@@ -219,6 +219,53 @@ def test_is_colored_output_on(self):
219219 self .assertTrue (color )
220220
221221
222+ class TestDevice (unittest .TestCase ):
223+
224+ def test_from_string_constructor (self ):
225+ device = trtorch .Device ("cuda:0" )
226+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
227+ self .assertEqual (device .gpu_id , 0 )
228+
229+ device = trtorch .Device ("gpu:1" )
230+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
231+ self .assertEqual (device .gpu_id , 1 )
232+
233+ def test_from_string_constructor_dla (self ):
234+ device = trtorch .Device ("dla:0" )
235+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
236+ self .assertEqual (device .gpu_id , 0 )
237+ self .assertEqual (device .dla_core , 0 )
238+
239+ device = trtorch .Device ("dla:1" , allow_gpu_fallback = True )
240+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
241+ self .assertEqual (device .gpu_id , 0 )
242+ self .assertEqual (device .dla_core , 1 )
243+ self .assertEqual (device .allow_gpu_fallback , True )
244+
245+ def test_kwargs_gpu (self ):
246+ device = trtorch .Device (gpu_id = 0 )
247+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
248+ self .assertEqual (device .gpu_id , 0 )
249+
250+ def test_kwargs_dla_and_settings (self ):
251+ device = trtorch .Device (dla_core = 1 , allow_gpu_fallback = False )
252+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
253+ self .assertEqual (device .gpu_id , 0 )
254+ self .assertEqual (device .dla_core , 1 )
255+ self .assertEqual (device .allow_gpu_fallback , False )
256+
257+ device = trtorch .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
258+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
259+ self .assertEqual (device .gpu_id , 1 )
260+ self .assertEqual (device .dla_core , 0 )
261+ self .assertEqual (device .allow_gpu_fallback , True )
262+
263+ def test_from_torch (self ):
264+ device = trtorch .Device ._from_torch_device (torch .device ("cuda:0" ))
265+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
266+ self .assertEqual (device .gpu_id , 0 )
267+
268+
222269def test_suite ():
223270 suite = unittest .TestSuite ()
224271 suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
@@ -231,6 +278,7 @@ def test_suite():
231278 suite .addTest (
232279 TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
233280 suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
281+ suite .addTest (unittest .makeSuite (TestDevice ))
234282
235283 return suite
236284
0 commit comments