@@ -122,6 +122,23 @@ def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:
122122
123123 return info
124124
125+ def _parse_torch_fallback (fallback_info : Dict [str , Any ]) -> trtorch ._C .TorchFallback :
126+ info = trtorch ._C .TorchFallback ()
127+ if "enabled" not in fallback_info :
128+ raise KeyError ("Enabled is required parameter" )
129+ else :
130+ assert isinstance (fallback_info ["enabled" ], bool )
131+ info .enabled = fallback_info ["enabled" ]
132+ if "min_block_size" in fallback_info :
133+ assert isinstance (fallback_info ["min_block_size" ], int )
134+ info .min_block_size = fallback_info ["min_block_size" ]
135+
136+ if "forced_fallback_operators" in fallback_info :
137+ assert isinstance (fallback_info ["forced_fallback_operators" ], list )
138+ info .forced_fallback_operators = fallback_info ["forced_fallback_operators" ]
139+
140+ return info
141+
125142
126143def _parse_compile_spec (compile_spec : Dict [str , Any ]) -> trtorch ._C .CompileSpec :
127144 info = trtorch ._C .CompileSpec ()
@@ -174,6 +191,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
174191 assert type (compile_spec ["max_batch_size" ]) is int
175192 info .max_batch_size = compile_spec ["max_batch_size" ]
176193
194+ if "torch_fallback" in compile_spec :
195+ info .torch_fallback = _parse_torch_fallback (compile_spec ["torch_fallback" ])
196+
197+
177198 return info
178199
179200
@@ -242,7 +263,13 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
242263 d .set_dla_core (parsed_spec .device .dla_core )
243264 d .set_allow_gpu_fallback (parsed_spec .device .allow_gpu_fallback )
244265
266+ torch_fallback = torch .classes .tensorrt .TorchFallback ()
267+ torch_fallback .set_enabled (parsed_spec .torch_fallback .enabled )
268+ torch_fallback .set_min_block_size (parsed_spec .torch_fallback .min_block_size )
269+ torch_fallback .set_forced_fallback_operators (parsed_spec .torch_fallback .forced_fallback_operators )
270+
245271 backend_spec .set_device (d )
272+ backend_spec .set_torch_fallback (fallback )
246273 backend_spec .set_op_precision (int (parsed_spec .op_precision ))
247274 backend_spec .set_disable_tf32 (parsed_spec .disable_tf32 )
248275 backend_spec .set_refit (parsed_spec .refit )
0 commit comments