Skip to content

Commit 6027b6c

Browse files
committed
Support gather with negative indices (onnx#678)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 351f22b commit 6027b6c

3 files changed

Lines changed: 58 additions & 14 deletions

File tree

builtin_op_importers.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,17 +1190,22 @@ DEFINE_BUILTIN_OP_IMPORTER(Floor)
11901190

11911191
DEFINE_BUILTIN_OP_IMPORTER(Gather)
11921192
{
1193-
nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);
1193+
nvinfer1::ITensor* data = &convertToTensor(inputs.at(0), ctx);
11941194
// TRT does not support BOOL input types for this node
1195-
ASSERT(data.getType() != nvinfer1::DataType::kBOOL, ErrorCode::kUNSUPPORTED_NODE);
1196-
nvinfer1::ITensor& indices = convertToTensor(inputs.at(1), ctx);
1195+
ASSERT( (data->getType() != nvinfer1::DataType::kBOOL) && "This version of TensorRT does not support BOOL input type for the Gather operator.", ErrorCode::kUNSUPPORTED_NODE);
1196+
1197+
nvinfer1::ITensor* indices = &convertToTensor(inputs.at(1), ctx);
11971198
OnnxAttrs attrs(node, ctx);
1198-
int axis = attrs.get<int>("axis", 0);
1199-
int nbDims = inputs.at(0).shape().nbDims;
1199+
int32_t axis = attrs.get<int32_t>("axis", 0);
1200+
int32_t nbDims = inputs.at(0).shape().nbDims;
12001201
TRT_CHECK(convertAxis(axis, nbDims));
12011202
LOG_VERBOSE("Using Gather axis: " << axis);
1202-
auto* layer = ctx->network()->addGather(data, indices, axis);
1203-
ctx->registerLayer(layer, node.name());
1203+
1204+
// Convert any negative indices to positive ones
1205+
indices = convertGatherIndices(ctx, data, indices, axis);
1206+
1207+
auto* layer = ctx->network()->addGather(*data, *indices, axis);
1208+
ctx->registerLayer(layer, getNodeName(node));
12041209
RETURN_FIRST_OUTPUT(layer);
12051210
}
12061211

@@ -1231,11 +1236,11 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12311236
*/
12321237
// clang-format on
12331238

1234-
nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);
1235-
nvinfer1::ITensor& index = convertToTensor(inputs.at(1), ctx);
1239+
nvinfer1::ITensor* data = &convertToTensor(inputs.at(0), ctx);
1240+
nvinfer1::ITensor* index = &convertToTensor(inputs.at(1), ctx);
12361241

1237-
const nvinfer1::Dims& idxDims = index.getDimensions();
1238-
const nvinfer1::Dims& daDims = data.getDimensions();
1242+
const nvinfer1::Dims& idxDims = index->getDimensions();
1243+
const nvinfer1::Dims& daDims = data->getDimensions();
12391244

