@@ -374,7 +374,9 @@ def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]:
374374
375375
376376def assert_onnx_proto_equal (
377- a : google .protobuf .message .Message | Any , b : google .protobuf .message .Message | Any
377+ actual : google .protobuf .message .Message | Any ,
378+ expected : google .protobuf .message .Message | Any ,
379+ ignore_initializer_value_proto : bool = False ,
378380) -> None :
379381 """Assert that two ONNX protos are equal.
380382
@@ -386,18 +388,31 @@ def assert_onnx_proto_equal(
386388 compared disregarding the order of their elements.
387389
388390 Args:
389- a: The first ONNX proto.
390- b: The second ONNX proto.
391+ actual: The first ONNX proto.
392+ expected: The second ONNX proto.
393+ ignore_initializer_value_proto: Ignore value protos for initializers if there
394+ are extra ones in the actual proto.
391395 """
392- assert type (a ) is type (b ), f"Type not equal: { type (a )} != { type (b )} "
396+ assert type (actual ) is type (expected ), (
397+ f"Type not equal: { type (actual )} != { type (expected )} "
398+ )
393399
394- a_fields = {field .name : value for field , value in a .ListFields ()}
395- b_fields = {field .name : value for field , value in b .ListFields ()}
400+ a_fields = {field .name : value for field , value in actual .ListFields ()}
401+ b_fields = {field .name : value for field , value in expected .ListFields ()}
396402 all_fields = sorted (set (a_fields .keys ()) | set (b_fields .keys ()))
397- for field in all_fields :
403+ if isinstance (actual , onnx .GraphProto ) and isinstance (expected , onnx .GraphProto ):
404+ actual_initializer_names = {i .name for i in actual .initializer }
405+ expected_initializer_names = {i .name for i in expected .initializer }
406+ else :
407+ actual_initializer_names = set ()
408+ expected_initializer_names = set ()
409+
410+ # Record and report all errors
411+ errors = []
412+ for field in all_fields : # pylint: disable=too-many-nested-blocks
398413 # Obtain the default value if the field is not set. This way we can compare the two fields.
399- a_value = getattr (a , field )
400- b_value = getattr (b , field )
414+ a_value = getattr (actual , field )
415+ b_value = getattr (expected , field )
401416 if (
402417 isinstance (a_value , Sequence )
403418 and isinstance (b_value , Sequence )
@@ -413,6 +428,22 @@ def assert_onnx_proto_equal(
413428 a_keys = [_opset_import_key (opset_import ) for opset_import in a_value ]
414429 b_keys = [_opset_import_key (opset_import ) for opset_import in b_value ]
415430 elif field == "value_info" :
431+ if (
432+ ignore_initializer_value_proto
433+ and isinstance (actual , onnx .GraphProto )
434+ and isinstance (expected , onnx .GraphProto )
435+ ):
436+ # Filter out initializers from the value_info list
437+ a_value = [
438+ value_info
439+ for value_info in a_value
440+ if value_info .name not in actual_initializer_names
441+ ]
442+ b_value = [
443+ value_info
444+ for value_info in b_value
445+ if value_info .name not in expected_initializer_names
446+ ]
416447 a_value = sorted (a_value , key = _value_info_key )
417448 b_value = sorted (b_value , key = _value_info_key )
418449 a_keys = [_value_info_key (value_info ) for value_info in a_value ]
@@ -424,51 +455,62 @@ def assert_onnx_proto_equal(
424455 b_keys = [_function_key (functions ) for functions in b_value ]
425456
426457 if a_keys != b_keys :
427- keys_only_in_a = set (a_keys ) - set (b_keys )
428- keys_only_in_b = set (b_keys ) - set (a_keys )
458+ keys_only_in_actual = set (a_keys ) - set (b_keys )
459+ keys_only_in_expected = set (b_keys ) - set (a_keys )
429460 error_message = (
430- f"Field { field } not equal: keys_only_in_a= { keys_only_in_a } , keys_only_in_b= { keys_only_in_b } . "
461+ f"Field { field } not equal: keys_only_in_actual= { keys_only_in_actual } , keys_only_in_expected= { keys_only_in_expected } . "
431462 f"Field type: { type (a_value )} . "
432463 f"Duplicated a_keys: { _find_duplicates (a_keys )} , duplicated b_keys: { _find_duplicates (b_keys )} "
433464 )
434- raise AssertionError (error_message )
435- if len (a_value ) != len (b_value ):
465+ errors . append (error_message )
466+ elif len (a_value ) != len (b_value ):
436467 error_message = (
437468 f"Field { field } not equal: len(a)={ len (a_value )} , len(b)={ len (b_value )} "
438469 f"Field type: { type (a_value )} "
439470 )
440- raise AssertionError (error_message )
441- # Check every element
442- for i in range (len (a_value )): # pylint: disable=consider-using-enumerate
443- a_value_i = a_value [i ]
444- b_value_i = b_value [i ]
445- if isinstance (a_value_i , google .protobuf .message .Message ) and isinstance (
446- b_value_i , google .protobuf .message .Message
447- ):
448- try :
449- assert_onnx_proto_equal (a_value_i , b_value_i )
450- except AssertionError as e :
451- error_message = f"Field { field } index { i } in sequence not equal. type(a_value_i): { type (a_value_i )} , type(b_value_i): { type (b_value_i )} , a_value_i: { a_value_i } , b_value_i: { b_value_i } "
452- raise AssertionError (error_message ) from e
453- elif a_value_i != b_value_i :
454- if (
455- isinstance (a_value_i , float )
456- and isinstance (b_value_i , float )
457- and math .isnan (a_value_i )
458- and math .isnan (b_value_i )
459- ):
460- # Consider NaNs equal
461- continue
462- error_message = f"Field { field } index { i } in sequence not equal. type(a_value_i): { type (a_value_i )} , type(b_value_i): { type (b_value_i )} "
463- for line in difflib .ndiff (
464- str (a_value_i ).splitlines (), str (b_value_i ).splitlines ()
465- ):
466- error_message += "\n " + line
467- raise AssertionError (error_message )
471+ errors .append (error_message )
472+ else :
473+ # Check every element
474+ for i in range (len (a_value )): # pylint: disable=consider-using-enumerate
475+ actual_value_i = a_value [i ]
476+ expected_value_i = b_value [i ]
477+ if isinstance (
478+ actual_value_i , google .protobuf .message .Message
479+ ) and isinstance (expected_value_i , google .protobuf .message .Message ):
480+ try :
481+ assert_onnx_proto_equal (
482+ actual_value_i ,
483+ expected_value_i ,
484+ ignore_initializer_value_proto = ignore_initializer_value_proto ,
485+ )
486+ except AssertionError as e :
487+ error_message = f"Field { field } index { i } in sequence not equal. type(actual_value_i): { type (actual_value_i )} , type(expected_value_i): { type (expected_value_i )} , actual_value_i: { actual_value_i } , expected_value_i: { expected_value_i } "
488+ error_message = (
489+ str (e ) + "\n \n Caused by the above error\n \n " + error_message
490+ )
491+ errors .append (error_message )
492+ elif actual_value_i != expected_value_i :
493+ if (
494+ isinstance (actual_value_i , float )
495+ and isinstance (expected_value_i , float )
496+ and math .isnan (actual_value_i )
497+ and math .isnan (expected_value_i )
498+ ):
499+ # Consider NaNs equal
500+ continue
501+ error_message = f"Field { field } index { i } in sequence not equal. type(actual_value_i): { type (actual_value_i )} , type(expected_value_i): { type (expected_value_i )} "
502+ for line in difflib .ndiff (
503+ str (actual_value_i ).splitlines (),
504+ str (expected_value_i ).splitlines (),
505+ ):
506+ error_message += "\n " + line
507+ errors .append (error_message )
468508 elif isinstance (a_value , google .protobuf .message .Message ) and isinstance (
469509 b_value , google .protobuf .message .Message
470510 ):
471- assert_onnx_proto_equal (a_value , b_value )
511+ assert_onnx_proto_equal (
512+ a_value , b_value , ignore_initializer_value_proto = ignore_initializer_value_proto
513+ )
472514 elif a_value != b_value :
473515 if (
474516 isinstance (a_value , float )
@@ -478,5 +520,11 @@ def assert_onnx_proto_equal(
478520 ):
479521 # Consider NaNs equal
480522 continue
481- error_message = f"Field { field } not equal. field_a: { a_value } , field_b: { b_value } "
482- raise AssertionError (error_message )
523+ error_message = (
524+ f"Field { field } not equal. field_actual: { a_value } , field_expected: { b_value } "
525+ )
526+ errors .append (error_message )
527+ if errors :
528+ raise AssertionError (
529+ f"Protos not equal: { type (actual )} != { type (expected )} \n " + "\n " .join (errors )
530+ )
0 commit comments