Skip to content

Commit 25dffe6

Browse files
committed
Fix test
1 parent 31409d4 commit 25dffe6

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

onnxscript/testing/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,11 @@ def assert_onnx_proto_equal(
428428
a_keys = [_opset_import_key(opset_import) for opset_import in a_value]
429429
b_keys = [_opset_import_key(opset_import) for opset_import in b_value]
430430
elif field == "value_info":
431-
if ignore_initializer_value_proto:
431+
if (
432+
ignore_initializer_value_proto
433+
and isinstance(actual, onnx.GraphProto)
434+
and isinstance(expected, onnx.GraphProto)
435+
):
432436
# Filter out initializers from the value_info list
433437
a_value = [
434438
value_info
@@ -474,7 +478,11 @@ def assert_onnx_proto_equal(
474478
actual_value_i, google.protobuf.message.Message
475479
) and isinstance(expected_value_i, google.protobuf.message.Message):
476480
try:
477-
assert_onnx_proto_equal(actual_value_i, expected_value_i)
481+
assert_onnx_proto_equal(
482+
actual_value_i,
483+
expected_value_i,
484+
ignore_initializer_value_proto=ignore_initializer_value_proto,
485+
)
478486
except AssertionError as e:
479487
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}"
480488
error_message = (
@@ -500,7 +508,9 @@ def assert_onnx_proto_equal(
500508
elif isinstance(a_value, google.protobuf.message.Message) and isinstance(
501509
b_value, google.protobuf.message.Message
502510
):
503-
assert_onnx_proto_equal(a_value, b_value)
511+
assert_onnx_proto_equal(
512+
a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto
513+
)
504514
elif a_value != b_value:
505515
if (
506516
isinstance(a_value, float)

0 commit comments

Comments
 (0)