1+ #include < string>
2+ #include " gtest/gtest.h"
3+ #include " torch/csrc/jit/ir/irparser.h"
4+ #include " tests/util/util.h"
5+ #include " core/compiler.h"
6+
7+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithBias) {
8+ const auto graph = R"IR(
9+ graph(%0 : Tensor,
10+ %1 : Tensor,
11+ %2 : Tensor,
12+ %3 : Tensor,
13+ %4 : Tensor,
14+ %5 : Tensor,
15+ %6 : Tensor):
16+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
17+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
18+ return (%8))IR" ;
19+
20+ auto g = std::make_shared<torch::jit::Graph>();
21+ torch::jit::parseIR (graph, &*g);
22+
23+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
24+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
25+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
26+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
27+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
28+ auto b_ih = at::randn ({4 *20 }, {at::kCUDA });
29+ auto b_hh = at::randn ({4 *20 }, {at::kCUDA });
30+
31+ auto jit_input = at::clone (input);
32+ auto jit_h0 = at::clone (h0);
33+ auto jit_c0 = at::clone (c0);
34+ auto jit_w_ih = at::clone (w_ih);
35+ auto jit_w_hh = at::clone (w_hh);
36+ auto jit_b_ih = at::clone (b_ih);
37+ auto jit_b_hh = at::clone (b_hh);
38+
39+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
40+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh, jit_b_ih, jit_b_hh});
41+
42+ auto trt_input = at::clone (input);
43+ auto trt_h0 = at::clone (h0);
44+ auto trt_c0 = at::clone (c0);
45+ auto trt_w_ih = at::clone (w_ih);
46+ auto trt_w_hh = at::clone (w_hh);
47+ auto trt_b_ih = at::clone (b_ih);
48+ auto trt_b_hh = at::clone (b_hh);
49+
50+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
51+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh, trt_b_ih, trt_b_hh});
52+
53+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
54+ }
55+
56+ TEST (Converters, ATenLSTMCellConvertsCorrectlyWithoutBias) {
57+ const auto graph = R"IR(
58+ graph(%0 : Tensor,
59+ %1 : Tensor,
60+ %2 : Tensor,
61+ %3 : Tensor,
62+ %4 : Tensor):
63+ %5 : None = prim::Constant()
64+ %6 : None = prim::Constant()
65+ %7 : Tensor[] = prim::ListConstruct(%1, %2)
66+ %8 : Tensor, %9 : Tensor = aten::lstm_cell(%0, %7, %3, %4, %5, %6)
67+ return (%8))IR" ;
68+
69+ auto g = std::make_shared<torch::jit::Graph>();
70+ torch::jit::parseIR (graph, &*g);
71+
72+ auto input = at::randn ({50 , 10 }, {at::kCUDA });
73+ auto h0 = at::randn ({50 , 20 }, {at::kCUDA });
74+ auto c0 = at::randn ({50 , 20 }, {at::kCUDA });
75+ auto w_ih = at::randn ({4 *20 , 10 }, {at::kCUDA });
76+ auto w_hh = at::randn ({4 *20 , 20 }, {at::kCUDA });
77+
78+ auto jit_input = at::clone (input);
79+ auto jit_h0 = at::clone (h0);
80+ auto jit_c0 = at::clone (c0);
81+ auto jit_w_ih = at::clone (w_ih);
82+ auto jit_w_hh = at::clone (w_hh);
83+
84+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
85+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_input, jit_h0, jit_c0, jit_w_ih, jit_w_hh});
86+
87+ auto trt_input = at::clone (input);
88+ auto trt_h0 = at::clone (h0);
89+ auto trt_c0 = at::clone (c0);
90+ auto trt_w_ih = at::clone (w_ih);
91+ auto trt_w_hh = at::clone (w_hh);
92+
93+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
94+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_input, trt_h0, trt_c0, trt_w_ih, trt_w_hh});
95+
96+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ].reshape_as (jit_results[0 ]), 2e-6 ));
97+ }
0 commit comments