@@ -10,18 +10,19 @@ namespace converters {
1010namespace impl {
1111namespace {
1212
13- bool add_conv_deconv (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13+ bool add_conv_deconv (
14+ ConversionCtx* ctx,
15+ const torch::jit::Node* n,
16+ args& args,
17+ nvinfer1::Dims& stride,
18+ nvinfer1::Dims& padding,
19+ nvinfer1::Dims& dilation,
20+ bool transposed,
21+ nvinfer1::Dims& out_padding,
22+ int64_t groups) {
1423 // Input to conv/deconv
1524 auto in = args[0 ].ITensor ();
1625
17- // Conv /deconv parameters
18- auto stride = util::toDims (args[3 ].unwrapToIntList ());
19- auto padding = util::toDims (args[4 ].unwrapToIntList ());
20- auto dilation = util::toDims (args[5 ].unwrapToIntList ());
21- bool transposed = args[6 ].unwrapToBool ();
22- auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
23- int64_t groups = args[8 ].unwrapToInt ();
24-
2526 // Reshape the parameters to 2D if needed
2627 if (stride.nbDims == 1 ) {
2728 stride = util::unsqueezeDims (stride, 1 , 1 );
@@ -174,28 +175,66 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
174175 return true ;
175176}
176177
177- auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
178- .pattern({
179- R"SIG( aten::_convolution(Tensor input, Tensor weight,
178+ auto conv_registrations TRTORCH_UNUSED =
179+ RegisterNodeConversionPatterns ()
180+ .pattern({
181+ R"SIG( aten::_convolution(Tensor input, Tensor weight,
180182 Tensor? bias, int[] stride, int[] padding,
181183 int[] dilation, bool transposed,
182184 int[] output_padding, int groups, bool benchmark,
183185 bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor))SIG" ,
184- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
185- return add_conv_deconv (ctx, n, args);
186- }})
187- .pattern({
188- R"SIG( aten::_convolution.deprecated(Tensor input, Tensor weight,
186+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
187+ // Conv /deconv parameters
188+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
189+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
190+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
191+ bool transposed = args[6 ].unwrapToBool ();
192+ auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
193+ int64_t groups = args[8 ].unwrapToInt ();
194+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
195+ }})
196+ .pattern({
197+ R"SIG( aten::_convolution.deprecated(Tensor input, Tensor weight,
189198 Tensor? bias, int[] stride, int[] padding,
190199 int[] dilation, bool transposed,
191200 int[] output_padding, int groups, bool benchmark,
192201 bool deterministic, bool cudnn_enabled) -> (Tensor))SIG" ,
193- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
194- // This pattern is only matched for traced JIT models which do not
195- // have allow_tf32 bool in the function signature. The TRT conversion
196- // code is exactly same as the above call.
197- return add_conv_deconv (ctx, n, args);
198- }});
202+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
203+ // This pattern is only matched for traced JIT models which do not
204+ // have allow_tf32 bool in the function signature. The TRT conversion
205+ // code is exactly same as the above call.
206+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
207+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
208+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
209+ bool transposed = args[6 ].unwrapToBool ();
210+ auto out_padding = util::toDims (args[7 ].unwrapToIntList ());
211+ int64_t groups = args[8 ].unwrapToInt ();
212+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
213+ }})
214+ .pattern(
215+ {R"SIG( aten::conv1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor)SIG" ,
216+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
217+ // Conv /deconv parameters
218+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
219+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
220+ auto dilation = util::toDims (args[5 ].unwrapToIntList ());
221+ bool transposed = false ;
222+ nvinfer1::Dims out_padding{1 , {0 }};
223+ int64_t groups = args[6 ].unwrapToInt ();
224+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
225+ }})
226+ .pattern(
227+ {R"SIG( aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor)SIG" ,
228+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
229+ // Conv /deconv parameters
230+ auto stride = util::toDims (args[3 ].unwrapToIntList ());
231+ auto padding = util::toDims (args[4 ].unwrapToIntList ());
232+ auto out_padding = util::toDims (args[5 ].unwrapToIntList ());
233+ bool transposed = true ;
234+ int64_t groups = args[6 ].unwrapToInt ();
235+ auto dilation = util::toDims (args[7 ].unwrapToIntList ());
236+ return add_conv_deconv (ctx, n, args, stride, padding, dilation, transposed, out_padding, groups);
237+ }});
199238} // namespace
200239} // namespace impl
201240} // namespace converters
0 commit comments