@@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
109109 ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
110110}
111111
112+ TEST (Converters, ATenPReLUConvertsCorrectly) {
113+ const auto graph = R"IR(
114+ graph(%0 : Tensor,
115+ %1 : Float(1)):
116+ %3 : Tensor = aten::prelu(%0, %1)
117+ return (%3))IR" ;
118+
119+ auto g = std::make_shared<torch::jit::Graph>();
120+ torch::jit::parseIR (graph, &*g);
121+
122+ auto in = at::randint (-5 , 5 , {5 }, {at::kCUDA });
123+ auto slope = at::randint (-5 , 5 , {1 }, {at::kCUDA });
124+
125+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
126+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
127+
128+ in = at::clone (in);
129+ params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
130+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
131+
132+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
133+ }
134+
135+ TEST (Converters, ATenPReLUMultiChannelConvertsCorrectly) {
136+ const auto graph = R"IR(
137+ graph(%0 : Tensor,
138+ %1 : Float(10)):
139+ %3 : Tensor = aten::prelu(%0, %1)
140+ return (%3))IR" ;
141+
142+ auto g = std::make_shared<torch::jit::Graph>();
143+ torch::jit::parseIR (graph, &*g);
144+
145+ auto in = at::randint (-5 , 5 , {1 ,10 , 1 , 1 }, {at::kCUDA });
146+ auto slope = at::randint (-5 , 5 , {10 }, {at::kCUDA });
147+
148+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
149+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
150+
151+ in = at::clone (in);
152+ params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
153+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
154+
155+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
156+ }
157+
158+
0 commit comments