You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: ImporterContext.hpp
+35-26Lines changed: 35 additions & 26 deletions
Original file line number
Diff line number
Diff line change
@@ -26,12 +26,11 @@
26
26
#include"onnx2trt_utils.hpp"
27
27
28
28
#include<list>
29
-
#include<set>
30
-
#include<string>
31
29
#include<unordered_map>
32
30
33
31
namespaceonnx2trt
34
32
{
33
+
35
34
classImporterContextfinal : public IImporterContext
36
35
{
37
36
nvinfer1::INetworkDefinition* _network;
@@ -45,22 +44,20 @@ class ImporterContext final : public IImporterContext
45
44
StringMap<float> mTensorRangeMins;
46
45
StringMap<float> mTensorRangeMaxes;
47
46
StringMap<nvinfer1::DataType> mLayerPrecisions;
48
-
std::set<std::string> mTensorNames; // keep track of tensor names used so far,
49
-
// to avoid duplicate naming in TRT.
50
-
std::set<std::string> mLayerNames; // keep track of layer names used so far,
51
-
// to avoid duplicate naming in TRT.
52
-
int64_tmSuffixCounter = 0; // increasing suffix counter used to uniquify layer names.
53
-
std::unordered_set<std::string> mUnsupportedShapeTensors; // Container to hold any shape tensors that are
54
-
// the output of layers that do not support
55
-
// shape tensors.
56
-
StringMap<std::string> mLoopTensors; // Container to map subgraph tensors to
57
-
// their original outer graph names.
58
-
std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file
47
+
std::set<std::string> mTensorNames; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
48
+
std::set<std::string> mLayerNames; // Keep track of how many times a tensor name shows up, to avoid duplicate naming in TRT.
49
+
int64_tmSuffixCounter = 0; // increasing suffix counter used to uniquify layer names.
50
+
std::unordered_set<std::string> mUnsupportedShapeTensors; // Container to hold output tensor names of layers that produce shape tensor outputs but do not natively support them.
51
+
StringMap<std::string> mLoopTensors; // Container to map subgraph tensors to their original outer graph names.
52
+
std::string mOnnxFileLocation; // Keep track of the directory of the parsed ONNX file
53
+
std::list<std::string> mInitializerNames; // Keep track of unique names of any initializers
54
+
RefitMap_t* mRefitMap; // Keep track of names of ONNX refittable weights with their corresponding TRT layer and role
auto reduceOp = type == nvinfer1::LayerType::kREDUCE ? (static_cast<nvinfer1::IReduceLayer*>(layer))->getOperation() : nvinfer1::ReduceOperation::kSUM;
487
488
if (!supportsShapeTensor(type, elementwiseOp, reduceOp))
0 commit comments