Skip to content

Commit 75ea900

Browse files
committed
update
1 parent a4c35fb commit 75ea900

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

onnxscript/ir/serde.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -625,9 +625,6 @@ def _deserialize_graph(
625625
if value.name in quantization_annotations:
626626
_deserialize_quantization_annotation(quantization_annotations[value.name], value)
627627

628-
# Build the value info dictionary to allow for quick lookup for this graph scope
629-
value_info = {info.name: info for info in proto.value_info}
630-
631628
# Initialize the values dictionary for this graph scope with the inputs and initializers
632629
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
633630

@@ -653,22 +650,22 @@ def _deserialize_graph(
653650
None,
654651
index=None,
655652
name=initializer_name,
656-
# Do not include shape or type as we need to respect the ONNX file
657-
# if the shape or type is not provided as ValueInfoProto
658-
# The shape/type information will be filled in in the subsequent ValueInfoProto
659-
# deserialization step (deserialize_value_info_proto)
653+
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
654+
# Users expect initialized values to have shape and type information.
655+
type=_core.TensorType(tensor.dtype),
656+
shape=tensor.shape, # type: ignore[arg-type]
660657
const_value=tensor,
661658
)
662-
if initializer_name in value_info:
663-
# This is where we fill in the shape and type information for the initializer
664-
deserialize_value_info_proto(value_info[initializer_name], initializer_value)
665659
if initializer_value.name in quantization_annotations:
666660
_deserialize_quantization_annotation(
667661
quantization_annotations[initializer_value.name], initializer_value
668662
)
669663
values[initializer_name] = initializer_value
670664
initializer_values.append(initializer_value)
671665

666+
# Build the value info dictionary to allow for quick lookup for this graph scope
667+
value_info = {info.name: info for info in proto.value_info}
668+
672669
# Deserialize nodes with all known values
673670
nodes = [
674671
_deserialize_node(node, scoped_values, value_info, quantization_annotations)

0 commit comments

Comments
 (0)