@@ -3060,47 +3060,6 @@ DEFINE_BUILTIN_OP_IMPORTER(Size)
30603060 return {{&size.tensor (ctx)}};
30613061}
30623062
3063- // ! Return subscripts such that gather(concat(x,y),subscripts)
3064- // ! will return x with x[subcripts[i]] replaced by y[i].
3065- ShapeTensor axesToInterlaceSubscripts (const ShapeTensor& axes, int nbDims)
3066- {
3067- std::vector<int64_t > subscripts (nbDims);
3068- std::iota (subscripts.begin (), subscripts.end (), 0 );
3069- for (int32_t i = 0 ; i < axes.size (); ++i)
3070- {
3071- subscripts[axes[i]] = nbDims + i;
3072- }
3073- return ShapeTensor (1 , std::move (subscripts));
3074- }
3075-
3076- // ! Decode a start or end index according to ONNX Slice rules:
3077- // ! "If the value passed to start or end is larger than the n (the number of
3078- // ! elements in this dimension), it represents n.... If a negative value is
3079- // ! passed for step, it represents slicing backward."
3080- ShapeTensor decodeOnnxIndices (IImporterContext* ctx, const ShapeTensor& indices, const ShapeTensor& dims)
3081- {
3082- // Oblique calculation implements the rules using only operations available in TensorRT.
3083- return sub (
3084- ctx, min (ctx, max (ctx, indices, mul (ctx, shapeVector (-1 ), dims)), dims), mul (ctx, dims, max (ctx, shapeVector (-1 ), min (ctx, shapeVector (0 ), indices))));
3085- }
3086-
3087- ShapeTensor computeSizes (IImporterContext* ctx, const ShapeTensor& starts, const ShapeTensor& ends,
3088- const ShapeTensor& steps, const ShapeTensor& shift, const ShapeTensor& dims)
3089- {
3090- if (steps.isAll (1 ))
3091- {
3092- // The general formula in the else is correct,
3093- // but creates much debris for this common case.
3094- return sub (ctx, ends, starts);
3095- }
3096- else
3097- {
3098- // "If a negative value is passed for step, it represents slicing backward."
3099- // Compute ceil((end-start)/step) + shift using only operations available in TensorRT.
3100- return add (ctx, sub (ctx, similar (ctx, dims, 0 ), floorDiv (ctx, sub (ctx, starts, ends), steps)), shift);
3101- }
3102- }
3103-
31043063DEFINE_BUILTIN_OP_IMPORTER (Slice)
31053064{
31063065 const int nbInputs = node.input ().size ();
@@ -3170,43 +3129,22 @@ DEFINE_BUILTIN_OP_IMPORTER(Slice)
31703129 ASSERT (std::unordered_set<int64_t >(axes.begin (), axes.end ()).size () == static_cast <size_t >(axes.size ()),
31713130 ErrorCode::kINVALID_NODE );
31723131
3173- // Create a shift shapeTensor. We need to add 1 to the size for any slice that cuts across an entire
3174- // axis in reverse. It is 0 for all other slices.
3175- ShapeTensor shift = shapeVector (0 );
3176- if (ends.allValuesKnown ())
3177- {
3178- shift = ends;
3179- for (int64_t & val : shift.getValues ())
3180- {
3181- if (val == static_cast <int64_t >(INT_MIN))
3182- {
3183- val = 1 ;
3184- }
3185- else
3186- {
3187- val = 0 ;
3188- }
3189- }
3190- }
3191-
31923132 if (axes.size () < dims.size () || !isIota)
31933133 {
3194- // axes specify a subset of the dimensions, or out of order.
3195- // Convert starts/ends/steps/shift to complete in-order form.
3134+ // Axes specify a subset of the dimensions, or out of order.
3135+ // Convert starts/ends/steps to complete in-order form.
31963136 const ShapeTensor subscripts{axesToInterlaceSubscripts (axes, dims.size ())};
31973137 starts = interlace (ctx, similar (ctx, dims, 0 ), starts, subscripts);
31983138 ends = interlace (ctx, dims, ends, subscripts);
31993139 steps = interlace (ctx, similar (ctx, dims, 1 ), steps, subscripts);
3200- shift = interlace (ctx, similar (ctx, dims, 0 ), shift, subscripts);
32013140 }
32023141
3203- // "If a negative value is passed for any of the start or end indices,
3204- // it represents number of elements before the end of that dimension."
3205- starts = decodeOnnxIndices (ctx, starts, dims);
3206- ends = decodeOnnxIndices (ctx, ends, dims);
3142+ // ONNX has a bunch of rules for converting out of bounds starts/ends
3143+ // indices into the actual indices to use.
3144+ decodeOnnxStartsAndEnds (ctx, dims, steps, starts, ends);
32073145
32083146 // TensorRT uses sizes of the output dimensions instead of ends.
3209- const ShapeTensor sizes = computeSizes (ctx, starts, ends, steps, shift , dims);
3147+ const ShapeTensor sizes = computeSliceSizes (ctx, starts, ends, steps, dims);
32103148
32113149 nvinfer1::ISliceLayer* slice = addSlice (ctx, data, starts, sizes, steps);
32123150
0 commit comments