Skip to content

Commit e24e489

Browse files
authored
Update proto comparison error message (#2215)
1 parent 8e0e86b commit e24e489

2 files changed

Lines changed: 99 additions & 46 deletions

File tree

onnxscript/testing/__init__.py

Lines changed: 93 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]:
374374

375375

376376
def assert_onnx_proto_equal(
377-
a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any
377+
actual: google.protobuf.message.Message | Any,
378+
expected: google.protobuf.message.Message | Any,
379+
ignore_initializer_value_proto: bool = False,
378380
) -> None:
379381
"""Assert that two ONNX protos are equal.
380382
@@ -386,18 +388,31 @@ def assert_onnx_proto_equal(
386388
compared disregarding the order of their elements.
387389
388390
Args:
389-
a: The first ONNX proto.
390-
b: The second ONNX proto.
391+
actual: The first ONNX proto.
392+
expected: The second ONNX proto.
393+
ignore_initializer_value_proto: Ignore value protos for initializers if there
394+
are extra ones in the actual proto.
391395
"""
392-
assert type(a) is type(b), f"Type not equal: {type(a)} != {type(b)}"
396+
assert type(actual) is type(expected), (
397+
f"Type not equal: {type(actual)} != {type(expected)}"
398+
)
393399

394-
a_fields = {field.name: value for field, value in a.ListFields()}
395-
b_fields = {field.name: value for field, value in b.ListFields()}
400+
a_fields = {field.name: value for field, value in actual.ListFields()}
401+
b_fields = {field.name: value for field, value in expected.ListFields()}
396402
all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys()))
397-
for field in all_fields:
403+
if isinstance(actual, onnx.GraphProto) and isinstance(expected, onnx.GraphProto):
404+
actual_initializer_names = {i.name for i in actual.initializer}
405+
expected_initializer_names = {i.name for i in expected.initializer}
406+
else:
407+
actual_initializer_names = set()
408+
expected_initializer_names = set()
409+
410+
# Record and report all errors
411+
errors = []
412+
for field in all_fields: # pylint: disable=too-many-nested-blocks
398413
# Obtain the default value if the field is not set. This way we can compare the two fields.
399-
a_value = getattr(a, field)
400-
b_value = getattr(b, field)
414+
a_value = getattr(actual, field)
415+
b_value = getattr(expected, field)
401416
if (
402417
isinstance(a_value, Sequence)
403418
and isinstance(b_value, Sequence)
@@ -413,6 +428,22 @@ def assert_onnx_proto_equal(
413428
a_keys = [_opset_import_key(opset_import) for opset_import in a_value]
414429
b_keys = [_opset_import_key(opset_import) for opset_import in b_value]
415430
elif field == "value_info":
431+
if (
432+
ignore_initializer_value_proto
433+
and isinstance(actual, onnx.GraphProto)
434+
and isinstance(expected, onnx.GraphProto)
435+
):
436+
# Filter out initializers from the value_info list
437+
a_value = [
438+
value_info
439+
for value_info in a_value
440+
if value_info.name not in actual_initializer_names
441+
]
442+
b_value = [
443+
value_info
444+
for value_info in b_value
445+
if value_info.name not in expected_initializer_names
446+
]
416447
a_value = sorted(a_value, key=_value_info_key)
417448
b_value = sorted(b_value, key=_value_info_key)
418449
a_keys = [_value_info_key(value_info) for value_info in a_value]
@@ -424,51 +455,62 @@ def assert_onnx_proto_equal(
424455
b_keys = [_function_key(functions) for functions in b_value]
425456

426457
if a_keys != b_keys:
427-
keys_only_in_a = set(a_keys) - set(b_keys)
428-
keys_only_in_b = set(b_keys) - set(a_keys)
458+
keys_only_in_actual = set(a_keys) - set(b_keys)
459+
keys_only_in_expected = set(b_keys) - set(a_keys)
429460
error_message = (
430-
f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. "
461+
f"Field {field} not equal: keys_only_in_actual={keys_only_in_actual}, keys_only_in_expected={keys_only_in_expected}. "
431462
f"Field type: {type(a_value)}. "
432463
f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}"
433464
)
434-
raise AssertionError(error_message)
435-
if len(a_value) != len(b_value):
465+
errors.append(error_message)
466+
elif len(a_value) != len(b_value):
436467
error_message = (
437468
f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} "
438469
f"Field type: {type(a_value)}"
439470
)
440-
raise AssertionError(error_message)
441-
# Check every element
442-
for i in range(len(a_value)): # pylint: disable=consider-using-enumerate
443-
a_value_i = a_value[i]
444-
b_value_i = b_value[i]
445-
if isinstance(a_value_i, google.protobuf.message.Message) and isinstance(
446-
b_value_i, google.protobuf.message.Message
447-
):
448-
try:
449-
assert_onnx_proto_equal(a_value_i, b_value_i)
450-
except AssertionError as e:
451-
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}"
452-
raise AssertionError(error_message) from e
453-
elif a_value_i != b_value_i:
454-
if (
455-
isinstance(a_value_i, float)
456-
and isinstance(b_value_i, float)
457-
and math.isnan(a_value_i)
458-
and math.isnan(b_value_i)
459-
):
460-
# Consider NaNs equal
461-
continue
462-
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}"
463-
for line in difflib.ndiff(
464-
str(a_value_i).splitlines(), str(b_value_i).splitlines()
465-
):
466-
error_message += "\n" + line
467-
raise AssertionError(error_message)
471+
errors.append(error_message)
472+
else:
473+
# Check every element
474+
for i in range(len(a_value)): # pylint: disable=consider-using-enumerate
475+
actual_value_i = a_value[i]
476+
expected_value_i = b_value[i]
477+
if isinstance(
478+
actual_value_i, google.protobuf.message.Message
479+
) and isinstance(expected_value_i, google.protobuf.message.Message):
480+
try:
481+
assert_onnx_proto_equal(
482+
actual_value_i,
483+
expected_value_i,
484+
ignore_initializer_value_proto=ignore_initializer_value_proto,
485+
)
486+
except AssertionError as e:
487+
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}"
488+
error_message = (
489+
str(e) + "\n\nCaused by the above error\n\n" + error_message
490+
)
491+
errors.append(error_message)
492+
elif actual_value_i != expected_value_i:
493+
if (
494+
isinstance(actual_value_i, float)
495+
and isinstance(expected_value_i, float)
496+
and math.isnan(actual_value_i)
497+
and math.isnan(expected_value_i)
498+
):
499+
# Consider NaNs equal
500+
continue
501+
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}"
502+
for line in difflib.ndiff(
503+
str(actual_value_i).splitlines(),
504+
str(expected_value_i).splitlines(),
505+
):
506+
error_message += "\n" + line
507+
errors.append(error_message)
468508
elif isinstance(a_value, google.protobuf.message.Message) and isinstance(
469509
b_value, google.protobuf.message.Message
470510
):
471-
assert_onnx_proto_equal(a_value, b_value)
511+
assert_onnx_proto_equal(
512+
a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto
513+
)
472514
elif a_value != b_value:
473515
if (
474516
isinstance(a_value, float)
@@ -478,5 +520,11 @@ def assert_onnx_proto_equal(
478520
):
479521
# Consider NaNs equal
480522
continue
481-
error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}"
482-
raise AssertionError(error_message)
523+
error_message = (
524+
f"Field {field} not equal. field_actual: {a_value}, field_expected: {b_value}"
525+
)
526+
errors.append(error_message)
527+
if errors:
528+
raise AssertionError(
529+
f"Protos not equal: {type(actual)} != {type(expected)}\n" + "\n".join(errors)
530+
)

tools/ir/model_zoo_test/model_zoo_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import traceback
1919

2020
import onnx
21+
import onnxruntime as ort
2122
import tqdm
2223
from onnx import hub
2324

@@ -42,8 +43,12 @@ def test_model(model_info: hub.ModelInfo) -> float:
4243
ir_model = ir.serde.deserialize_model(model)
4344
serialized = ir.serde.serialize_model(ir_model)
4445
end = time.time()
45-
onnxscript.testing.assert_onnx_proto_equal(serialized, model)
46+
onnxscript.testing.assert_onnx_proto_equal(
47+
serialized, model, ignore_initializer_value_proto=True
48+
)
4649
onnx.checker.check_model(serialized)
50+
# Check the model can be loaded with onnxruntime
51+
ort.InferenceSession(serialized.SerializeToString())
4752
return end - start
4853

4954

0 commit comments

Comments
 (0)