Skip to content

Commit 5291e85

Browse files
authored
Merge branch 'main' into justinchu/remove-legacy
2 parents aadd644 + 883a74f commit 5291e85

File tree

5 files changed

+45
-28
lines changed

5 files changed

+45
-28
lines changed

onnxscript/ir/serde.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -627,32 +627,43 @@ def _deserialize_graph(
627627

628628
# Initialize the values dictionary for this graph scope with the inputs and initializers
629629
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
630+
631+
# Enter the graph scope by pushing the values for this scope to the stack
630632
scoped_values.append(values)
633+
631634
initializer_values = []
632-
for tensor in initializer_tensors:
633-
if tensor.name in values:
635+
for i, tensor in enumerate(initializer_tensors):
636+
initializer_name = tensor.name
637+
if not initializer_name:
638+
logger.warning(
639+
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
640+
i,
641+
)
642+
continue
643+
if initializer_name in values:
634644
# The initializer is for an input
635-
initializer_value = values[tensor.name]
645+
initializer_value = values[initializer_name]
636646
initializer_value.const_value = tensor
637647
else:
638648
# The initializer is for some other value. Create this value first
639649
initializer_value = _core.Value(
640650
None,
641651
index=None,
642-
name=tensor.name,
643-
# TODO(justinchuby): Fix type hinting for shape and dtype
644-
shape=tensor.shape, # type: ignore
652+
name=initializer_name,
653+
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
654+
# Users expect initialized values to have shape and type information.
645655
type=_core.TensorType(tensor.dtype),
656+
shape=tensor.shape, # type: ignore[arg-type]
646657
const_value=tensor,
647658
)
648659
if initializer_value.name in quantization_annotations:
649660
_deserialize_quantization_annotation(
650661
quantization_annotations[initializer_value.name], initializer_value
651662
)
652-
values[tensor.name] = initializer_value # type: ignore[index]
663+
values[initializer_name] = initializer_value
653664
initializer_values.append(initializer_value)
654665

655-
# Add ValueInfos for this graph scope
666+
# Build the value info dictionary to allow for quick lookup for this graph scope
656667
value_info = {info.name: info for info in proto.value_info}
657668

658669
# Deserialize nodes with all known values
@@ -663,7 +674,10 @@ def _deserialize_graph(
663674

664675
# Fill in values for graph outputs
665676
outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
677+
678+
# Exit the graph scope by popping the values for this scope from the stack
666679
scoped_values.pop()
680+
667681
return _core.Graph(
668682
inputs,
669683
outputs,
@@ -1204,24 +1218,24 @@ def _serialize_opset_imports_into(
12041218
opset_ids.add(domain=domain, version=version)
12051219

12061220

1207-
def _serialize_metadata_props_into(
1221+
def _serialize_string_string_maps(
12081222
string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
12091223
onnx.StringStringEntryProto
12101224
],
12111225
from_: Mapping[str, str],
12121226
) -> None:
1213-
"""Serialize metadata properties into a repeated field of string-string entries.
1227+
"""Serialize a <str, str> mapping into a repeated field of string-string entries.
12141228
12151229
Args:
12161230
string_string_entries: The repeated field to serialize into.
1217-
from_: The mapping of metadata properties to serialize.
1231+
from_: The mapping of a <str, str> mapping to serialize.
12181232
"""
12191233
# Sort names for deterministic serialization
12201234
for key in sorted(from_):
12211235
string_string_entries.add(key=key, value=from_[key])
12221236

12231237

1224-
_serialize_string_string_maps = _serialize_metadata_props_into
1238+
_serialize_metadata_props_into = _serialize_string_string_maps
12251239

12261240

12271241
def _maybe_add_quantization_annotation(
@@ -1284,18 +1298,21 @@ def serialize_graph_into(
12841298
# TODO(justinchuby): We should add a method is_initializer() on Value when
12851299
# the initializer list is tracked
12861300
_maybe_add_quantization_annotation(graph_proto, input_)
1301+
input_names = {input_.name for input_ in from_.inputs}
12871302
# TODO(justinchuby): Support sparse_initializer
1288-
for initializer in from_.initializers.values():
1289-
_maybe_add_quantization_annotation(graph_proto, initializer)
1290-
if initializer.const_value is None:
1303+
for value in from_.initializers.values():
1304+
_maybe_add_quantization_annotation(graph_proto, value)
1305+
if _should_create_value_info_for_value(value) and value.name not in input_names:
1306+
# Serialize information about all initializers into value_info,
1307+
# except for those that are also graph inputs
1308+
serialize_value_into(graph_proto.value_info.add(), value)
1309+
if value.const_value is None:
12911310
# Skip initializers without constant values
1292-
logger.warning(
1293-
"Initializer '%s' does not have a constant value set.", initializer.name
1294-
)
1311+
logger.warning("Initializer '%s' does not have a constant value set.", value.name)
12951312
continue
12961313
# Make sure the tensor's name is the same as the value's name
1297-
initializer.const_value.name = initializer.name
1298-
serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
1314+
value.const_value.name = value.name
1315+
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
12991316
for node in from_:
13001317
serialize_node_into(graph_proto.node.add(), from_=node)
13011318
for node_output in node.outputs:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:06d78f841f26ec59cea1d15dd2c2a086cb907d6644ef8dac15e6d366935413e8
3-
size 43087292
2+
oid sha256:6dcf6976f8e324c497b0b74b2b9733c4b454f2c259488f5544bbc1aaaf57714c
3+
size 43091738
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:a336102b11d8439daa2c1a164a851f34414529a5610a046943fd869b1b44336f
3-
size 14665355
2+
oid sha256:ba424976b53bf2f141bfd001b48c0cc1c5c798b49def51f39a72f17e1f74e3a2
3+
size 14673089
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:31fbebb580ff85ed8eefa7fb95d4e2cbda41fe267afeaae2d4f4177264d1f4e7
3-
size 46918368
2+
oid sha256:12d24be13a03ea8ddebcc5ea229390d49fb0da08ad1df896b03703c664e2def1
3+
size 46921843
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:efd167b736106103235f42b480027c28c798dd46117526ca49067a2bdbc7b327
3-
size 311182
2+
oid sha256:6519a87ecf89132a9d39c59c47a442ae5833faf14811575d0b2323e8d13e30a8
3+
size 313873

0 commit comments

Comments
 (0)