diff --git a/ShapeTensor.hpp b/ShapeTensor.hpp index 371bd1f8..f5c3c88b 100644 --- a/ShapeTensor.hpp +++ b/ShapeTensor.hpp @@ -115,13 +115,6 @@ class ShapeTensor return mValues[k]; } - //! Return underlying mValues - std::vector& getValues() - { - assert(mAllValuesKnown); - return mValues; - } - //! Return true if x and y always have the same value. friend bool operator==(const ShapeTensor& x, const ShapeTensor& y); friend ShapeTensor shapeOf(const ShapeTensor& t); diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index bb19eea0..8ddf9fd2 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -3057,47 +3057,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Size) return {{&size.tensor(ctx)}}; } -//! Return subscripts such that gather(concat(x,y),subscripts) -//! will return x with x[subcripts[i]] replaced by y[i]. -ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims) -{ - std::vector subscripts(nbDims); - std::iota(subscripts.begin(), subscripts.end(), 0); - for (int32_t i = 0; i < axes.size(); ++i) - { - subscripts[axes[i]] = nbDims + i; - } - return ShapeTensor(1, std::move(subscripts)); -} - -//! Decode a start or end index according to ONNX Slice rules: -//! "If the value passed to start or end is larger than the n (the number of -//! elements in this dimension), it represents n.... If a negative value is -//! passed for step, it represents slicing backward." -ShapeTensor decodeOnnxIndices(IImporterContext* ctx, const ShapeTensor& indices, const ShapeTensor& dims) -{ - // Oblique calculation implements the rules using only operations available in TensorRT. - return sub( - ctx, min(ctx, max(ctx, indices, mul(ctx, shapeVector(-1), dims)), dims), mul(ctx, dims, max(ctx, shapeVector(-1), min(ctx, shapeVector(0), indices)))); -} - -ShapeTensor computeSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends, - const ShapeTensor& steps, const ShapeTensor& shift, const ShapeTensor& dims) -{ - if (steps.isAll(1)) - { - // The general formula in the else is correct, - // but creates much debris for this common case. - return sub(ctx, ends, starts); - } - else - { - // "If a negative value is passed for step, it represents slicing backward." - // Compute ceil((end-start)/step) + shift using only operations available in TensorRT. - return add(ctx, sub(ctx, similar(ctx, dims, 0), floorDiv(ctx, sub(ctx, starts, ends), steps)), shift); - } -} - DEFINE_BUILTIN_OP_IMPORTER(Slice) { const int nbInputs = node.input().size(); @@ -3167,43 +3126,22 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice) ASSERT(std::unordered_set(axes.begin(), axes.end()).size() == static_cast(axes.size()), ErrorCode::kINVALID_NODE); - // Create a shift shapeTensor. We need to add 1 to the size for any slice that cuts across an entire - // axis in reverse. It is 0 for all other slices. - ShapeTensor shift = shapeVector(0); - if (ends.allValuesKnown()) - { - shift = ends; - for (int64_t& val : shift.getValues()) - { - if (val == static_cast(INT_MIN)) - { - val = 1; - } - else - { - val = 0; - } - } - } - if (axes.size() < dims.size() || !isIota) { - // axes specify a subset of the dimensions, or out of order. - // Convert starts/ends/steps/shift to complete in-order form. + // Axes specify a subset of the dimensions, or out of order. + // Convert starts/ends/steps to complete in-order form. const ShapeTensor subscripts{axesToInterlaceSubscripts(axes, dims.size())}; starts = interlace(ctx, similar(ctx, dims, 0), starts, subscripts); ends = interlace(ctx, dims, ends, subscripts); steps = interlace(ctx, similar(ctx, dims, 1), steps, subscripts); - shift = interlace(ctx, similar(ctx, dims, 0), shift, subscripts); } - // "If a negative value is passed for any of the start or end indices, - // it represents number of elements before the end of that dimension." - starts = decodeOnnxIndices(ctx, starts, dims); - ends = decodeOnnxIndices(ctx, ends, dims); + // ONNX has a bunch of rules for converting out of bounds starts/ends + // indices into the actual indices to use. + decodeOnnxStartsAndEnds(ctx, dims, steps, starts, ends); // TensorRT uses sizes of the output dimensions instead of ends. - const ShapeTensor sizes = computeSizes(ctx, starts, ends, steps, shift, dims); + const ShapeTensor sizes = computeSliceSizes(ctx, starts, ends, steps, dims); nvinfer1::ISliceLayer* slice = addSlice(ctx, data, starts, sizes, steps); diff --git a/onnx2trt_utils.cpp b/onnx2trt_utils.cpp index f6c583e4..cd45047a 100644 --- a/onnx2trt_utils.cpp +++ b/onnx2trt_utils.cpp @@ -1848,4 +1848,64 @@ Status weightsToVector(TensorOrWeights weights, std::vector* weightVect return Status(ErrorCode::kSUCCESS); } +//! Return ShapeTensor representing x clamped to closed interval [lowerBound,upperBound]. +static ShapeTensor clamp(IImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& lowerBound, const ShapeTensor& upperBound) +{ + return min(ctx, max(ctx, x, lowerBound), upperBound); +} + +//! Return ShapeTensor representing indices < 0 ? inputDims + indices : indices +static ShapeTensor bumpIfNegative(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& indices) +{ + const auto signs = clamp(ctx, indices, shapeVector(-1), shapeVector(0)); + return sub(ctx, indices, mul(ctx, signs, inputDims)); +} + +void decodeOnnxStartsAndEnds(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& steps, ShapeTensor& starts, ShapeTensor& ends) +{ + //! The ONNX specification is unreliable (https://github.com/onnx/onnx/issues/3063) + //! thus the logic here is designed to match that in + //! https://github.com/onnx/onnx/blob/master/onnx/defs/tensor/defs.cc . + + // Set stepSign to step < 0 ? -1 : 0. + const auto stepSign = clamp(ctx, steps, shapeVector(-1), shapeVector(0)); + + // Update starts. + starts = bumpIfNegative(ctx, inputDims, starts); + starts = clamp(ctx, starts, shapeVector(0), add(ctx, inputDims, stepSign)); + + // Update ends + ends = bumpIfNegative(ctx, inputDims, ends); + ends = clamp(ctx, ends, stepSign, inputDims); +} + +ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims) +{ + std::vector subscripts(nbDims); + std::iota(subscripts.begin(), subscripts.end(), 0); + for (int32_t i = 0; i < axes.size(); ++i) + { + subscripts[axes[i]] = nbDims + i; + } + return ShapeTensor(1, std::move(subscripts)); +} + +ShapeTensor computeSliceSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends, + const ShapeTensor& steps, const ShapeTensor& dims) +{ + if (steps.isAll(1)) + { + // The general formula in the else is correct, + // but creates much debris for this common case. + return sub(ctx, ends, starts); + } + else + { + // "If a negative value is passed for step, it represents slicing backward." + // Compute ceil((end-start)/step) using only operations available on ShapeTensor, + // using the identity ceil(x) = -floor(-x). + return sub(ctx, similar(ctx, dims, 0), floorDiv(ctx, sub(ctx, starts, ends), steps)); + } +} + } // namespace onnx2trt diff --git a/onnx2trt_utils.hpp b/onnx2trt_utils.hpp index 7be30f66..b6b78e59 100644 --- a/onnx2trt_utils.hpp +++ b/onnx2trt_utils.hpp @@ -95,6 +95,8 @@ static std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& namespace onnx2trt { +class ShapeTensor; + // Helper function to calculate the volume of a Dims object int64_t volume(const nvinfer1::Dims& dims); @@ -310,4 +312,16 @@ nvinfer1::ITensor* unsqueezeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE // Helper function to convert a ShapedWeights object into a vector Status weightsToVector(TensorOrWeights weights, std::vector* weightVector); +//! Decode in place the starts and ends indices according to ONNX Slice rules. +void decodeOnnxStartsAndEnds(IImporterContext* ctx, const ShapeTensor& inputDims, const ShapeTensor& steps, ShapeTensor& starts, ShapeTensor& ends); + +//! Return ShapeTensor representing size of result of Slice. +//! starts and ends should first be decoded by decodeOnnxStartsAndEnds. +ShapeTensor computeSliceSizes(IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends, + const ShapeTensor& steps, const ShapeTensor& dims); + +//! Return subscripts such that gather(concat(x,y),subscripts) +//! will return x with x[subcripts[i]] replaced by y[i]. +ShapeTensor axesToInterlaceSubscripts(const ShapeTensor& axes, int nbDims); + } // namespace onnx2trt