Skip to content

Commit fd98985

Browse files
committed
Update gather elements implementation (#675)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 123f8f4 commit fd98985

3 files changed

Lines changed: 115 additions & 58 deletions

File tree

builtin_op_importers.cpp

Lines changed: 46 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,85 +1206,73 @@ DEFINE_BUILTIN_OP_IMPORTER(Gather)
12061206

12071207
DEFINE_BUILTIN_OP_IMPORTER(GatherElements)
12081208
{
1209+
1210+
// We can treat GatherElements as a regular Gather operation with transformed input and indices tensors.
1211+
// Consider a simple example of a 3D tensor with axis = 1.
1212+
// The regular forumla of out[i][j][k] = in[i][idx[i][j][k]][k] can be rewritten as out[i][j][k] = in'[idx'[i,j,k]]
1213+
// Where in' is a squeezed down 1D representation of the data and idx' is calculated from the following formula:
1214+
// idx' = idx[i,j,k] * pitch[1] + bias. The bias is calculated as i*pitch[0] + k*pitch[2].
1215+
1216+
// clang-format off
1217+
/* Example: Data is 3D tensor of shape [2,2,2] with values [[[1,2], [3,4]], [[5,6], [7,8]]]
1218+
Indices is a 3D tensor of shape [2,2,1] with values [[[0], [1]], [[0], [1]]]
1219+
From the original formula, the output is [[[1], [3]], [[5], [7]]],
1220+
1221+
Pitch vector of data is [4,2,1].
1222+
1223+
idx` calculation:
1224+
idx`[0, 0, 0] = [idx[0,0,0]](0) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 0
1225+
idx`[0, 1, 0] = [idx[0,1,0]](1) * [pitch[axis]](2) + [i(0)*pitch[0](4)](0) + [k(0)*pitch[2](1)](0) = 2
1226+
idx`[1, 0, 0] = [idx[1,0,0]](0) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 4
1227+
idx`[1, 1, 0] = [idx[1,1,0]](1) * [pitch[axis]](2) + [i(1)*pitch[0](4)](4) + [k(0)*pitch[2](1)](0) = 6
1228+
= [[[0], [2]], [[4], [6]]]
1229+
1230+
After linearizing data to 1D: [1,2,3,4,5,6,7,8], gathering on axis 0 with the new indices gives the same results.
1231+
*/
1232+
// clang-format on
1233+
12091234
nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);
12101235
nvinfer1::ITensor& index = convertToTensor(inputs.at(1), ctx);
12111236

12121237
const nvinfer1::Dims& idxDims = index.getDimensions();
1213-
const nvinfer1::Dims& dataDims = data.getDimensions();
1238+
const nvinfer1::Dims& daDims = data.getDimensions();
1239+
1240+
// Note the above tranformation requires dimensions to be known at parse time, so check for dynamic shapes
1241+
ASSERT(!isDynamic(daDims) && !isDynamic(idxDims)
1242+
&& "This version of TenosrRT does not support GatherElements on dynamic shapes!",
1243+
ErrorCode::kUNSUPPORTED_NODE);
12141244

12151245
OnnxAttrs attrs(node, ctx);
12161246
int32_t axis = attrs.get<int32_t>("axis", 0);
1217-
int32_t dataNbDims = dataDims.nbDims;
1247+
int32_t dataNbDims = daDims.nbDims;
12181248

12191249
TRT_CHECK(convertAxis(axis, dataNbDims));
12201250
LOG_VERBOSE("Using Gather axis: " << axis);
12211251

1222-
// Calculate how many indices
1252+
// Calculate data pitches vector, and create axisPitch vector
12231253
int64_t nIndx = volume(idxDims);
1254+
std::vector<int32_t> pitches = calculatePitches(daDims);
1255+
std::vector<int32_t> axisPitch(nIndx, pitches[axis]);
12241256

1225-
// Calculate pitches of input tensor
1226-
int32_t nDataElements = volume(dataDims), pitch = 1;
1227-
int32_t pitches[nvinfer1::Dims::MAX_DIMS] = {0};
1228-
pitches[dataDims.nbDims-1] = pitch;
1229-
for (int32_t i = dataDims.nbDims-2; i >= 0 ; i--)
1230-
{
1231-
pitch *= dataDims.d[i];
1232-
pitches[i] = pitch;
1233-
}
1234-
1235-
// Generate constants based on axis
1236-
std::vector<int32_t> sCoeff(nIndx, pitches[axis]);
1237-
std::vector<int32_t> aCoeff;
1238-
1239-
// Transform a 1-d index back to the nDims
1240-
for (int32_t i = 0; i < nIndx; i++)
1241-
{
1242-
std::vector<int32_t> nDimsIdx; //this can be an array
1243-
int32_t currI = i;
1244-
1245-
for (int32_t j = 0; j < dataDims.nbDims; j++)
1246-
{
1247-
int32_t currIdxVal = currI / pitches[j];
1248-
nDimsIdx.push_back(currIdxVal);
1249-
currI = currI % pitches[j];
1250-
}
1251-
1252-
int32_t bias = 0;
1253-
//calculate the aCoeff
1254-
for (size_t j = 0; j < nDimsIdx.size(); j++)
1255-
{
1256-
1257-
if (j == (size_t)axis)
1258-
{
1259-
continue;
1260-
}
1261-
bias += nDimsIdx[j] * pitches[j];
1262-
}
1263-
aCoeff.push_back(bias);
1264-
}
1265-
1266-
auto* sCoeffLayer = addConstant(ctx, sCoeff, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);
1267-
auto* aCoeffLayer = addConstant(ctx, aCoeff, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims);
1268-
1269-
nvinfer1::ITensor* sCoeffTensor = sCoeffLayer->getOutput(0);
1270-
nvinfer1::ITensor* aCoeffTensor = aCoeffLayer->getOutput(0);
1271-
auto* mul = ctx->network()->addElementWise(index, *sCoeffTensor, nvinfer1::ElementWiseOperation::kPROD);
1272-
1273-
nvinfer1::ITensor* mulTensor = mul->getOutput(0);
1274-
auto* add = ctx->network()->addElementWise(*mulTensor, *aCoeffTensor, nvinfer1::ElementWiseOperation::kSUM);
1257+
// Calculate bias vector
1258+
std::vector<int32_t> biasVector = calculateBias(daDims, idxDims, pitches, axis);
12751259

1276-
nvinfer1::ITensor* addTensor = add->getOutput(0);
1260+
// Perform idx` = idx * pitch[axis] + bias calculation.
1261+
auto* axisPitchTensor = addConstant(ctx, axisPitch, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput(0);
1262+
auto* biasTensor = addConstant(ctx, biasVector, ::ONNX_NAMESPACE::TensorProto::INT32, idxDims)->getOutput(0);
12771263

1278-
nvinfer1::Dims flattenDataDims{1};
1264+
auto* mul
1265+
= ctx->network()->addElementWise(index, *axisPitchTensor, nvinfer1::ElementWiseOperation::kPROD)->getOutput(0);
1266+
auto* newIndices
1267+
= ctx->network()->addElementWise(*mul, *biasTensor, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0);
12791268

1280-
flattenDataDims.nbDims = 1;
1281-
flattenDataDims.d[0] = nDataElements;
1269+
nvinfer1::Dims flattenDataDims{1, {static_cast<int32_t>(volume(daDims))}};
12821270
auto* reshape = ctx->network()->addShuffle(data);
12831271
reshape->setReshapeDimensions(flattenDataDims);
12841272
reshape->setZeroIsPlaceholder(false);
12851273

12861274
nvinfer1::ITensor* flattenData = reshape->getOutput(0);
1287-
auto* layer = ctx->network()->addGather(*flattenData, *addTensor, 0);
1275+
auto* layer = ctx->network()->addGather(*flattenData, *newIndices, 0);
12881276
ctx->registerLayer(layer, getNodeName(node));
12891277
RETURN_FIRST_OUTPUT(layer);
12901278
}

onnx2trt_utils.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,68 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
136136
return Status::success();
137137
}
138138

