Skip to content

Commit e989f0f

Browse files
committed
Cast BOOL concats to INT32 (#620)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent e9a0748 commit e989f0f

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

builtin_op_importers.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,17 @@ DEFINE_BUILTIN_OP_IMPORTER(Clip)
354354
DEFINE_BUILTIN_OP_IMPORTER(Concat)
355355
{
356356
std::vector<nvinfer1::ITensor*> tensors;
357+
// Cast boolean inputs to INT32
358+
bool isBool = false;
357359
for (auto& input : inputs)
358360
{
359-
// TRT does not support BOOL input types for this node
360-
ASSERT(!input.isBool(), ErrorCode::kUNSUPPORTED_NODE);
361-
tensors.push_back(&convertToTensor(input, ctx));
361+
auto* tensorPtr = &convertToTensor(input, ctx);
362+
if (tensorPtr->getType() == nvinfer1::DataType::kBOOL)
363+
{
364+
tensorPtr = castHelper(ctx, tensorPtr, nvinfer1::DataType::kINT32);
365+
isBool = true;
366+
}
367+
tensors.push_back(tensorPtr);
362368
}
363369
OnnxAttrs attrs(node, ctx);
364370
int axis = attrs.get<int>("axis");
@@ -368,6 +374,12 @@ DEFINE_BUILTIN_OP_IMPORTER(Concat)
368374
ctx->registerLayer(layer, node.name());
369375
ASSERT(layer, ErrorCode::kUNSUPPORTED_NODE);
370376
layer->setAxis(axis);
377+
378+
if (isBool)
379+
{
380+
return {{castHelper(ctx, layer->getOutput(0), nvinfer1::DataType::kBOOL)}};
381+
}
382+
371383
RETURN_FIRST_OUTPUT(layer);
372384
}
373385

onnx2trt_utils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,13 @@ bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors)
152152
return true;
153153
}
154154

155+
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype)
156+
{
157+
nvinfer1::IIdentityLayer* cast = ctx->network()->addIdentity(*input);
158+
cast->setOutputType(0, dtype);
159+
return cast->getOutput(0);
160+
}
161+
155162
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape)
156163
{
157164
int rank = shape->getDimensions().d[0];

onnx2trt_utils.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ Status broadcastTensors(IImporterContext* ctx, nvinfer1::ITensor*& t1, nvinfer1:
154154
// Helper function to check that linear resize can be used
155155
bool canUseLinearResize(const size_t scaleSize, const float* scaleFactors);
156156

157+
// Helper function to add a Cast layer in the network
158+
nvinfer1::ITensor* castHelper(IImporterContext* ctx, nvinfer1::ITensor* input, nvinfer1::DataType dtype);
159+
157160
// Helper function for constantOfShape operator. Input shape must be a shape tensor
158161
nvinfer1::ITensor* constantOfShape(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor* constant, nvinfer1::ITensor* shape);
159162

0 commit comments

Comments
 (0)