Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 93 additions & 45 deletions onnxscript/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@


def assert_onnx_proto_equal(
a: google.protobuf.message.Message | Any, b: google.protobuf.message.Message | Any
actual: google.protobuf.message.Message | Any,
expected: google.protobuf.message.Message | Any,
ignore_initializer_value_proto: bool = False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ignore-initializer-value-proto seems a very specialized niche-case ... why/where exactly do you need it? Wonder if there is some more general-purpose mechanism to achieve that ... I don't have any suggestions at the moment though.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are models in the onnx model zoo that does not have value info proto for initializers, but roundtripping will add those protos to the model. So I created a flag to ignore them. It is used in tools/ir/model_zoo_test/model_zoo_test.py

) -> None:
"""Assert that two ONNX protos are equal.

Expand All @@ -386,18 +388,31 @@
compared disregarding the order of their elements.

Args:
a: The first ONNX proto.
b: The second ONNX proto.
actual: The first ONNX proto.
expected: The second ONNX proto.
ignore_initializer_value_proto: Ignore value protos for initializers if there
are extra ones in the actual proto.
"""
assert type(a) is type(b), f"Type not equal: {type(a)} != {type(b)}"
assert type(actual) is type(expected), (

Check warning on line 396 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L396

Added line #L396 was not covered by tests
f"Type not equal: {type(actual)} != {type(expected)}"
)

a_fields = {field.name: value for field, value in a.ListFields()}
b_fields = {field.name: value for field, value in b.ListFields()}
a_fields = {field.name: value for field, value in actual.ListFields()}
b_fields = {field.name: value for field, value in expected.ListFields()}

Check warning on line 401 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L400-L401

Added lines #L400 - L401 were not covered by tests
all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys()))
for field in all_fields:
if isinstance(actual, onnx.GraphProto) and isinstance(expected, onnx.GraphProto):
actual_initializer_names = {i.name for i in actual.initializer}
expected_initializer_names = {i.name for i in expected.initializer}

Check warning on line 405 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L404-L405

Added lines #L404 - L405 were not covered by tests
else:
actual_initializer_names = set()
expected_initializer_names = set()

Check warning on line 408 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L407-L408

Added lines #L407 - L408 were not covered by tests

# Record and report all errors
errors = []

Check warning on line 411 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L411

Added line #L411 was not covered by tests
for field in all_fields: # pylint: disable=too-many-nested-blocks
# Obtain the default value if the field is not set. This way we can compare the two fields.
a_value = getattr(a, field)
b_value = getattr(b, field)
a_value = getattr(actual, field)
b_value = getattr(expected, field)

Check warning on line 415 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L414-L415

Added lines #L414 - L415 were not covered by tests
if (
isinstance(a_value, Sequence)
and isinstance(b_value, Sequence)
Expand All @@ -413,6 +428,22 @@
a_keys = [_opset_import_key(opset_import) for opset_import in a_value]
b_keys = [_opset_import_key(opset_import) for opset_import in b_value]
elif field == "value_info":
if (
ignore_initializer_value_proto
and isinstance(actual, onnx.GraphProto)
and isinstance(expected, onnx.GraphProto)
):
# Filter out initializers from the value_info list
a_value = [

Check warning on line 437 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L437

Added line #L437 was not covered by tests
value_info
for value_info in a_value
if value_info.name not in actual_initializer_names
]
b_value = [

Check warning on line 442 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L442

Added line #L442 was not covered by tests
value_info
for value_info in b_value
if value_info.name not in expected_initializer_names
]
a_value = sorted(a_value, key=_value_info_key)
b_value = sorted(b_value, key=_value_info_key)
a_keys = [_value_info_key(value_info) for value_info in a_value]
Expand All @@ -424,51 +455,62 @@
b_keys = [_function_key(functions) for functions in b_value]

if a_keys != b_keys:
keys_only_in_a = set(a_keys) - set(b_keys)
keys_only_in_b = set(b_keys) - set(a_keys)
keys_only_in_actual = set(a_keys) - set(b_keys)
keys_only_in_expected = set(b_keys) - set(a_keys)

Check warning on line 459 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L458-L459

Added lines #L458 - L459 were not covered by tests
error_message = (
f"Field {field} not equal: keys_only_in_a={keys_only_in_a}, keys_only_in_b={keys_only_in_b}. "
f"Field {field} not equal: keys_only_in_actual={keys_only_in_actual}, keys_only_in_expected={keys_only_in_expected}. "
f"Field type: {type(a_value)}. "
f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}"
)
raise AssertionError(error_message)
if len(a_value) != len(b_value):
errors.append(error_message)

Check warning on line 465 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L465