12401245
// Note the above tranformation requires dimensions to be known at parse time, so check for dynamic shapes
12411246
ASSERT(!isDynamic(daDims) && !isDynamic(idxDims)
@@ -1246,6 +1251,9 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12461251
int32_t axis = attrs.get<int32_t>("axis", 0);
12471252
int32_t dataNbDims = daDims.nbDims;
12481253

1254+
// Convert any negative indices to positive ones
1255+
index = convertGatherIndices(ctx, data, index, axis);
1256+
12491257
TRT_CHECK(convertAxis(axis, dataNbDims));
12501258
LOG_VERBOSE("Using Gather axis: " << axis);
12511259

@@ -1262,12 +1270,12 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12621270
auto* biasTensor = addConstant(ctx, biasVector, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput(0);
12631271

12641272
auto* mul
1265-
= ctx->network()->addElementWise(index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
1273+
= ctx->network()->addElementWise(*index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
12661274
auto* newIndices
12671275
= ctx->network()->addElementWise(*mul, *biasTensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
12681276

12691277
nvinfer1::Dims flattenDataDims{1, {static_cast<int32_t>(volume(daDims))}};
1270-
auto* reshape = ctx->network()->addShuffle(data);
1278+
auto* reshape = ctx->network()->addShuffle(*data);
12711279
reshape->setReshapeDimensions(flattenDataDims);
12721280
reshape->setZeroIsPlaceholder(false);
12731281

@@ -1277,7 +1285,6 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12771285
RETURN_FIRST_OUTPUT(layer);
12781286
}
12791287

1280-
12811288
DEFINE_BUILTIN_OP_IMPORTER(Gemm)
12821289
{
12831290
OnnxAttrs attrs(node, ctx);

onnx2trt_utils.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,37 @@ onnx2trt::ShapedWeights createZeroShifts(const onnx2trt::ShapedWeights& shiftInt
405405
return shift;
406406
}
407407

408+
nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* data)
409+
{
410+
nvinfer1::ITensor* zero;
411+
if (data->getType() == nvinfer1::DataType::kFLOAT)
412+
{
413+
zero
414+
= addConstant(ctx, std::vector<float>{0.f}, ::ONNX_NAMESPACE::TensorProto::FLOAT, {0, {1}})->getOutput(0);
415+
}
416+
else
417+
{
418+
zero
419+
= addConstant(ctx, std::vector<int>{0}, ::ONNX_NAMESPACE::TensorProto::INT32, {0, {1}})->getOutput(0);
420+
}
421+
broadcastTensors(ctx, zero, data);
422+
zero = ctx->network()->addElementWise(*data, *zero, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
423+
return zero;
424+
}
425+
426+
nvinfer1::ITensor* convertGatherIndices(IImporterContext* ctx, nvinfer1::ITensor* data, nvinfer1::ITensor* indices, int32_t axis)
427+
{
428+
// Create a condition tensor that is 1 for the elements in indices that are < 0 or 0 otherwise
429+
auto condition = ctx->network()->addElementWise(*indices, *createZeroTensor(ctx, indices), nvinfer1::ElementWiseOperation::kLESS)->getOutput(0);
430+
auto axisLength = getAxisLength(ctx, data, axis);
431+
broadcastTensors(ctx, axisLength, indices);
432+
// Create a shifted tensor that is indices + axisLength
433+
auto shifted = ctx->network()->addElementWise(*indices, *axisLength, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
434+
// Select between the shifted and original data based on condition
435+
auto select = ctx->network()->addSelect(*condition, *shifted, *indices);
436+
return select->getOutput(0);
437+
}
438+
408439
template <typename DataType>
409440
DataType* convertINT32Data(const int32_t* weightValues, nvinfer1::Dims shape, int32_t onnxdtype, IImporterContext* ctx)
410441
{

onnx2trt_utils.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,19 @@ bool convertDtype(int32_t onnx_dtype, nvinfer1::DataType* trt_dtype);
181181
// Helper function to convert INT64 weight values into INT32
182182
int32_t* convertINT64(const int64_t* weightValues, nvinfer1::Dims shape, IImporterContext* ctx);
183183

184+
// Helper function to convert negative gather indices into positive ones
185+
nvinfer1::ITensor* convertGatherIndices(IImporterContext* ctx, nvinfer1::ITensor* data, nvinfer1::ITensor* indices, int32_t axis);
186+
184187
// Helper function to convert ONNX padding into TRT padding
185188
bool convertOnnxPadding(
186189
const std::vector<int64_t>& onnxPadding, nvinfer1::Dims2* begPadding, nvinfer1::Dims2* endPadding);
187190

188191
// Helper function to create zero shifts for QuantizeLinear/DequantizeLinear ops
189192
onnx2trt::ShapedWeights createZeroShifts(const onnx2trt::ShapedWeights& shiftInt8, int32_t type, IImporterContext* ctx);
190193

194+
// Helper function to create a tensor of all zeros with the same shape as a data tensor
195+
nvinfer1::ITensor* createZeroTensor(IImporterContext* ctx, nvinfer1::ITensor* data);
196+
191197
// Helper function to convert an ONNX weight into a ShapedWeights object
192198
bool convertOnnxWeights(
193199
const ::ONNX_NAMESPACE::TensorProto& onnxTensor, onnx2trt::ShapedWeights* weights, IImporterContext* ctx);

0 commit comments

Comments
 (0)