Skip to content

Commit 4f349a1

Browse files
committed
feat(//cpp/trtorchc): Adding a new CLI application for TRTorch which
will serve as a replacement to trtorchexec and act like a GCC style compiler for TorchScript Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent 3381073 commit 4f349a1

File tree

5 files changed

+501
-25
lines changed

5 files changed

+501
-25
lines changed

cpp/api/include/trtorch/ptq.h

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,17 @@ class Int8Calibrator : Algorithm {
104104
std::stringstream ss;
105105
ss << "Reading Calibration Cache from " << cache_file_path_;
106106
logging::log(logging::Level::kINFO, ss.str());
107+
107108
cache_.clear();
108-
std::ifstream cache_file(cache_file_path_, std::ios::binary);
109-
cache_file >> std::noskipws;
110-
if (cache_file.good()) {
111-
std::copy(std::istream_iterator<char>(cache_file),
112-
std::istream_iterator<char>(),
113-
std::back_inserter(cache_));
114-
ss << "Cache read";
115-
logging::log(logging::Level::kDEBUG, ss.str());
109+
std::ifstream input(cache_file_path_, std::ios::binary);
110+
input >> std::noskipws;
111+
if (input.good()) {
112+
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
113+
std::back_inserter(cache_));
114+
logging::log(logging::Level::kDEBUG, "Cache read");
116115
}
117-
cache_size_ = cache_.size();
118-
return cache_size_ ? cache_.data() : nullptr;
116+
length = cache_.size();
117+
return length ? cache_.data() : nullptr;
119118
}
120119
return nullptr;
121120
}
@@ -220,23 +219,17 @@ class Int8CacheCalibrator : Algorithm {
220219
std::stringstream ss;
221220
ss << "Reading Calibration Cache from " << cache_file_path_;
222221
logging::log(logging::Level::kINFO, ss.str());
222+
223223
cache_.clear();
224-
std::ifstream cache_file;
225-
cache_file.open(cache_file_path_, std::ios::in | std::ios::binary);
226-
cache_file.unsetf(std::ios::skipws);
227-
cache_file.seekg(0, std::ios::beg);
228-
cache_.reserve(cache_file.tellg());
229-
cache_file.seekg(0, std::ios::beg);
230-
if (cache_file.good()) {
231-
std::cout << "Trying to read cache" << std::endl;
232-
std::copy(std::istreambuf_iterator<char>(cache_file),
233-
std::istreambuf_iterator<char>(),
234-
std::back_inserter(cache_));
235-
ss << "Cache read";
236-
logging::log(logging::Level::kDEBUG, ss.str());
224+
std::ifstream input(cache_file_path_, std::ios::binary);
225+
input >> std::noskipws;
226+
if (input.good()) {
227+
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
228+
std::back_inserter(cache_));
229+
logging::log(logging::Level::kDEBUG, "Cache read");
237230
}
238-
cache_size_ = cache_.size();
239-
return cache_size_ ? cache_.data() : nullptr;
231+
length = cache_.size();
232+
return length ? cache_.data() : nullptr;
240233
}
241234

242235

cpp/api/include/trtorch/trtorch.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ struct TRTORCH_API ExtraInfo {
142142
* @return false
143143
*/
144144
constexpr bool operator==(DataType other) const { return value == other.value; }
145+
/**
146+
* @brief Comparision operator for DataType
147+
*
148+
* @param other
149+
* @return true
150+
* @return false
151+
*/
152+
constexpr bool operator==(DataType::Value other) const { return value == other; }
145153
/**
146154
* @brief Comparision operator for DataType
147155
*
@@ -150,6 +158,14 @@ struct TRTORCH_API ExtraInfo {
150158
* @return false
151159
*/
152160
constexpr bool operator!=(DataType other) const { return value != other.value; }
161+
/**
162+
* @brief Comparision operator for DataType
163+
*
164+
* @param other
165+
* @return true
166+
* @return false
167+
*/
168+
constexpr bool operator!=(DataType::Value other) const { return value != other; }
153169
private:
154170
Value value;
155171
};

