Skip to content

Commit 4ba6ec4

Browse files
authored
Fix slice computations for starts and ends for large slices (#558)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 7efb99f commit 4ba6ec4

File tree

4 files changed

+80
-75
lines changed

4 files changed

+80
-75
lines changed

ShapeTensor.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,6 @@ class ShapeTensor
115115
return mValues[k];
116116
}
117117

118-
//! Return underlying mValues
119-
std::vector<int64_t>& getValues()
120-
{
121-
assert(mAllValuesKnown);
122-
return mValues;
123-
}
124-
125118
//! Return true if x and y always have the same value.
126119
friend bool operator==(const ShapeTensor& x, const ShapeTensor& y);
127120
friend ShapeTensor shapeOf(const ShapeTensor& t);

builtin_op_importers.cpp

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,47 +3060,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Size)
30603060
return {{&size.tensor(ctx)}};
30613061
}
30623062

3063-
//! Return subscripts such that gather(concat(x,y),subscripts)
3064-
//! will return x with x[subcripts[i]] replaced by y[i].
3065-
ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims)
3066-
{
3067-
std::vector<int64_t> subscripts(nbDims);
3068-
std::iota(subscripts.begin(), subscripts.end(), 0);
3069-
for (int32_t i = 0; i < axes.size(); ++i)
3070-
{
3071-
subscripts[axes[i]] = nbDims + i;
3072-
}
3073-
return ShapeTensor(1, std::move(subscripts));
3074-
}
3075-
3076-
//! Decode a start or end index according to ONNX Slice rules:
3077-
//! "If the value passed to start or end is larger than the n (the number of
3078-
//! elements in this dimension), it represents n.... If a negative value is
3079-
//! passed for step, it represents slicing backward."
3080-
ShapeTensor decodeOnnxIndices(IImporterContext* ctx, const ShapeTensor& indices, const ShapeTensor& dims)
3081-
{
3082-
// Oblique calculation implements the rules using only operations available in TensorRT.
3083-
return sub(
3084-
ctx, min(ctx, max(ctx, indices, mul(ctx, shapeVector(-1), dims)), dims), mul(ctx, dims, max(ctx, shapeVector(-1), min(ctx, shapeVector(0), indices))));
3085-
}
3086-
3087-
ShapeTensor computeSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends,
3088-
const ShapeTensor& steps, const ShapeTensor& shift, const ShapeTensor& dims)
3089-
{
3090-
if (steps.isAll(1))
3091-
{
3092-
// The general formula in the else is correct,
3093-
// but creates much debris for this common case.
3094-
return sub(ctx, ends, starts);
3095-
}
3096-
else
3097-
{
3098-
// "If a negative value is passed for step, it represents slicing backward."
3099-
// Compute ceil((end-start)/step) + shift using only operations available in TensorRT.
3100-
return add(ctx, sub(ctx, similar(ctx, dims, 0), floorDiv(ctx, sub(ctx, starts, ends), steps)), shift);
3101-
}
3102-
}
3103-
31043063
DEFINE_BUILTIN_OP_IMPORTER(Slice)
31053064
{
31063065
const int nbInputs = node.input().size();
@@ -3170,43 +3129,22 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice)
31703129
ASSERT(std::unordered_set<int64_t>(axes.begin(), axes.end()).size() == static_cast<size_t>(axes.size()),
31713130
ErrorCode::kINVALID_NODE);
31723131

3173-
// Create a shift shapeTensor. We need to add 1 to the size for any slice that cuts across an entire
3174-
// axis in reverse. It is 0 for all other slices.
3175-
ShapeTensor shift = shapeVector(0);
3176-
if (ends.allValuesKnown())
3177-
{
3178-
shift = ends;
3179-
for (int64_t& val : shift.getValues())
3180-
{
3181-
if (val == static_cast<int64_t>(INT_MIN))
3182-
{
3183-
val = 1;
3184-
}
3185-
else
3186-
{
3187-
val = 0;
3188-
}
3189-
}
3190-
}
3191-
31923132
if (axes.size() < dims.size() || !isIota)
31933133
{
3194-
// axes specify a subset of the dimensions, or out of order.
3195-
// Convert starts/ends/steps/shift to complete in-order form.
3134+
// Axes specify a subset of the dimensions, or out of order.
3135+
// Convert starts/ends/steps to complete in-order form.
31963136
const ShapeTensor subscripts{axesToInterlaceSubscripts(axes, dims.size())};
31973137
starts = interlace(ctx, similar(ctx, dims, 0), starts, subscripts);
31983138
ends = interlace(ctx, dims, ends, subscripts);
31993139
steps = interlace(ctx, similar(ctx, dims, 1), steps, subscripts);
3200-
shift = interlace(ctx, similar(ctx, dims, 0), shift, subscripts);
32013140
}
32023141

3203-
// "If a negative value is passed for any of the start or end indices,
3204-
// it represents number of elements before the end of that dimension."
3205-
starts = decodeOnnxIndices(ctx, starts, dims);
3206-
ends = decodeOnnxIndices(ctx, ends, dims);
3142+
// ONNX has a bunch of rules for converting out of bounds starts/ends
3143+
// indices into the actual indices to use.
3144+
decodeOnnxStartsAndEnds(ctx, dims, steps, starts, ends);
32073145

