@@ -1206,85 +1206,73 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)
12061206
12071207DEFINE_BUILTIN_OP_IMPORTER (GatherElements)
12081208{
1209+
1210+ // We can treat GatherElements as a regular Gather operation with transformed input and indices tensors.
1211+ // Consider a simple example of a 3D tensor with axis = 1.
1212+ // The regular forumla of out[i][j][k] = in[i][idx[i][j][k]][k] can be rewritten as out[i][j][k] = in'[idx'[i,j,k]]
1213+ // Where in' is a squeezed down 1D representation of the data and idx' is calculated from the following formula:
1214+ // idx' = idx[i,j,k] * pitch[1] + bias. The bias is calculated as i*pitch[0] + k*pitch[2].
1215+
1216+ // clang-format off
1217+ /* Example: Data is 3D tensor of shape [2,2,2] with values [[[1,2], [3,4]], [[5,6], [7,8]]]
1218+ Indices is a 3D tensor of shape [2,2,1] with values [[[0], [1]], [[0], [1]]]
1219+ From the original formula, the output is [[[1], [3]], [[5], [7]]],
1220+
1221+ Pitch vector of data is [4,2,1].
1222+
1223+ idx` calculation:
1224+ idx`[0, 0, 0] = [idx[0,0,0]](0) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 0
1225+ idx`[0, 1, 0] = [idx[0,1,0]](1) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 2
1226+ idx`[1, 0, 0] = [idx[1,0,0]](0) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 4
1227+ idx`[1, 1, 0] = [idx[1,1,0]](1) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 6
1228+ = [[[0], [2]], [[4], [6]]]
1229+
1230+ After linearizing data to 1D: [1,2,3,4,5,6,7,8], gathering on axis 0 with the new indices gives the same results.
1231+ */
1232+ // clang-format on
1233+
12091234 nvinfer1::ITensor& data = convertToTensor (inputs.at (0 ), ctx);
12101235 nvinfer1::ITensor& index = convertToTensor (inputs.at (1 ), ctx);
12111236
12121237 const nvinfer1::Dims& idxDims = index.getDimensions ();
1213- const nvinfer1::Dims& dataDims = data.getDimensions ();
1238+ const nvinfer1::Dims& daDims = data.getDimensions ();
1239+
1240+ // Note the above tranformation requires dimensions to be known at parse time, so check for dynamic shapes
1241+ ASSERT (!isDynamic (daDims) && !isDynamic (idxDims)
1242+ && " This version of TenosrRT does not support GatherElements on dynamic shapes!" ,
1243+ ErrorCode::kUNSUPPORTED_NODE );
12141244
12151245 OnnxAttrs attrs (node, ctx);
12161246 int32_t axis = attrs.get <int32_t >(" axis" , 0 );
1217- int32_t dataNbDims = dataDims .nbDims ;
1247+ int32_t dataNbDims = daDims .nbDims ;
12181248
12191249 TRT_CHECK (convertAxis (axis, dataNbDims));
12201250 LOG_VERBOSE (" Using Gather axis: " << axis);
12211251
1222- // Calculate how many indices
1252+ // Calculate data pitches vector, and create axisPitch vector
12231253 int64_t nIndx = volume (idxDims);
1254+ std::vector<int32_t > pitches = calculatePitches (daDims);
1255+ std::vector<int32_t > axisPitch (nIndx, pitches[axis]);
12241256
1225- // Calculate pitches of input tensor
1226- int32_t nDataElements = volume (dataDims), pitch = 1 ;
1227- int32_t pitches[nvinfer1::Dims::MAX_DIMS] = {0 };
1228- pitches[dataDims.nbDims -1 ] = pitch;
1229- for (int32_t i = dataDims.nbDims -2 ; i >= 0 ; i--)
1230- {
1231- pitch *= dataDims.d [i];
1232- pitches[i] = pitch;
1233- }
1234-
1235- // Generate constants based on axis
1236- std::vector<int32_t > sCoeff (nIndx, pitches[axis]);
1237- std::vector<int32_t > aCoeff;
1238-
1239- // Transform a 1-d index back to the nDims
1240- for (int32_t i = 0 ; i < nIndx; i++)
1241- {
1242- std::vector<int32_t > nDimsIdx; // this can be an array
1243- int32_t currI = i;
1244-
1245- for (int32_t j = 0 ; j < dataDims.nbDims ; j++)
1246- {
1247- int32_t currIdxVal = currI / pitches[j];
1248- nDimsIdx.push_back (currIdxVal);
1249- currI = currI % pitches[j];
1250- }
1251-
1252- int32_t bias = 0 ;
1253- // calculate the aCoeff
1254- for (size_t j = 0 ; j < nDimsIdx.size (); j++)
1255- {
1256-
1257- if (j == (size_t )axis)
1258- {
1259- continue ;
1260- }
1261- bias += nDimsIdx[j] * pitches[j];
1262- }
1263- aCoeff.push_back (bias);
1264- }
1265-
1266- auto * sCoeffLayer = addConstant (ctx, sCoeff , ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);
1267- auto * aCoeffLayer = addConstant (ctx, aCoeff, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);
1268-
1269- nvinfer1::ITensor* sCoeffTensor = sCoeffLayer ->getOutput (0 );
1270- nvinfer1::ITensor* aCoeffTensor = aCoeffLayer->getOutput (0 );
1271- auto * mul = ctx->network ()->addElementWise (index, *sCoeffTensor , nvinfer1::ElementWiseOperation::kPROD );
1272-
1273- nvinfer1::ITensor* mulTensor = mul->getOutput (0 );
1274- auto * add = ctx->network ()->addElementWise (*mulTensor, *aCoeffTensor, nvinfer1::ElementWiseOperation::kSUM );
1257+ // Calculate bias vector
1258+ std::vector<int32_t > biasVector = calculateBias (daDims, idxDims, pitches, axis);
12751259
1276- nvinfer1::ITensor* addTensor = add->getOutput (0 );
1260+ // Perform idx` = idx * pitch[axis] + bias calculation.
1261+ auto * axisPitchTensor = addConstant (ctx, axisPitch, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput (0 );
1262+ auto * biasTensor = addConstant (ctx, biasVector, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput (0 );
12771263
1278- nvinfer1::Dims flattenDataDims{1 };
1264+ auto * mul
1265+ = ctx->network ()->addElementWise (index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD )->getOutput (0 );
1266+ auto * newIndices
1267+ = ctx->network ()->addElementWise (*mul, *biasTensor, nvinfer1::ElementWiseOperation::kSUM )->getOutput (0 );
12791268
1280- flattenDataDims.nbDims = 1 ;
1281- flattenDataDims.d [0 ] = nDataElements;
1269+ nvinfer1::Dims flattenDataDims{1 , {static_cast <int32_t >(volume (daDims))}};
12821270 auto * reshape = ctx->network ()->addShuffle (data);
12831271 reshape->setReshapeDimensions (flattenDataDims);
12841272 reshape->setZeroIsPlaceholder (false );
12851273
12861274 nvinfer1::ITensor* flattenData = reshape->getOutput (0 );
1287- auto * layer = ctx->network ()->addGather (*flattenData, *addTensor , 0 );
1275+ auto * layer = ctx->network ()->addGather (*flattenData, *newIndices , 0 );
12881276 ctx->registerLayer (layer, getNodeName (node));
12891277 RETURN_FIRST_OUTPUT (layer);
12901278}
0 commit comments