cpp/trtorchc/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
cc_binary(
4+
name = "trtorchc",
5+
srcs = [
6+
"main.cpp"
7+
],
8+
deps = [
9+
"@libtorch//:libtorch",
10+
"@libtorch//:caffe2",
11+
"//third_party/args",
12+
"//cpp/api:trtorch"
13+
],
14+
)

cpp/trtorchc/README.md

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# trtorhc
2+
3+
trtorchc is a compiler CLI application using the TRTorch compiler. It serves as an easy way to compile a
4+
TorchScript Module with TRTorch from the command-line to quickly check support or as part of
5+
a deployment pipeline. All basic features of the compiler are supported including post training
6+
quantization (though you must already have a calibration cache file to use). The compiler can
7+
output two formats, either a TorchScript program with the TensorRT engine embedded or
8+
the TensorRT engine itself as a PLAN file.
9+
10+
All that is required to run the program after compilation is for C++ linking against libtrtorch.so
11+
or in Python importing the trtorch package. All other aspects of using compiled modules are identical
12+
to standard TorchScript. Load with `torch.jit.load()` and run like you would run any other module.
13+
14+
15+
```
16+
trtorchc [input_file_path] [output_file_path]
17+
[input_shapes...] {OPTIONS}
18+
19+
TRTorch is a compiler for TorchScript, it will compile and optimize
20+
TorchScript programs to run on NVIDIA GPUs using TensorRT
21+
22+
OPTIONS:
23+
24+
-h, --help Display this help menu
25+
Verbiosity of the compiler
26+
-v, --verbose Dumps debugging information about the
27+
compilation process onto the console
28+
-w, --warnings Disables warnings generated during
29+
compilation onto the console (warnings
30+
are on by default)
31+
--info Dumps info messages generated during
32+
compilation onto the console
33+
--build-debuggable-engine Creates a debuggable engine
34+
--use-strict-types Restrict operating type to only use set
35+
default operation precision
36+
(op_precision)
37+
--allow-gpu-fallback (Only used when targeting DLA
38+
(device-type)) Lets engine run layers on
39+
GPU if they are not supported on DLA
40+
-p[precision],
41+
--default-op-precision=[precision]
42+
Default operating precision for the
43+
engine (Int8 requires a
44+
calibration-cache argument) [ float |
45+
float32 | f32 | half | float16 | f16 |
46+
int8 | i8 ] (default: float)
47+
-d[type], --device-type=[type] The type of device the engine should be
48+
built for [ gpu | dla ] (default: gpu)
49+
--engine-capability=[capability] The type of device the engine should be
50+
built for [ default | safe_gpu |
51+
safe_dla ]
52+
--calibration-cache-file=[file_path]
53+
Path to calibration cache file to use
54+
for post training quantization
55+
--num-min-timing-iter=[num_iters] Number of minimization timing iterations
56+
used to select kernels
57+
--num-avg-timing-iters=[num_iters]
58+
Number of averaging timing iterations
59+
used to select kernels
60+
--workspace-size=[workspace_size] Maximum size of workspace given to
61+
TensorRT
62+
--max-batch-size=[max_batch_size] Maximum batch size (must be >= 1 to be
63+
set, 0 means not set)
64+
-t[threshold],
65+
--threshold=[threshold] Maximum acceptable numerical deviation
66+
from standard torchscript output
67+
(default 2e-5)
68+
--save-engine Instead of compiling a full a
69+
TorchScript program, save the created
70+
engine to the path specified as the
71+
output path
72+
input_file_path Path to input TorchScript file
73+
output_file_path Path for compiled TorchScript (or
74+
TensorRT engine) file
75+
input_shapes... Sizes for inputs to engine, can either
76+
be a single size or a range defined by
77+
Min, Optimal, Max sizes, e.g.
78+
"(N,..,C,H,W)"
79+
"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]"
80+
"--" can be used to terminate flag options and force all following
81+
arguments to be treated as positional options
82+
```
83+
84+
e.g.
85+
```
86+
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
87+
```

0 commit comments

Comments
 (0)