@@ -75,4 +75,107 @@ TEST(Evaluators, ZerosDataTypeEvaluatesCorrectly) {
7575 auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {in});
7676
7777 ASSERT_TRUE (at::equal (jit_results[0 ].toTensor ().to (at::kCUDA ), trt_results[0 ].toTensor ()));
78+ }
79+
80+ TEST (Evaluators, ATenArangeIntEvaluatesCorrectly) {
81+ const auto graph = R"IR(
82+ graph():
83+ %0 : int = prim::Constant[value=51]()
84+ %1 : None = prim::Constant()
85+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
86+ return (%2))IR" ;
87+
88+ auto g = std::make_shared<torch::jit::Graph>();
89+ torch::jit::parseIR (graph, &*g);
90+
91+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
92+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
93+
94+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
95+ }
96+
97+ TEST (Evaluators, ATenArangeFloatEvaluatesCorrectly) {
98+ const auto graph = R"IR(
99+ graph():
100+ %0 : float = prim::Constant[value=51.2]()
101+ %1 : None = prim::Constant()
102+ %2 : Tensor = aten::arange(%0, %1, %1, %1, %1)
103+ return (%2))IR" ;
104+
105+ auto g = std::make_shared<torch::jit::Graph>();
106+ torch::jit::parseIR (graph, &*g);
107+
108+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
109+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
110+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
111+ }
112+
113+ TEST (Evaluators, ATenArangeStartEndIntEvaluatesCorrectly) {
114+ const auto graph = R"IR(
115+ graph():
116+ %0 : int = prim::Constant[value=1]()
117+ %1 : int = prim::Constant[value=51]()
118+ %2 : None = prim::Constant()
119+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
120+ return (%3))IR" ;
121+
122+ auto g = std::make_shared<torch::jit::Graph>();
123+ torch::jit::parseIR (graph, &*g);
124+
125+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
126+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
127+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
128+ }
129+
130+ TEST (Evaluators, ATenArangeStartEndFloatEvaluatesCorrectly) {
131+ const auto graph = R"IR(
132+ graph():
133+ %0 : float = prim::Constant[value=1.5]()
134+ %1 : float = prim::Constant[value=51.2]()
135+ %2 : None = prim::Constant()
136+ %3 : Tensor = aten::arange(%0, %1, %2, %2, %2, %2)
137+ return (%3))IR" ;
138+
139+ auto g = std::make_shared<torch::jit::Graph>();
140+ torch::jit::parseIR (graph, &*g);
141+
142+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
143+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
144+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
145+ }
146+
147+ TEST (Evaluators, ATenArangeStartEndStepIntEvaluatesCorrectly) {
148+ const auto graph = R"IR(
149+ graph():
150+ %0 : int = prim::Constant[value=1]()
151+ %1 : int = prim::Constant[value=51]()
152+ %2 : int = prim::Constant[value=1]()
153+ %3 : None = prim::Constant()
154+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
155+ return (%4))IR" ;
156+
157+ auto g = std::make_shared<torch::jit::Graph>();
158+ torch::jit::parseIR (graph, &*g);
159+
160+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
161+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
162+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
163+ }
164+
165+ TEST (Evaluators, ATenArangeStartEndStepFloatEvaluatesCorrectly) {
166+ const auto graph = R"IR(
167+ graph():
168+ %0 : float = prim::Constant[value=1.2]()
169+ %1 : float = prim::Constant[value=51.6]()
170+ %2 : float = prim::Constant[value=1.5]()
171+ %3 : None = prim::Constant()
172+ %4 : Tensor = aten::arange(%0, %1, %2, %3, %3, %3, %3)
173+ return (%4))IR" ;
174+
175+ auto g = std::make_shared<torch::jit::Graph>();
176+ torch::jit::parseIR (graph, &*g);
177+
178+ auto jit_results = trtorch::tests::util::EvaluateGraphJIT (g, {});
179+ auto trt_results = trtorch::tests::util::EvaluateGraph (g->block (), {});
180+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ].toTensor (), trt_results[0 ].toTensor (), 2e-6 ));
78181}
0 commit comments