32083146
// TensorRT uses sizes of the output dimensions instead of ends.
3209-
const ShapeTensor sizes = computeSizes(ctx, starts, ends, steps, shift, dims);
3147+
const ShapeTensor sizes = computeSliceSizes(ctx, starts, ends, steps, dims);
32103148

32113149
nvinfer1::ISliceLayer* slice = addSlice(ctx, data, starts, sizes, steps);
32123150

onnx2trt_utils.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1848,4 +1848,64 @@ Status weightsToVector(TensorOrWeights weights, std::vector<int64_t>* weightVect
18481848
return Status(ErrorCode::kSUCCESS);
18491849
}
18501850

1851+
//! Return ShapeTensor representing x clamped to closed interval [lowerBound,upperBound].
1852+
static ShapeTensor clamp(IImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& lowerBound, const ShapeTensor& upperBound)
1853+
{
1854+
return min(ctx, max(ctx, x, lowerBound), upperBound);
1855+
}
1856+
1857+
//! Return ShapeTensor representing indices < 0 ? inputDims + indices : indices
1858+
static ShapeTensor bumpIfNegative(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& indices)
1859+
{
1860+
const auto signs = clamp(ctx, indices, shapeVector(-1), shapeVector(0));
1861+
return sub(ctx, indices, mul(ctx, signs, inputDims));
1862+
}
1863+
1864+
void decodeOnnxStartsAndEnds(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& steps, ShapeTensor& starts, ShapeTensor& ends)
1865+
{
1866+
//! The ONNX specification is unreliable (https://github.com/onnx/onnx/issues/3063)
1867+
//! thus the logic here is designed to match that in
1868+
//! https://github.com/onnx/onnx/blob/master/onnx/defs/tensor/defs.cc .
1869+
1870+
// Set stepSign to step < 0 ? -1 : 0.
1871+
const auto stepSign = clamp(ctx, steps, shapeVector(-1), shapeVector(0));
1872+
1873+
// Update starts.
1874+
starts = bumpIfNegative(ctx, inputDims, starts);
1875+
starts = clamp(ctx, starts, shapeVector(0), add(ctx, inputDims, stepSign));
1876+
1877+
// Update ends
1878+
ends = bumpIfNegative(ctx, inputDims, ends);
1879+
ends = clamp(ctx, ends, stepSign, inputDims);
1880+
}
1881+
1882+
ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims)
1883+
{
1884+
std::vector<int64_t> subscripts(nbDims);
1885+
std::iota(subscripts.begin(), subscripts.end(), 0);
1886+
for (int32_t i = 0; i < axes.size(); ++i)
1887+
{
1888+
subscripts[axes[i]] = nbDims + i;
1889+
}
1890+
return ShapeTensor(1, std::move(subscripts));
1891+
}
1892+
1893+
ShapeTensor computeSliceSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends,
1894+
const ShapeTensor& steps, const ShapeTensor& dims)
1895+
{
1896+
if (steps.isAll(1))
1897+
{
1898+
// The general formula in the else is correct,
1899+
// but creates much debris for this common case.
1900+
return sub(ctx, ends, starts);
1901+
}
1902+
else
1903+
{
1904+
// "If a negative value is passed for step, it represents slicing backward."
1905+
// Compute ceil((end-start)/step) using only operations available on ShapeTensor,
1906+
// using the identity ceil(x) = -floor(-x).
1907+
return sub(ctx, similar(ctx, dims, 0), floorDiv(ctx, sub(ctx, starts, ends), steps));
1908+
}
1909+
}
1910+
18511911
} // namespace onnx2trt

onnx2trt_utils.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ static std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType&
9595
namespace onnx2trt
9696
{
9797

98+
class ShapeTensor;
99+
98100
// Helper function to calculate the volume of a Dims object
99101
int64_t volume(const nvinfer1::Dims& dims);
100102

@@ -310,4 +312,16 @@ nvinfer1::ITensor* unsqueezeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE
310312
// Helper function to convert a ShapedWeights object into a vector
311313
Status weightsToVector(TensorOrWeights weights, std::vector<int64_t>* weightVector);
312314

315+
//! Decode in place the starts and ends indices according to ONNX Slice rules.
316+
void decodeOnnxStartsAndEnds(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& steps, ShapeTensor& starts, ShapeTensor& ends);
317+
318+
//! Return ShapeTensor representing size of result of Slice.
319+
//! starts and ends should first be decoded by decodeOnnxStartsAndEnds.
320+
ShapeTensor computeSliceSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends,
321+
const ShapeTensor& steps, const ShapeTensor& dims);
322+
323+
//! Return subscripts such that gather(concat(x,y),subscripts)
324+
//! will return x with x[subcripts[i]] replaced by y[i].
325+
ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims);
326+
313327
} // namespace onnx2trt

0 commit comments

Comments
 (0)