Skip to content

Commit b3eda61

Browse files
authored
Perform weights datatype check for slice (onnx#632)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent 5c22586 commit b3eda61

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

builtin_op_importers.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3073,6 +3073,7 @@ ShapeTensor computeSizes(IImporterContext* ctx, const ShapeTensor& starts, const
30733073

30743074
DEFINE_BUILTIN_OP_IMPORTER(Slice)
30753075
{
3076+
ASSERT(validateInputs(inputs), ErrorCode::kUNSUPPORTED_NODE);
30763077
const int nbInputs = node.input().size();
30773078
// "...it uses this information to slice the input data tensor."
30783079
nvinfer1::ITensor& data = convertToTensor(inputs.at(0), ctx);

onnx2trt_utils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,6 +1806,20 @@ int64_t volume(const nvinfer1::Dims& dims)
18061806
return std::accumulate(dims.d, dims.d + dims.nbDims, 1, std::multiplies<int64_t>{});
18071807
}
18081808

1809+
bool validateInputs(std::vector<TensorOrWeights>& inputs)
1810+
{
1811+
nvinfer1::DataType type = nvinfer1::DataType::kFLOAT;
1812+
bool valid = true;
1813+
for (auto& input : inputs)
1814+
{
1815+
if (input.is_weights())
1816+
{
1817+
valid = valid && convertDtype(input.weights().type, &type);
1818+
}
1819+
}
1820+
return valid;
1821+
}
1822+
18091823
Status weightsToVector(TensorOrWeights weights, std::vector<int64_t>* weightVector)
18101824
{
18111825
ASSERT(weights.is_weights(), ErrorCode::kUNSUPPORTED_NODE);

onnx2trt_utils.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ NodeImportResult unaryHelper(IImporterContext* ctx, const ::ONNX_NAMESPACE::Node
304304
// Helper function to unsqueeze tensors on a given set of axes
305305
nvinfer1::ITensor* unsqueezeTensor(IImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ITensor& tensor, const std::vector<int>& axes, bool regLayer = false);
306306

307+
// Helper function to validate inputs and receive type information
308+
bool validateInputs(std::vector<TensorOrWeights>& inputs);
309+
307310
// Helper function to convert a ShapedWeights object into a vector
308311
Status weightsToVector(TensorOrWeights weights, std::vector<int64_t>* weightVector);
309312

0 commit comments

Comments
 (0)