Skip to content

Commit 2eb74d9

Browse files
authored
Fix GEMM import assertion (#485)
1 parent ff2e754 commit 2eb74d9

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

builtin_op_importers.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,11 @@ DEFINE_BUILTIN_OP_IMPORTER(Gemm)
10231023
float beta = attrs.get("beta", 1.f);
10241024
bool transA = attrs.get("transA", false);
10251025
bool transB = attrs.get("transB", false);
1026+
1027+
// TRT does not support INT32 input types for this node
1028+
ASSERT(!inputs.at(0).isInt32() && !inputs.at(1).isInt32() && "TensorRT does not support INT32 inputs for GEMM!", ErrorCode::kUNSUPPORTED_NODE);
10261029
nvinfer1::ITensor& inputA = convertToTensor(inputs.at(0), ctx);
10271030
nvinfer1::ITensor* inputB = &convertToTensor(inputs.at(1), ctx);
1028-
// TRT does not support INT32 input types for this node
1029-
ASSERT(inputA.getType() == inputB->getType() && inputA.getType() != nvinfer1::DataType::kINT32, ErrorCode::kUNSUPPORTED_NODE);
10301031

10311032
// Use FC if it is likely to be faster - which is usually when no Shuffles are required.
10321033
bool canUseFC = inputs.at(0).is_tensor() && inputs.at(1).is_weights() && inputs.at(2).is_weights() && alpha == 1.f

0 commit comments

Comments
 (0)