55#include " core/compiler.h"
66
77TEST (Converters, ATenMeanConvertsCorrectly) {
8+ const auto graph = R"IR(
9+ graph(%0 : Tensor):
10+ %4 : None = prim::Constant()
11+ %5 : Tensor = aten::mean(%0, %4)
12+ return (%5))IR" ;
13+
14+ auto g = std::make_shared<torch::jit::Graph>();
15+ torch::jit::script::parseIR (graph, &*g);
16+
17+ auto in = at::randint (-5 , 5 , {4 , 4 }, at::kCUDA );
18+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
19+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
20+
21+ in = at::clone (in);
22+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
23+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
24+
25+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
26+ }
27+
28+ TEST (Converters, ATenMeanHigherDimensionConvertsCorrectly) {
29+ const auto graph = R"IR(
30+ graph(%0 : Tensor):
31+ %4 : None = prim::Constant()
32+ %5 : Tensor = aten::mean(%0, %4)
33+ return (%5))IR" ;
34+
35+ auto g = std::make_shared<torch::jit::Graph>();
36+ torch::jit::script::parseIR (graph, &*g);
37+
38+ auto in = at::randint (-5 , 5 , {4 , 4 , 4 , 4 }, at::kCUDA );
39+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
40+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
41+
42+ in = at::clone (in);
43+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
44+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
45+
46+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
47+ }
48+
49+ TEST (Converters, ATenMeanRowConvertsCorrectly) {
850 const auto graph = R"IR(
951 graph(%0 : Tensor):
1052 %1 : int = prim::Constant[value=1]()
11- %2 : int[] = prim::ListConstruct(%1)
53+ %2 : int[] = prim::ListConstruct(%1)
1254 %3 : bool = prim::Constant[value=0]()
1355 %4 : None = prim::Constant()
14- %5 : Tensor = aten::mean(%0, %2, %3, %4)
56+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
1557 return (%5))IR" ;
1658
1759 auto g = std::make_shared<torch::jit::Graph>();
@@ -24,18 +66,18 @@ TEST(Converters, ATenMeanConvertsCorrectly) {
2466 in = at::clone (in);
2567 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
2668 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
27-
69+
2870 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
2971}
3072
3173TEST (Converters, ATenMeanKeepDimsConvertsCorrectly) {
3274 const auto graph = R"IR(
3375 graph(%0 : Tensor):
3476 %1 : int = prim::Constant[value=1]()
35- %2 : int[] = prim::ListConstruct(%1)
77+ %2 : int[] = prim::ListConstruct(%1)
3678 %3 : bool = prim::Constant[value=1]()
3779 %4 : None = prim::Constant()
38- %5 : Tensor = aten::mean(%0, %2, %3, %4)
80+ %5 : Tensor = aten::mean(%0, %2, %3, %4)
3981 return (%5))IR" ;
4082
4183 auto g = std::make_shared<torch::jit::Graph>();
@@ -48,6 +90,6 @@ TEST(Converters, ATenMeanKeepDimsConvertsCorrectly) {
4890 in = at::clone (in);
4991 params = trtorch::core::conversion::get_named_params (g->inputs (), {});
5092 auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
51-
93+
5294 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ]));
5395}
0 commit comments