Added line #L465 was not covered by tests
elif len(a_value) != len(b_value):
error_message = (
f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} "
f"Field type: {type(a_value)}"
)
raise AssertionError(error_message)
# Check every element
for i in range(len(a_value)): # pylint: disable=consider-using-enumerate
a_value_i = a_value[i]
b_value_i = b_value[i]
if isinstance(a_value_i, google.protobuf.message.Message) and isinstance(
b_value_i, google.protobuf.message.Message
):
try:
assert_onnx_proto_equal(a_value_i, b_value_i)
except AssertionError as e:
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}, a_value_i: {a_value_i}, b_value_i: {b_value_i}"
raise AssertionError(error_message) from e
elif a_value_i != b_value_i:
if (
isinstance(a_value_i, float)
and isinstance(b_value_i, float)
and math.isnan(a_value_i)
and math.isnan(b_value_i)
):
# Consider NaNs equal
continue
error_message = f"Field {field} index {i} in sequence not equal. type(a_value_i): {type(a_value_i)}, type(b_value_i): {type(b_value_i)}"
for line in difflib.ndiff(
str(a_value_i).splitlines(), str(b_value_i).splitlines()
):
error_message += "\n" + line
raise AssertionError(error_message)
errors.append(error_message)

Check warning on line 471 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L471

Added line #L471 was not covered by tests
else:
# Check every element
for i in range(len(a_value)): # pylint: disable=consider-using-enumerate
actual_value_i = a_value[i]
expected_value_i = b_value[i]

Check warning on line 476 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L475-L476

Added lines #L475 - L476 were not covered by tests
if isinstance(
actual_value_i, google.protobuf.message.Message
) and isinstance(expected_value_i, google.protobuf.message.Message):
try:
assert_onnx_proto_equal(

Check warning on line 481 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L480-L481

Added lines #L480 - L481 were not covered by tests
actual_value_i,
expected_value_i,
ignore_initializer_value_proto=ignore_initializer_value_proto,
)
except AssertionError as e:
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}"
error_message = (

Check warning on line 488 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L486-L488

Added lines #L486 - L488 were not covered by tests
str(e) + "\n\nCaused by the above error\n\n" + error_message
)
errors.append(error_message)

Check warning on line 491 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L491

Added line #L491 was not covered by tests
elif actual_value_i != expected_value_i:
if (
isinstance(actual_value_i, float)
and isinstance(expected_value_i, float)
and math.isnan(actual_value_i)
and math.isnan(expected_value_i)
):
# Consider NaNs equal
continue
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}"

Check warning on line 501 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L500-L501

Added lines #L500 - L501 were not covered by tests
for line in difflib.ndiff(
str(actual_value_i).splitlines(),
str(expected_value_i).splitlines(),
):
error_message += "\n" + line
errors.append(error_message)

Check warning on line 507 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L506-L507

Added lines #L506 - L507 were not covered by tests
elif isinstance(a_value, google.protobuf.message.Message) and isinstance(
b_value, google.protobuf.message.Message
):
assert_onnx_proto_equal(a_value, b_value)
assert_onnx_proto_equal(

Check warning on line 511 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L511

Added line #L511 was not covered by tests
a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto
)
elif a_value != b_value:
if (
isinstance(a_value, float)
Expand All @@ -478,5 +520,11 @@
):
# Consider NaNs equal
continue
error_message = f"Field {field} not equal. field_a: {a_value}, field_b: {b_value}"
raise AssertionError(error_message)
error_message = (

Check warning on line 523 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L523

Added line #L523 was not covered by tests
f"Field {field} not equal. field_actual: {a_value}, field_expected: {b_value}"
)
errors.append(error_message)

Check warning on line 526 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L526

Added line #L526 was not covered by tests
if errors:
raise AssertionError(

Check warning on line 528 in onnxscript/testing/__init__.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/testing/__init__.py#L528

Added line #L528 was not covered by tests
f"Protos not equal: {type(actual)} != {type(expected)}\n" + "\n".join(errors)
)
7 changes: 6 additions & 1 deletion tools/ir/model_zoo_test/model_zoo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import traceback

import onnx
import onnxruntime as ort
import tqdm
from onnx import hub

Expand All @@ -42,8 +43,12 @@ def test_model(model_info: hub.ModelInfo) -> float:
ir_model = ir.serde.deserialize_model(model)
serialized = ir.serde.serialize_model(ir_model)
end = time.time()
onnxscript.testing.assert_onnx_proto_equal(serialized, model)
onnxscript.testing.assert_onnx_proto_equal(
serialized, model, ignore_initializer_value_proto=True
)
onnx.checker.check_model(serialized)
# Check the model can be loaded with onnxruntime
ort.InferenceSession(serialized.SerializeToString())
return end - start


Expand Down
Loading