diff --git a/builtin_op_importers.cpp b/builtin_op_importers.cpp index bb19eea0..d1a6c5c2 100644 --- a/builtin_op_importers.cpp +++ b/builtin_op_importers.cpp @@ -1062,6 +1062,8 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) { auto transposedWeights = ctx->createTempWeights(weights.type, weights.shape); ASSERT(transposeWeights(weights, {1, 0}, &transposedWeights), ErrorCode::kUNSUPPORTED_NODE); + transposedWeights.setName(weights.getName()); + LOG_WARNING("Weight " << transposedWeights.getName() << " has been transposed! If you plan on overwriting this weight with the Refitter API, the new weights must be pre-transposed"); weights = transposedWeights; // Since we've already transposed now, we can set transpose to false. transB = false; @@ -1070,6 +1072,7 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm) = ctx->network()->addConstant(weights.shape, static_cast(weights)); // Map the constant layer to the weights name. ctx->registerLayer(weightsLayer, node.input(1)); + ctx->insertRefitMap(weights.getName(), weightsLayer->getName(), nvinfer1::WeightsRole::kCONSTANT); inputB = weightsLayer->getOutput(0); } else