@@ -1190,17 +1190,22 @@ DEFINE_BUILTIN_OP_IMPORTER(Floor)
11901190
11911191DEFINE_BUILTIN_OP_IMPORTER (Gather)
11921192{
1193- nvinfer1::ITensor& data = convertToTensor (inputs.at (0 ), ctx);
1193+ nvinfer1::ITensor* data = & convertToTensor (inputs.at (0 ), ctx);
11941194 // TRT does not support BOOL input types for this node
1195- ASSERT (data.getType () != nvinfer1::DataType::kBOOL , ErrorCode::kUNSUPPORTED_NODE );
1196- nvinfer1::ITensor& indices = convertToTensor (inputs.at (1 ), ctx);
1195+ ASSERT ( (data->getType () != nvinfer1::DataType::kBOOL ) && " This version of TensorRT does not support BOOL input type for the Gather operator." , ErrorCode::kUNSUPPORTED_NODE );
1196+
1197+ nvinfer1::ITensor* indices = &convertToTensor (inputs.at (1 ), ctx);
11971198 OnnxAttrs attrs (node, ctx);
1198- int axis = attrs.get <int >(" axis" , 0 );
1199- int nbDims = inputs.at (0 ).shape ().nbDims ;
1199+ int32_t axis = attrs.get <int32_t >(" axis" , 0 );
1200+ int32_t nbDims = inputs.at (0 ).shape ().nbDims ;
12001201 TRT_CHECK (convertAxis (axis, nbDims));
12011202 LOG_VERBOSE (" Using Gather axis: " << axis);
1202- auto * layer = ctx->network ()->addGather (data, indices, axis);
1203- ctx->registerLayer (layer, node.name ());
1203+
1204+ // Convert any negative indices to positive ones
1205+ indices = convertGatherIndices (ctx, data, indices, axis);
1206+
1207+ auto * layer = ctx->network ()->addGather (*data, *indices, axis);
1208+ ctx->registerLayer (layer, getNodeName (node));
12041209 RETURN_FIRST_OUTPUT (layer);
12051210}
12061211
@@ -1231,11 +1236,11 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12311236 */
12321237 // clang-format on
12331238
1234- nvinfer1::ITensor& data = convertToTensor (inputs.at (0 ), ctx);
1235- nvinfer1::ITensor& index = convertToTensor (inputs.at (1 ), ctx);
1239+ nvinfer1::ITensor* data = & convertToTensor (inputs.at (0 ), ctx);
1240+ nvinfer1::ITensor* index = & convertToTensor (inputs.at (1 ), ctx);
12361241
1237- const nvinfer1::Dims& idxDims = index. getDimensions ();
1238- const nvinfer1::Dims& daDims = data. getDimensions ();
1242+ const nvinfer1::Dims& idxDims = index-> getDimensions ();
1243+ const nvinfer1::Dims& daDims = data-> getDimensions ();
12391244
12401245 // Note the above tranformation requires dimensions to be known at parse time, so check for dynamic shapes
12411246 ASSERT (!isDynamic (daDims) && !isDynamic (idxDims)
@@ -1246,6 +1251,9 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12461251 int32_t axis = attrs.get <int32_t >(" axis" , 0 );
12471252 int32_t dataNbDims = daDims.nbDims ;
12481253
1254+ // Convert any negative indices to positive ones
1255+ index = convertGatherIndices (ctx, data, index, axis);
1256+
12491257 TRT_CHECK (convertAxis (axis, dataNbDims));
12501258 LOG_VERBOSE (" Using Gather axis: " << axis);
12511259
@@ -1262,12 +1270,12 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12621270 auto * biasTensor = addConstant (ctx, biasVector, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput (0 );
12631271
12641272 auto * mul
1265- = ctx->network ()->addElementWise (index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD )->getOutput (0 );
1273+ = ctx->network ()->addElementWise (* index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD )->getOutput (0 );
12661274 auto * newIndices
12671275 = ctx->network ()->addElementWise (*mul, *biasTensor, nvinfer1::ElementWiseOperation::kSUM )->getOutput (0 );
12681276
12691277 nvinfer1::Dims flattenDataDims{1 , {static_cast <int32_t >(volume (daDims))}};
1270- auto * reshape = ctx->network ()->addShuffle (data);
1278+ auto * reshape = ctx->network ()->addShuffle (* data);
12711279 reshape->setReshapeDimensions (flattenDataDims);
12721280 reshape->setZeroIsPlaceholder (false );
12731281
@@ -1277,7 +1285,6 @@ DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12771285 RETURN_FIRST_OUTPUT (layer);
12781286}
12791287
1280-
12811288DEFINE_BUILTIN_OP_IMPORTER (Gemm)
12821289{
12831290 OnnxAttrs attrs (node, ctx);
0 commit comments