@@ -23,44 +23,52 @@ def test_compile_traced(self):
2323 "enabled_precisions" : {torch .float }
2424 }
2525
26- trt_mod = trtorch .compile (self .traced_model , compile_spec )
26+ trt_mod = trtorch .compile (self .traced_model , ** compile_spec )
2727 same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
2828 self .assertTrue (same < 2e-2 )
2929
3030 def test_compile_script (self ):
31+ trt_mod = trtorch .compile (self .scripted_model , inputs = [self .input ], device = trtorch .Device (gpu_id = 0 ), enabled_precisions = {torch .float })
32+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
33+ self .assertTrue (same < 2e-2 )
34+
35+ def test_from_torch_tensor (self ):
3136 compile_spec = {
32- "inputs" : [trtorch . Input ( shape = self .input . shape ) ],
37+ "inputs" : [self .input ],
3338 "device" : {
3439 "device_type" : trtorch .DeviceType .GPU ,
3540 "gpu_id" : 0 ,
3641 },
3742 "enabled_precisions" : {torch .float }
3843 }
3944
40- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
45+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
4146 same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
4247 self .assertTrue (same < 2e-2 )
4348
44- def test_from_torch_tensor (self ):
49+ def test_device (self ):
50+ compile_spec = {"inputs" : [self .input ], "device" : trtorch .Device ("gpu:0" ), "enabled_precisions" : {torch .float }}
51+
52+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
53+ same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
54+ self .assertTrue (same < 2e-2 )
55+
56+
57+ def test_compile_script_from_dict (self ):
4558 compile_spec = {
46- "inputs" : [self .input ],
59+ "inputs" : [trtorch . Input ( shape = self .input . shape ) ],
4760 "device" : {
4861 "device_type" : trtorch .DeviceType .GPU ,
4962 "gpu_id" : 0 ,
5063 },
5164 "enabled_precisions" : {torch .float }
5265 }
5366
54- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
55- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
67+ trt_mod = trtorch .compile (self .traced_model , ** compile_spec )
68+ same = (trt_mod (self .input ) - self .traced_model (self .input )).abs ().max ()
5669 self .assertTrue (same < 2e-2 )
5770
58- def test_device (self ):
59- compile_spec = {"inputs" : [self .input ], "device" : trtorch .Device ("gpu:0" ), "enabled_precisions" : {torch .float }}
6071
61- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
62- same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
63- self .assertTrue (same < 2e-2 )
6472
6573
6674class TestCompileHalf (ModelTestCase ):
@@ -80,7 +88,7 @@ def test_compile_script_half(self):
8088 "enabled_precisions" : {torch .half }
8189 }
8290
83- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
91+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
8492 same = (trt_mod (self .input .half ()) - self .scripted_model (self .input .half ())).abs ().max ()
8593 trtorch .logging .log (trtorch .logging .Level .Debug , "Max diff: " + str (same ))
8694 self .assertTrue (same < 3e-2 )
@@ -103,7 +111,7 @@ def test_compile_script_half_by_default(self):
103111 "enabled_precisions" : {torch .float , torch .half }
104112 }
105113
106- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
114+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
107115 same = (trt_mod (self .input .half ()) - self .scripted_model (self .input .half ())).abs ().max ()
108116 trtorch .logging .log (trtorch .logging .Level .Debug , "Max diff: " + str (same ))
109117 self .assertTrue (same < 3e-2 )
@@ -132,7 +140,7 @@ def test_compile_script(self):
132140 }
133141 }
134142
135- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
143+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
136144 same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
137145 self .assertTrue (same < 2e-3 )
138146
@@ -160,7 +168,7 @@ def test_compile_script(self):
160168 }
161169 }
162170
163- trt_mod = trtorch .compile (self .scripted_model , compile_spec )
171+ trt_mod = trtorch .compile (self .scripted_model , ** compile_spec )
164172 same = (trt_mod (self .input ) - self .scripted_model (self .input )).abs ().max ()
165173 self .assertTrue (same < 2e-3 )
166174
@@ -183,7 +191,7 @@ def test_pt_to_trt_to_pt(self):
183191 }
184192 }
185193
186- trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , compile_spec )
194+ trt_engine = trtorch .convert_method_to_trt_engine (self .ts_model , "forward" , ** compile_spec )
187195 trt_mod = trtorch .embed_engine_in_new_module (trt_engine , trtorch .Device ("cuda:0" ))
188196 same = (trt_mod (self .input ) - self .ts_model (self .input )).abs ().max ()
189197 self .assertTrue (same < 2e-3 )
0 commit comments