@@ -14,22 +14,41 @@ namespace converters {
1414namespace impl {
1515namespace {
1616
17+ nvinfer1::ITensor* add_bias (nvinfer1::ITensor* a, nvinfer1::ITensor* b, std::string b_name, ConversionCtx* ctx, const torch::jit::Node* n) {
18+ auto a_dim = a->getDimensions ();
19+ auto b_dim = b->getDimensions ();
20+
21+ LOG_DEBUG (b_name << " tensor shape: " << b_dim);
22+
23+ TRTORCH_CHECK (util::broadcastable (a_dim, b_dim, false ), " bias " << b_name << " is not broadcastable - can't be added to previous matmul operation." );
24+
25+ if (util::toVec (a_dim) != util::toVec (b_dim)) {
26+ LOG_DEBUG (b_name << " 's dimensions need to be reshaped" );
27+
28+ auto shuffle = ctx->net ->addShuffle (*b);
29+ TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
30+ shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_dim), a_dim.nbDims ));
31+ b = shuffle->getOutput (0 );
32+ }
33+
34+ auto add = ctx->net ->addElementWise (*a, *b, nvinfer1::ElementWiseOperation::kSUM );
35+ TRTORCH_CHECK (add, " Unable to create ElementWise layer from node: " << *n);
36+
37+ return add->getOutput (0 );
38+ }
39+
1740auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1841 .pattern({
1942 " aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)" ,
2043 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
2144 auto input = args[0 ].ITensorOrFreeze (ctx);
2245 auto w_ih = args[2 ].ITensorOrFreeze (ctx);
2346 auto w_hh = args[3 ].ITensorOrFreeze (ctx);
24- auto b_ih = args[4 ].ITensorOrFreeze (ctx);
25- auto b_hh = args[5 ].ITensorOrFreeze (ctx);
2647
2748 LOG_DEBUG (" Input tensor shape: " << input->getDimensions ());
2849 LOG_DEBUG (" w_ih tensor shape: " << w_ih->getDimensions ());
2950 LOG_DEBUG (" w_hh tensor shape: " << w_hh->getDimensions ());
30- LOG_DEBUG (" b_ih tensor shape: " << b_ih->getDimensions ());
31- LOG_DEBUG (" b_hh tensor shape: " << b_hh->getDimensions ());
32-
51+
3352 std::vector<nvinfer1::ITensor*> state;
3453 auto hx = args[1 ].IValue ()->toListRef ();
3554 for (unsigned int i = 0 ; i < hx.size (); i++) {
@@ -51,81 +70,56 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
5170 // calculate first half of gates
5271 auto mm1 = ctx->net ->addMatrixMultiply (*input, nvinfer1::MatrixOperation::kNONE , *w_ih, nvinfer1::MatrixOperation::kTRANSPOSE );
5372 TRTORCH_CHECK (mm1, " Unable to create matrix multiplication node: " << *n);
54-
5573 auto mm1_out = mm1->getOutput (0 );
56- auto mm1_dim = mm1_out->getDimensions ();
57- auto b_ih_dim = b_ih->getDimensions ();
58-
59- TRTORCH_CHECK (util::broadcastable (mm1_dim, b_ih_dim, false ));
6074
61- if (util::toVec (mm1_dim) != util::toVec (b_ih_dim)) {
62- LOG_DEBUG (" b_ih dimensions need to be reshaped" );
63-
64- auto shuffle = ctx->net ->addShuffle (*b_ih);
65- TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
66- shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_ih_dim), mm1_dim.nbDims ));
67- b_ih = shuffle->getOutput (0 );
68- }
69-
70- auto add1 = ctx->net ->addElementWise (*mm1_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
71- TRTORCH_CHECK (add1, " Unable to create ElementWise layer from node: " << *n);
72- auto add1_out = add2->getOutput (0 );
75+ auto out1 = !args[4 ].IValue ()->isNone () ? add_bias (mm1_out, args[4 ].ITensorOrFreeze (ctx), " b_ih" , ctx, n) : mm1_out;
7376
7477 // calculate second half of gates
75- auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOE );
78+ auto mm2 = ctx->net ->addMatrixMultiply (*state[0 ], nvinfer1::MatrixOperation::kNONE , *w_hh, nvinfer1::MatrixOperation::kTRANSPOSE );
7679 TRTORCH_CHECK (mm2, " Unable to create matrix multiplication node: " << *n);
77-
7880 auto mm2_out = mm2->getOutput (0 );
79- auto mm2_dim = mm2_out->getDimensions ();
80- auto b_hh_dim = b_hh->getDimensions ();
81-
82- TRTORCH_CHECK (util::broadcastable (mm2_dim, b_hh_dim, false ));
8381
84- if (util::toVec (mm2_dim) != util::toVec (b_hh_dim)) {
85- LOG_DEBUG (" b_hh dimensions need to be reshaped" );
86-
87- auto shuffle = ctx->net ->addShuffle (*b_hh);
88- TRTORCH_CHECK (shuffle, " Unable to create shuffle layer from node: " << *n);
89- shuffle->setReshapeDimensions (util::toDimsPad (util::toVec (b_hh_dim), mm2_dim.nbDims ));
90- b_hh = shuffle->getOutput (0 );
91- }
92-
93- auto add2 = ctx->net ->addElementWise (*mm2_out, *b_ih, nvinfer1::ElementWiseOperation::kSUM );
94- TRTORCH_CHECK (add2, " Unable to create ElementWise layer from node: " << *n);
95- auto add2_out = add2->getOutput (0 );
82+ auto out2 = !args[5 ].IValue ()->isNone () ? add_bias (mm2_out, args[5 ].ITensorOrFreeze (ctx), " b_hh" , ctx, n) : mm2_out;
9683
9784 // gates
98- auto add3 = ctx->net ->addElementWise (*add1_out , *add2_out , nvinfer1::ElementWiseOperation::kSUM );
99- TRTORCH_CHECK (add3 , " Unable to create ElementWise layer from node: " << *n);
100- auto add3_out = add3 ->getOutput (0 );
85+ auto add = ctx->net ->addElementWise (*out1 , *out2 , nvinfer1::ElementWiseOperation::kSUM );
86+ TRTORCH_CHECK (add , " Unable to create ElementWise layer from node: " << *n);
87+ auto add_out = add ->getOutput (0 );
10188
10289 // chunk Tensor into 4 parts and apply activation functions
103- auto dims = util::toVec (add3_out ->getDimensions ());
90+ auto dims = util::toVec (add_out ->getDimensions ());
10491 auto batch = dims[0 ];
10592 auto hidden = dims[1 ]/4 ;
10693
107- auto size = util::toDims (std::vector<int64_t >({batch, hidden}));
108- auto stride = util::toDims (std::vector<int64_t >({1 , 1 }));
94+ std::vector<int64_t > size_vec = {batch, hidden};
95+ std::vector<int64_t > stride_vec = {1 , 1 };
96+ std::vector<int64_t > offset0 = {0 , 0 };
97+ std::vector<int64_t > offset1 = {0 , hidden};
98+ std::vector<int64_t > offset2 = {0 , 2 *hidden};
99+ std::vector<int64_t > offset3 = {0 , 3 *hidden};
100+
101+ auto size = util::toDims (size_vec);
102+ auto stride = util::toDims (stride_vec);
109103
110- auto slice1 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 0 }) ), size, stride);
104+ auto slice1 = ctx->net ->addSlice (*add_out , util::toDims (offset0 ), size, stride);
111105 TRTORCH_CHECK (slice1, " Unable to create Slice layer from node: " << *n);
112106 auto activ1 = ctx->net ->addActivation (*slice1->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
113107 TRTORCH_CHECK (activ1, " Unable to create sigmoid activation layer from node: " << *n);
114108 auto ingate = activ1->getOutput (0 );
115109
116- auto slice2 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , hidden}) ), size, stride);
110+ auto slice2 = ctx->net ->addSlice (*add_out , util::toDims (offset1 ), size, stride);
117111 TRTORCH_CHECK (slice2, " Unable to create Slice layer from node: " << *n);
118112 auto activ2 = ctx->net ->addActivation (*slice2->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
119113 TRTORCH_CHECK (activ2, " Unable to create sigmoid activation layer from node: " << *n);
120114 auto forgetgate = activ2->getOutput (0 );
121115
122- auto slice3 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 2 *hidden}) ), size, stride);
116+ auto slice3 = ctx->net ->addSlice (*add_out , util::toDims (offset2 ), size, stride);
123117 TRTORCH_CHECK (slice3, " Unable to create Slice layer from node: " << *n);
124118 auto activ3 = ctx->net ->addActivation (*slice3->getOutput (0 ), nvinfer1::ActivationType::kTANH );
125119 TRTORCH_CHECK (activ3, " Unable to create tanh activation layer from node: " << *n);
126120 auto cellgate = activ3->getOutput (0 );
127121
128- auto slice4 = ctx->net ->addSlice (*add3_out , util::toDims (std::vector< int64_t >({ 0 , 3 *hidden}) ), size, stride);
122+ auto slice4 = ctx->net ->addSlice (*add_out , util::toDims (offset3 ), size, stride);
129123 TRTORCH_CHECK (slice4, " Unable to create Slice layer from node: " << *n);
130124 auto activ4 = ctx->net ->addActivation (*slice4->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
131125 TRTORCH_CHECK (activ4, " Unable to create sigmoid activation layer from node: " << *n);
0 commit comments