Skip to content

Commit 1e591b9

Browse files
authored
fix duplicate layer names bug (#446) (#467)
Suppose we have a network with (not all distinct) layer names layer layer_1 layer When ImporterContext sees "layer", it sees it's not in mLayerNameCounts, and sets mLayerNameCounts["layer"] = 1 and adds a TRT layer with name "layer". It then sees "layer_1", concludes it's not in mLayerNameCounts, so it sets mLayerNameCounts["layer_1"] = 1 and adds a TRT layer with name "layer_1". NOW when it sees "layer", it sees that mLayerNameCounts["layer"] == 1, so we produce a "uniqueName" of "layer" + "_" + std::to_string(mLayerNameCounts["layer"] ), ie "layer_1", which is a name conflict for the TRT net. This change keeps track of all inserted names in a set and in the case of duplicates, tries suffix-appended modifications of the duplicated name by ever increasing integers until a name appears which has not been used.
1 parent 347e50f commit 1e591b9

1 file changed

Lines changed: 36 additions & 17 deletions

File tree

ImporterContext.hpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@
2626
#include "onnx2trt_utils.hpp"
2727

2828
#include <list>
29+
#include <set>
30+
#include <string>
2931
#include <unordered_map>
3032

3133
namespace onnx2trt
3234
{
33-
3435
class ImporterContext final : public IImporterContext
3536
{
3637
nvinfer1::INetworkDefinition* _network;
@@ -44,13 +45,17 @@ class ImporterContext final : public IImporterContext
4445
StringMap<float> mTensorRangeMins;
4546
StringMap<float> mTensorRangeMaxes;
4647
StringMap<nvinfer1::DataType> mLayerPrecisions;
47-
StringMap<size_t>
48-
mTensorNameCounts; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
49-
StringMap<size_t>
50-
mLayerNameCounts; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
51-
std::unordered_set<std::string> mUnsupportedShapeTensors; // Container to hold output tensor names of layers that produce shape tensor outputs but do not natively support them.
52-
StringMap<std::string> mLoopTensors; // Container to map subgraph tensors to their original outer graph names.
53-
std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file
48+
std::set<std::string> mTensorNames; // keep track of tensor names used so far,
49+
// to avoid duplicate naming in TRT.
50+
std::set<std::string> mLayerNames; // keep track of layer names used so far,
51+
// to avoid duplicate naming in TRT.
52+
int64_t mSuffixCounter = 0; // increasing suffix counter used to uniquify layer names.
53+
std::unordered_set<std::string> mUnsupportedShapeTensors; // Container to hold any shape tensors that are
54+
// the output of layers that do not support
55+
// shape tensors.
56+
StringMap<std::string> mLoopTensors; // Container to map subgraph tensors to
57+
// their original outer graph names.
58+
std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file
5459

5560
public:
5661
ImporterContext(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger)
@@ -98,13 +103,12 @@ class ImporterContext final : public IImporterContext
98103
{
99104
return mOnnxFileLocation;
100105
}
101-
// This actually handles weights as well, but is named this way to be consistent with the tensors()
106+
// This actually handles weights as well, but is named this way to be
107+
// consistent with the tensors()
102108
virtual void registerTensor(TensorOrWeights tensor, const std::string& basename) override
103109
{
104110
// TRT requires unique tensor names.
105-
const std::string uniqueName
106-
= mTensorNameCounts[basename] ? (basename + "_" + std::to_string(mTensorNameCounts[basename])) : basename;
107-
++mTensorNameCounts[basename];
111+
const std::string uniqueName = generateUniqueName(mTensorNames, basename);
108112

109113
if (tensor)
110114
{
@@ -122,8 +126,9 @@ class ImporterContext final : public IImporterContext
122126
convertINT64(reinterpret_cast<int64_t*>(weights.values), weights.shape, ctx), weights.shape};
123127
}
124128
}
125-
// Overwrite previous tensors registered with the same name (this only happens when there are subgraphs,
126-
// and in that case, overwriting is the desired behavior).
129+
// Overwrite previous tensors registered with the same name (this only
130+
// happens when there are subgraphs, and in that case, overwriting is the
131+
// desired behavior).
127132
this->tensors()[basename] = std::move(tensor);
128133
}
129134

@@ -133,9 +138,7 @@ class ImporterContext final : public IImporterContext
133138
if (layer)
134139
{
135140
const std::string name = basename.empty() ? layer->getName() : basename;
136-
const std::string uniqueName
137-
= mLayerNameCounts[name] ? (name + "_" + std::to_string(mLayerNameCounts[name])) : name;
138-
++mLayerNameCounts[name];
141+
const std::string uniqueName = generateUniqueName(mLayerNames, basename);
139142

140143
auto* ctx = this; // To enable logging.
141144
LOG_VERBOSE("Registering layer: " << name << " for ONNX node: " << basename);
@@ -225,6 +228,22 @@ class ImporterContext final : public IImporterContext
225228
return _opsets.at(domain);
226229
}
227230
}
231+
232+
private:
233+
std::string generateUniqueName(std::set<std::string>& namesSet, const std::string& basename)
234+
{
235+
std::string candidate = basename;
236+
237+
while (namesSet.find(candidate) != namesSet.end())
238+
{
239+
candidate = basename + "_" + std::to_string(mSuffixCounter);
240+
++mSuffixCounter;
241+
}
242+
243+
namesSet.insert(candidate);
244+
245+
return candidate;
246+
}
228247
};
229248

230249
} // namespace onnx2trt

0 commit comments

Comments
 (0)