@@ -528,39 +528,32 @@ def load_onnx_weights_and_quant(path, config):
528528 return weights_dict
529529
530530def emb_layernorm (builder , network , config , weights_dict , builder_config , sequence_lengths , batch_sizes ):
531- if len (batch_sizes ) > 1 or len (sequence_lengths ) > 1 :
532- # int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
533- input_ids = network .add_input (name = "input_ids" , dtype = trt .int32 , shape = (- 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ], - 1 if len (batch_sizes ) > 1 else batch_sizes [0 ]))
534- segment_ids = network .add_input (name = "segment_ids" , dtype = trt .int32 , shape = (- 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ], - 1 if len (batch_sizes ) > 1 else batch_sizes [0 ]))
535- input_mask = network .add_input (name = "input_mask" , dtype = trt .int32 , shape = (- 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ], - 1 if len (batch_sizes ) > 1 else batch_sizes [0 ]))
536-
537- # Specify profiles for the batch sizes we're interested in.
538- # Make sure the profile also works for all sizes not covered by the previous profile.
539- prev_batch_size = 0
540- for batch_size in sorted (batch_sizes ):
541- if len (sequence_lengths ) == 1 :
542- min_shape = (sequence_lengths [0 ], prev_batch_size + 1 )
543- shape = (sequence_lengths [0 ], batch_size )
531+ # int8 only support some of the sequence length, we dynamic on sequence length is not allowed.
532+ input_ids = network .add_input (name = "input_ids" , dtype = trt .int32 , shape = (- 1 , - 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ]))
533+ segment_ids = network .add_input (name = "segment_ids" , dtype = trt .int32 , shape = (- 1 , - 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ]))
534+ input_mask = network .add_input (name = "input_mask" , dtype = trt .int32 , shape = (- 1 , - 1 if len (sequence_lengths ) > 1 else sequence_lengths [0 ]))
535+
536+ # Specify profiles for the batch sizes we're interested in.
537+ # Make sure the profile also works for all sizes not covered by the previous profile.
538+
539+ for batch_size in sorted (batch_sizes ):
540+ if len (sequence_lengths ) == 1 :
541+ profile = builder .create_optimization_profile ()
542+ min_shape = (1 , sequence_lengths [0 ])
543+ shape = (batch_size , sequence_lengths [0 ])
544+ profile .set_shape ("input_ids" , min = min_shape , opt = shape , max = shape )
545+ profile .set_shape ("segment_ids" , min = min_shape , opt = shape , max = shape )
546+ profile .set_shape ("input_mask" , min = min_shape , opt = shape , max = shape )
547+ builder_config .add_optimization_profile (profile )
548+ else :
549+ for sequence_length in sorted (sequence_lengths ):
550+ profile = builder .create_optimization_profile ()
551+ min_shape = (1 , sequence_length )
552+ shape = (batch_size , sequence_length )
544553 profile .set_shape ("input_ids" , min = min_shape , opt = shape , max = shape )
545554 profile .set_shape ("segment_ids" , min = min_shape , opt = shape , max = shape )
546555 profile .set_shape ("input_mask" , min = min_shape , opt = shape , max = shape )
547556 builder_config .add_optimization_profile (profile )
548- else :
549- prev_sequence_length = 0
550- for sequence_length in sorted (sequence_lengths ):
551- profile = builder .create_optimization_profile ()
552- min_shape = (prev_sequence_length + 1 , prev_batch_size + 1 )
553- shape = (sequence_length , batch_size )
554- profile .set_shape ("input_ids" , min = min_shape , opt = shape , max = shape )
555- profile .set_shape ("segment_ids" , min = min_shape , opt = shape , max = shape )
556- profile .set_shape ("input_mask" , min = min_shape , opt = shape , max = shape )
557- builder_config .add_optimization_profile (profile )
558- prev_sequence_length = sequence_length
559- prev_batch_size = batch_size
560- else :
561- input_ids = network .add_input (name = "input_ids" , dtype = trt .int32 , shape = (sequence_lengths [0 ], batch_sizes [0 ]))
562- segment_ids = network .add_input (name = "segment_ids" , dtype = trt .int32 , shape = (sequence_lengths [0 ], batch_sizes [0 ]))
563- input_mask = network .add_input (name = "input_mask" , dtype = trt .int32 , shape = (sequence_lengths [0 ], batch_sizes [0 ]))
564557
565558 wbeta = trt .PluginField ("bert_embeddings_layernorm_beta" , weights_dict ["bert_embeddings_layernorm_beta" ].numpy (), trt .PluginFieldType .FLOAT32 )
566559 wgamma = trt .PluginField ("bert_embeddings_layernorm_gamma" , weights_dict ["bert_embeddings_layernorm_gamma" ].numpy (), trt .PluginFieldType .FLOAT32 )
@@ -574,7 +567,15 @@ def emb_layernorm(builder, network, config, weights_dict, builder_config, sequen
574567 pfc = trt .PluginFieldCollection ([wbeta , wgamma , wwordemb , wtokemb , wposemb , output_fp16 , mha_type ])
575568 fn = emln_plg_creator .create_plugin ("embeddings" , pfc )
576569
577- inputs = [input_ids , segment_ids , input_mask ]
570+ input_ids = network .add_shuffle (input_ids )
571+ input_ids .second_transpose = (1 , 0 )
572+ segment_ids = network .add_shuffle (segment_ids )
573+ segment_ids .second_transpose = (1 , 0 )
574+ input_mask = network .add_shuffle (input_mask )
575+ input_mask .second_transpose = (1 , 0 )
576+ inputs = [input_ids .get_output (0 ),
577+ segment_ids .get_output (0 ),
578+ input_mask .get_output (0 )]
578579 emb_layer = network .add_plugin_v2 (inputs , fn )
579580
580581 if config .use_qat :
0 commit comments