@@ -253,6 +253,7 @@ GraphAndMapping ConstructFallbackGraph(
253253 }
254254 // update the input ranges for each segments
255255 convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
256+
256257 auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
257258 auto temp_g = std::make_shared<torch::jit::Graph>();
258259 auto device_spec = convert_cfg.engine_settings .device ;
@@ -288,7 +289,7 @@ GraphAndMapping ConstructFallbackGraph(
288289}
289290
290291
291- void MapInputsAndDetermineDTypes (CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, const util::InputTypeMap & first_use_type_map) {
292+ void MapInputsAndDetermineDTypes (CompileSpec& cfg, std::shared_ptr<torch::jit::Graph>& g, ir::StaticParams& static_params, ir::TypeMap & first_use_type_map) {
292293 // Associate input specs with inputs
293294 cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
294295
@@ -303,9 +304,31 @@ void MapInputsAndDetermineDTypes(CompileSpec& cfg, std::shared_ptr<torch::jit::G
303304 } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
304305 // If we cannot calculate the type and the user did not define the type, then default to FP32
305306 LOG_WARNING (
306- " Cannot deterime input type from calcuations in graph for input "
307+ " Cannot infer input type from calcuations in graph for input "
307308 << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
308309 spec.dtype = nvinfer1::DataType::kFLOAT ;
310+ } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
311+ if (!est_type_opt) {
312+ LOG_INFO (" Cannot infer input tensor dtype in graph, unable to verify user input dtype settings" );
313+ } else {
314+ if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
315+ std::stringstream ss;
316+ ss <<" For input " << in->debugName () << " , found user specified input dtype as " ;
317+ ss << cfg.convert_info .inputs .find (in)->second .dtype ;
318+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
319+ ss << est_type_opt.value () << std::endl;
320+ ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
321+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
322+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
323+ ss << " If you do indeed see errors at runtime either:\n " ;
324+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
325+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
326+ auto warn_str = ss.str ();
327+ LOG_WARNING (warn_str);
328+ // Overwrite type map with user settings
329+ first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
330+ }
331+ }
309332 } else {
310333 // The user defined the type so no changes are necessary
311334 }
@@ -317,10 +340,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
317340 auto graph_and_parameters = lowering::Lower (mod, method_name, cfg.lower_info );
318341
319342 auto g = graph_and_parameters.first ;
343+ TRTORCH_CHECK (conversion::VerifyConverterSupportForBlock (g->block ()), " Not all operations in graph are supported by the compiler" );
320344 auto params = graph_and_parameters.second ;
321345 auto static_params = ir::get_static_params (g->inputs (), params);
322346 // Infer the type of an input from the weights of the calculation
323- auto first_use_types = util ::get_block_first_calc_dtypes_opt (g->block ());
347+ auto first_use_types = ir ::get_block_first_calc_dtypes_opt (g->block ());
324348
325349 MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
326350
@@ -357,11 +381,21 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
357381 auto params = graph_and_parameters.second ;
358382 auto static_params = ir::get_static_params (g->inputs (), params);
359383 // Infer the type of an input from the weights of the calculation
360- auto first_use_types = util ::get_block_first_calc_dtypes_opt (g->block ());
384+ auto first_use_types = ir ::get_block_first_calc_dtypes_opt (g->block ());
361385
362386 MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
363387
364- if (cfg.partition_info .enabled ) {
388+ if (cfg.partition_info .enabled
389+ && (cfg.lower_info .forced_fallback_modules .size () == 0
390+ && cfg.partition_info .forced_fallback_operators .size () == 0
391+ && conversion::VerifyConverterSupportForBlock (g->block (), true ))) {
392+ LOG_INFO (" Skipping partitioning since model is fully supported" );
393+ }
394+
395+ if (cfg.partition_info .enabled
396+ && !(cfg.lower_info .forced_fallback_modules .size () == 0
397+ && cfg.partition_info .forced_fallback_operators .size () == 0
398+ && conversion::VerifyConverterSupportForBlock (g->block (), false ))) {
365399 auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
366400 auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map, cfg, static_params);
367401 new_g = graph_and_mapping.first ;
@@ -374,6 +408,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
374408 return mod;
375409 }
376410 } else {
411+ TRTORCH_CHECK (conversion::VerifyConverterSupportForBlock (g->block ()), " Not all operations in graph are supported by the compiler" );
377412 auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
378413 auto device_spec = cfg.convert_info .engine_settings .device ;
379414 auto cuda_device = runtime::CudaDevice (device_spec.gpu_id , device_spec.device_type );
0 commit comments