33TEST_P (CppAPITests, CompiledModuleIsClose) {
44 std::vector<torch::jit::IValue> jit_inputs_ivalues;
55 std::vector<torch::jit::IValue> trt_inputs_ivalues;
6- for (auto in_shape : input_shapes) {
7- auto in = at::randint (5 , in_shape, {at::kCUDA });
6+ std::vector<torch_tensorrt::Input> shapes;
7+ for (uint64_t i = 0 ; i < input_shapes.size (); i++) {
8+ auto in = at::randint (5 , input_shapes[i], {at::kCUDA }).to (input_types[i]);
89 jit_inputs_ivalues.push_back (in.clone ());
910 trt_inputs_ivalues.push_back (in.clone ());
11+ auto in_spec = torch_tensorrt::Input (input_shapes[i]);
12+ in_spec.dtype = input_types[i];
13+ shapes.push_back (in_spec);
14+ std::cout << in_spec << std::endl;
1015 }
1116
1217 torch::jit::IValue jit_results_ivalues = torch_tensorrt::tests::util::RunModuleForward (mod, jit_inputs_ivalues);
1318 std::vector<at::Tensor> jit_results;
14- jit_results.push_back (jit_results_ivalues.toTensor ());
19+ if (jit_results_ivalues.isTuple ()) {
20+ auto tuple = jit_results_ivalues.toTuple ();
21+ for (auto t : tuple->elements ()) {
22+ jit_results.push_back (t.toTensor ());
23+ }
24+ } else {
25+ jit_results.push_back (jit_results_ivalues.toTensor ());
26+ }
27+
28+ auto spec = torch_tensorrt::ts::CompileSpec (shapes);
29+ spec.truncate_long_and_double = true ;
1530
16- auto trt_mod = torch_tensorrt::ts::compile (mod, input_shapes );
31+ auto trt_mod = torch_tensorrt::ts::compile (mod, spec );
1732 torch::jit::IValue trt_results_ivalues = torch_tensorrt::tests::util::RunModuleForward (trt_mod, trt_inputs_ivalues);
1833 std::vector<at::Tensor> trt_results;
19- trt_results.push_back (trt_results_ivalues.toTensor ());
34+ if (trt_results_ivalues.isTuple ()) {
35+ auto tuple = trt_results_ivalues.toTuple ();
36+ for (auto t : tuple->elements ()) {
37+ trt_results.push_back (t.toTensor ());
38+ }
39+ } else {
40+ trt_results.push_back (trt_results_ivalues.toTensor ());
41+ }
2042
2143 for (size_t i = 0 ; i < trt_results.size (); i++) {
2244 ASSERT_TRUE (
@@ -30,13 +52,14 @@ INSTANTIATE_TEST_SUITE_P(
3052 CompiledModuleForwardIsCloseSuite,
3153 CppAPITests,
3254 testing::Values (
33- PathAndInSize ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
34- PathAndInSize({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
35- PathAndInSize({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
36- PathAndInSize({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
37- PathAndInSize({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
38- PathAndInSize({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 2e-5 }),
39- PathAndInSize({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 8e-3 }),
40- PathAndInSize({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, 8e-2 })));
55+ PathAndInput ({" tests/modules/resnet18_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
56+ PathAndInput({" tests/modules/resnet50_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
57+ PathAndInput({" tests/modules/mobilenet_v2_traced.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
58+ PathAndInput({" tests/modules/resnet18_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
59+ PathAndInput({" tests/modules/resnet50_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
60+ PathAndInput({" tests/modules/mobilenet_v2_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 2e-5 }),
61+ PathAndInput({" tests/modules/efficientnet_b0_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 8e-3 }),
62+ PathAndInput({" tests/modules/bert_base_uncased_traced.jit.pt" , {{1 , 14 }, {1 , 14 }}, {at::kInt , at::kInt }, 8e-2 }),
63+ PathAndInput({" tests/modules/vit_scripted.jit.pt" , {{1 , 3 , 224 , 224 }}, {at::kFloat }, 8e-2 })));
4164
4265#endif
0 commit comments