@@ -77,7 +77,7 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyAbove3DIndex) {
7777 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
7878}
7979
80- TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex ) {
80+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveOneIndex ) {
8181 const auto graph = R"IR(
8282 graph(%0 : Tensor):
8383 %1 : None = prim::Constant()
@@ -100,5 +100,31 @@ TEST(Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
100100
101101 auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
102102
103+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
104+ }
105+
106+ TEST (Converters, ATenSoftmaxNDConvertsCorrectlyNegtiveIndex) {
107+ const auto graph = R"IR(
108+ graph(%0 : Tensor):
109+ %1 : None = prim::Constant()
110+ %2 : int = prim::Constant[value=-2]()
111+ %3 : Tensor = aten::softmax(%0, %2, %1)
112+ return (%3))IR" ;
113+
114+ auto g = std::make_shared<torch::jit::Graph>();
115+ torch::jit::parseIR (graph, &*g);
116+
117+ auto in = at::randint (0 , 5 , {1 , 2 , 2 , 2 , 2 }, {at::kCUDA });
118+
119+ auto jit_in = at::clone (in);
120+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
121+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {jit_in});
122+
123+ auto trt_in = at::clone (in);
124+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
125+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {trt_in});
126+
127+ auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
128+
103129 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
104130}
0 commit comments