@@ -428,7 +428,11 @@ def assert_onnx_proto_equal(
428428 a_keys = [_opset_import_key (opset_import ) for opset_import in a_value ]
429429 b_keys = [_opset_import_key (opset_import ) for opset_import in b_value ]
430430 elif field == "value_info" :
431- if ignore_initializer_value_proto :
431+ if (
432+ ignore_initializer_value_proto
433+ and isinstance (actual , onnx .GraphProto )
434+ and isinstance (expected , onnx .GraphProto )
435+ ):
432436 # Filter out initializers from the value_info list
433437 a_value = [
434438 value_info
@@ -474,7 +478,11 @@ def assert_onnx_proto_equal(
474478 actual_value_i , google .protobuf .message .Message
475479 ) and isinstance (expected_value_i , google .protobuf .message .Message ):
476480 try :
477- assert_onnx_proto_equal (actual_value_i , expected_value_i )
481+ assert_onnx_proto_equal (
482+ actual_value_i ,
483+ expected_value_i ,
484+ ignore_initializer_value_proto = ignore_initializer_value_proto ,
485+ )
478486 except AssertionError as e :
479487 error_message = f"Field { field } index { i } in sequence not equal. type(actual_value_i): { type (actual_value_i )} , type(expected_value_i): { type (expected_value_i )} , actual_value_i: { actual_value_i } , expected_value_i: { expected_value_i } "
480488 error_message = (
@@ -500,7 +508,9 @@ def assert_onnx_proto_equal(
500508 elif isinstance (a_value , google .protobuf .message .Message ) and isinstance (
501509 b_value , google .protobuf .message .Message
502510 ):
503- assert_onnx_proto_equal (a_value , b_value )
511+ assert_onnx_proto_equal (
512+ a_value , b_value , ignore_initializer_value_proto = ignore_initializer_value_proto
513+ )
504514 elif a_value != b_value :
505515 if (
506516 isinstance (a_value , float )
0 commit comments