Skip to content

Commit 44fc991

Browse files
committed
Add constant weights for GEMMs to refitmap (#556)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent eb559b6 commit 44fc991

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

builtin_op_importers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,8 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
10621062
{
10631063
auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape);
10641064
ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights), ErrorCode::kUNSUPPORTED_NODE);
1065+
transposedWeights.setName(weights.getName());
1066+
LOG_WARNING("Weight " << transposedWeights.getName() << " has been transposed! If you plan on overwriting this weight with the Refitter API, the new weights must be pre-transposed");
10651067
weights = transposedWeights;
10661068
// Since we've already transposed now, we can set transpose to false.
10671069
transB = false;
@@ -1070,6 +1072,7 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
10701072
= ctx->network()->addConstant(weights.shape, static_cast<nvinfer1::Weights>(weights));
10711073
// Map the constant layer to the weights name.
10721074
ctx->registerLayer(weightsLayer, node.input(1));
1075+
ctx->insertRefitMap(weights.getName(), weightsLayer->getName(), nvinfer1::WeightsRole::kCONSTANT);
10731076
inputB = weightsLayer->getOutput(0);
10741077
}
10751078
else

0 commit comments

Comments
 (0)