@@ -24,40 +24,29 @@ c10::IValue preprocess(const torch::jit::Module& mod, const c10::Dict<c10::IValu
2424
2525c10::impl::GenericDict TensorRTBackend::compile (c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
2626 auto mod = mod_val.toModule ();
27- mod = core::lowering::LowerModule (mod);
28-
2927 auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
30- core::lowering::LowerInfo lower_info;
31- for (auto it = spec.begin (), end = spec.end (); it != end; ++it) {
32- const auto & method_name = it->key ();
33- auto method = mod.get_method (method_name);
34- auto graph = method.graph ();
35- core::lowering::LowerGraph (graph, lower_info);
36- }
3728
3829 auto handles = c10::impl::GenericDict (
3930 c10::StringType::get (), c10::getCustomClassType<c10::intrusive_ptr<core::runtime::TRTEngine>>());
4031
4132 for (auto it = spec.begin (), end = spec.end (); it != end; ++it) {
33+ auto mod_ = mod.clone ();
4234 const auto & method_name = it->key ();
43- auto method = mod.get_method (method_name);
44- auto g = method.graph ();
45-
4635 auto raw_spec = it->value ().toCustomClass <trtorch::pyapi::CompileSpec>();
4736 LOG_DEBUG (raw_spec->stringify ());
4837 auto cfg = raw_spec->toInternalCompileSpec ();
49- auto convert_cfg = std::move (cfg.convert_info );
50- auto graph_and_ivalues = torch::jit::LowerGraph (*g, mod._ivalue ());
38+ auto graph_and_ivals = Lower (mod_, method_name, cfg.lower_info );
5139
52- g = graph_and_ivalues .first ;
53- auto params = graph_and_ivalues .second ;
40+ auto g = graph_and_ivals .first ;
41+ auto params = graph_and_ivals .second ;
5442 auto named_params = core::conversion::get_named_params (g->inputs (), params);
5543
44+ auto convert_cfg = std::move (cfg.convert_info );
5645 auto device_spec = convert_cfg.engine_settings .device ;
5746 auto device = core::runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
5847 auto serialized_engine = core::conversion::ConvertBlockToEngine (g->block (), convert_cfg, named_params);
5948 auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key (), serialized_engine, device);
60- handles.insert (method. name () , at::IValue (engine_handle));
49+ handles.insert (method_name , at::IValue (engine_handle));
6150 }
6251
6352 return c10::impl::toGenericDict (handles);
0 commit comments