@@ -99,16 +99,57 @@ auto lstm_cell_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
9999 TRTORCH_CHECK (add3, " Unable to create ElementWise layer from node: " << *n);
100100 auto add3_out = add3->getOutput (0 );
101101
102-
103-
104-
105-
106- auto mm_layer = ctx->net ->addMatrixMultiply (*self, nvinfer1::MatrixOperation::kNONE , *other, nvinfer1::MatrixOperation::kNONE );
107- TRTORCH_CHECK (mm_layer, " Unable to create matrix multiplication node: " << *n);
108- mm_layer->setName (util::node_info (n).c_str ());
109- auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], mm_layer->getOutput (0 ));
102+ // chunk Tensor into 4 parts and apply activation functions
103+ auto dims = util::toVec (add3_out->getDimensions ());
104+ auto batch = dims[0 ];
105+ auto hidden = dims[1 ]/4 ;
106+
107+ auto size = util::toDims (std::vector<int64_t >({batch, hidden}));
108+ auto stride = util::toDims (std::vector<int64_t >({1 , 1 }));
109+
110+ auto slice1 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 0 })), size, stride);
111+ TRTORCH_CHECK (slice1, " Unable to create Slice layer from node: " << *n);
112+ auto activ1 = ctx->net ->addActivation (*slice1->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
113+ TRTORCH_CHECK (activ1, " Unable to create sigmoid activation layer from node: " << *n);
114+ auto ingate = activ1->getOutput (0 );
115+
116+ auto slice2 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , hidden})), size, stride);
117+ TRTORCH_CHECK (slice2, " Unable to create Slice layer from node: " << *n);
118+ auto activ2 = ctx->net ->addActivation (*slice2->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
119+ TRTORCH_CHECK (activ2, " Unable to create sigmoid activation layer from node: " << *n);
120+ auto forgetgate = activ2->getOutput (0 );
121+
122+ auto slice3 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 2 *hidden})), size, stride);
123+ TRTORCH_CHECK (slice3, " Unable to create Slice layer from node: " << *n);
124+ auto activ3 = ctx->net ->addActivation (*slice3->getOutput (0 ), nvinfer1::ActivationType::kTANH );
125+ TRTORCH_CHECK (activ3, " Unable to create tanh activation layer from node: " << *n);
126+ auto cellgate = activ3->getOutput (0 );
127+
128+ auto slice4 = ctx->net ->addSlice (*add3_out, util::toDims (std::vector<int64_t >({0 , 3 *hidden})), size, stride);
129+ TRTORCH_CHECK (slice4, " Unable to create Slice layer from node: " << *n);
130+ auto activ4 = ctx->net ->addActivation (*slice4->getOutput (0 ), nvinfer1::ActivationType::kSIGMOID );
131+ TRTORCH_CHECK (activ4, " Unable to create sigmoid activation layer from node: " << *n);
132+ auto outgate = activ4->getOutput (0 );
133+
134+ // compute cy
135+ auto forget_cx = ctx->net ->addElementWise (*forgetgate, *state[1 ], nvinfer1::ElementWiseOperation::kPROD );
136+ TRTORCH_CHECK (forget_cx, " Unable to create ElementWise layer from node: " << *n);
137+ auto in_cell = ctx->net ->addElementWise (*ingate, *cellgate, nvinfer1::ElementWiseOperation::kPROD );
138+ TRTORCH_CHECK (in_cell, " Unable to create ElementWise layer from node: " << *n);
139+ auto cy = ctx->net ->addElementWise (*forget_cx->getOutput (0 ), *in_cell->getOutput (0 ), nvinfer1::ElementWiseOperation::kPROD );
140+ TRTORCH_CHECK (cy, " Unable to create ElementWise layer from node: " << *n);
141+ auto cy_out = ctx->AssociateValueAndTensor (n->outputs ()[1 ], cy->getOutput (0 ));
142+
143+ // compute hy
144+ auto cy_tanh = ctx->net ->addActivation (*cy_out, nvinfer1::ActivationType::kTANH );
145+ TRTORCH_CHECK (cy_tanh, " Unable to create tanh activation layer from node: " << *n);
146+ auto hy = ctx->net ->addElementWise (*outgate, *cy_tanh->getOutput (0 ), nvinfer1::ElementWiseOperation::kPROD );
147+ TRTORCH_CHECK (hy, " Unable to create ElementWise layer from node: " << *n);
148+ auto hy_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], hy->getOutput (0 ));
149+
150+ LOG_DEBUG (" Output tensor [hy] shape: " << hy_out->getDimensions ());
151+ LOG_DEBUG (" Output tensor [cy] shape: " << cy_out->getDimensions ());
110152
111- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
112153 return true ;
113154 }
114155 });
0 commit comments