139+
// Helper functions for calculateBias:
140+
int32_t getBias(const std::vector<int32_t>& dimension_count, const std::vector<int32_t>& pitches, int32_t axis)
141+
{
142+
int32_t result{0};
143+
for (int32_t i = 0; i < static_cast<int32_t>(dimension_count.size()); i++)
144+
{
145+
if (i != axis)
146+
{
147+
result += dimension_count[i] * pitches[i];
148+
}
149+
}
150+
return result;
151+
}
152+
153+
void incrementOuterDimension(std::vector<int32_t>& dimensionCount, nvinfer1::Dims idxDims)
154+
{
155+
// Start at [x,x,0]. Increment starting from the outer dimension.
156+
int32_t rank = dimensionCount.size();
157+
158+
for (int32_t i = rank - 1; i >= 0; i--)
159+
{
160+
int dimLimit = idxDims.d[i];
161+
// If we're not at the limit, increment current axis and return
162+
if (++dimensionCount[i] != dimLimit)
163+
{
164+
break;
165+
}
166+
// Else, we increment on the next dimension and reset current one
167+
dimensionCount[i] = 0;
168+
}
169+
}
170+
171+
std::vector<int32_t> calculateBias(
172+
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis)
173+
{
174+
std::vector<int32_t> biasVector;
175+
std::vector<int32_t> dimensionCount(daDims.nbDims, 0);
176+
int64_t total = volume(idxDims);
177+
178+
for (int64_t i = 0; i < total; i++)
179+
{
180+
int32_t bias = getBias(dimensionCount, pitches, axis);
181+
biasVector.push_back(bias);
182+
incrementOuterDimension(dimensionCount, idxDims);
183+
}
184+
return biasVector;
185+
}
186+
187+
std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims)
188+
{
189+
int32_t pitch = 1;
190+
int32_t nbDims = inputDims.nbDims;
191+
std::vector<int32_t> pitches(nbDims);
192+
pitches[nbDims - 1] = pitch;
193+
for (int32_t i = nbDims - 2; i >= 0; i--)
194+
{
195+
pitch *= inputDims.d[i + 1];
196+
pitches[i] = pitch;
197+
}
198+
return pitches;
199+
}
200+
139201
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
140202
{
141203
// Linear resize supports up to 3D resize on the outermost dimensions.

onnx2trt_utils.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
156156
// Helper function to broadcast three tensors to the largest one's shape
157157
Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1::ITensor*& t2, nvinfer1::ITensor*& t3);
158158

159+
// Helper function to calculate the bias tensor for GatherElements.
160+
std::vector<int32_t> calculateBias(
161+
const nvinfer1::Dims& daDims, const nvinfer1::Dims& idxDims, const std::vector<int32_t>& pitches, int32_t axis);
162+
163+
// Helper function to calculate and return a vector representation of the pitches of a given shape
164+
std::vector<int32_t> calculatePitches(const nvinfer1::Dims& inputDims);
165+
159166
// Helper function to check that linear resize can be used
160167
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);
161168

0 commit comments

Comments
 (0)