@@ -627,32 +627,43 @@ def _deserialize_graph(
627627
628628 # Initialize the values dictionary for this graph scope with the inputs and initializers
629629 values : dict [str , _core .Value ] = {v .name : v for v in inputs } # type: ignore[misc]
630+
631+ # Enter the graph scope by pushing the values for this scope to the stack
630632 scoped_values .append (values )
633+
631634 initializer_values = []
632- for tensor in initializer_tensors :
633- if tensor .name in values :
635+ for i , tensor in enumerate (initializer_tensors ):
636+ initializer_name = tensor .name
637+ if not initializer_name :
638+ logger .warning (
639+ "Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer." ,
640+ i ,
641+ )
642+ continue
643+ if initializer_name in values :
634644 # The initializer is for an input
635- initializer_value = values [tensor . name ]
645+ initializer_value = values [initializer_name ]
636646 initializer_value .const_value = tensor
637647 else :
638648 # The initializer is for some other value. Create this value first
639649 initializer_value = _core .Value (
640650 None ,
641651 index = None ,
642- name = tensor . name ,
643- # TODO(justinchuby): Fix type hinting for shape and dtype
644- shape = tensor . shape , # type: ignore
652+ name = initializer_name ,
653+ # Include shape and type even if the shape or type is not provided as ValueInfoProto.
654+ # Users expect initialized values to have shape and type information.
645655 type = _core .TensorType (tensor .dtype ),
656+ shape = tensor .shape , # type: ignore[arg-type]
646657 const_value = tensor ,
647658 )
648659 if initializer_value .name in quantization_annotations :
649660 _deserialize_quantization_annotation (
650661 quantization_annotations [initializer_value .name ], initializer_value
651662 )
652- values [tensor . name ] = initializer_value # type: ignore[index]
663+ values [initializer_name ] = initializer_value
653664 initializer_values .append (initializer_value )
654665
655- # Add ValueInfos for this graph scope
666+ # Build the value info dictionary to allow for quick lookup for this graph scope
656667 value_info = {info .name : info for info in proto .value_info }
657668
658669 # Deserialize nodes with all known values
@@ -663,7 +674,10 @@ def _deserialize_graph(
663674
664675 # Fill in values for graph outputs
665676 outputs = [deserialize_value_info_proto (info , values [info .name ]) for info in proto .output ]
677+
678+ # Exit the graph scope by popping the values for this scope from the stack
666679 scoped_values .pop ()
680+
667681 return _core .Graph (
668682 inputs ,
669683 outputs ,
@@ -1204,24 +1218,24 @@ def _serialize_opset_imports_into(
12041218 opset_ids .add (domain = domain , version = version )
12051219
12061220
1207- def _serialize_metadata_props_into (
1221+ def _serialize_string_string_maps (
12081222 string_string_entries : proto_containers .RepeatedCompositeFieldContainer [
12091223 onnx .StringStringEntryProto
12101224 ],
12111225 from_ : Mapping [str , str ],
12121226) -> None :
1213- """Serialize metadata properties into a repeated field of string-string entries.
1227+ """Serialize a <str, str> mapping into a repeated field of string-string entries.
12141228
12151229 Args:
12161230 string_string_entries: The repeated field to serialize into.
1217- from_: The mapping of metadata properties to serialize.
1231+ from_: The mapping of a <str, str> mapping to serialize.
12181232 """
12191233 # Sort names for deterministic serialization
12201234 for key in sorted (from_ ):
12211235 string_string_entries .add (key = key , value = from_ [key ])
12221236
12231237
1224- _serialize_string_string_maps = _serialize_metadata_props_into
1238+ _serialize_metadata_props_into = _serialize_string_string_maps
12251239
12261240
12271241def _maybe_add_quantization_annotation (
@@ -1284,18 +1298,21 @@ def serialize_graph_into(
12841298 # TODO(justinchuby): We should add a method is_initializer() on Value when
12851299 # the initializer list is tracked
12861300 _maybe_add_quantization_annotation (graph_proto , input_ )
1301+ input_names = {input_ .name for input_ in from_ .inputs }
12871302 # TODO(justinchuby): Support sparse_initializer
1288- for initializer in from_ .initializers .values ():
1289- _maybe_add_quantization_annotation (graph_proto , initializer )
1290- if initializer .const_value is None :
1303+ for value in from_ .initializers .values ():
1304+ _maybe_add_quantization_annotation (graph_proto , value )
1305+ if _should_create_value_info_for_value (value ) and value .name not in input_names :
1306+ # Serialize information about all initializers into value_info,
1307+ # except for those that are also graph inputs
1308+ serialize_value_into (graph_proto .value_info .add (), value )
1309+ if value .const_value is None :
12911310 # Skip initializers without constant values
1292- logger .warning (
1293- "Initializer '%s' does not have a constant value set." , initializer .name
1294- )
1311+ logger .warning ("Initializer '%s' does not have a constant value set." , value .name )
12951312 continue
12961313 # Make sure the tensor's name is the same as the value's name
1297- initializer .const_value .name = initializer .name
1298- serialize_tensor_into (graph_proto .initializer .add (), from_ = initializer .const_value )
1314+ value .const_value .name = value .name
1315+ serialize_tensor_into (graph_proto .initializer .add (), from_ = value .const_value )
12991316 for node in from_ :
13001317 serialize_node_into (graph_proto .node .add (), from_ = node )
13011318 for node_output in node .outputs :
0 commit comments