Skip to content

Commit 351f22b

Browse files
committed
Support optional plugin fields (onnx#676)
Signed-off-by: Kevin Chen <kevinch@nvidia.com>
1 parent e5ee2b5 commit 351f22b

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

builtin_op_importers.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4191,11 +4191,19 @@ std::tuple<const void*, size_t> copyField(
41914191

41924192
// Load plugin fields from an ONNX node, using fieldData for temporary allocations.
41934193
std::vector<nvinfer1::PluginField> loadFields(string_map<std::vector<uint8_t>>& fieldData, const OnnxAttrs& attrs,
4194-
const nvinfer1::PluginFieldCollection* fieldNames)
4194+
const nvinfer1::PluginFieldCollection* fieldNames, IImporterContext* ctx)
41954195
{
41964196
std::vector<nvinfer1::PluginField> fields{};
41974197
for (int i = 0; i < fieldNames->nbFields; ++i)
41984198
{
4199+
// Some plugins may have default values for fields that map to optional attributes in an ONNX graph.
4200+
if (!attrs.count(fieldNames->fields[i].name))
4201+
{
4202+
LOG_WARNING("Attribute " << fieldNames->fields[i].name
4203+
<< " not found in plugin node! Ensure that the plugin creator has a default value "
4204+
"defined or the engine may fail to build.");
4205+
continue;
4206+
}
41994207
// Name must be retrieved from the map so that it is alive for long enough.
42004208
const std::string& fieldName = fieldData.emplace(fieldNames->fields[i].name, std::vector<uint8_t>{}).first->first;
42014209
const void* data{nullptr};
@@ -4288,7 +4296,7 @@ DEFINE_BUILTIN_OP_IMPORTER(FallbackPluginImporter)
42884296
const nvinfer1::PluginFieldCollection* fieldNames = creator->getFieldNames();
42894297
// Field data needs to be type erased, we use fieldData for temporary allocations.
42904298
string_map<std::vector<uint8_t>> fieldData{};
4291-
std::vector<nvinfer1::PluginField> fields = loadFields(fieldData, attrs, fieldNames);
4299+
std::vector<nvinfer1::PluginField> fields = loadFields(fieldData, attrs, fieldNames, ctx);
42924300

42934301
nvinfer1::IPluginV2* plugin = createPlugin(node.name(), creator, fields);
42944302
ASSERT(plugin && "Could not create plugin", ErrorCode::kUNSUPPORTED_NODE);

0 commit comments

Comments
 (0)