@@ -1321,7 +1321,7 @@ def __init__(
13211321 domain : str ,
13221322 op_type : str ,
13231323 inputs : Iterable [Value | None ],
1324- attributes : Iterable [Attr | RefAttr ] = (),
1324+ attributes : Iterable [Attr ] = (),
13251325 * ,
13261326 overload : str = "" ,
13271327 num_outputs : int | None = None ,
@@ -1353,7 +1353,7 @@ def __init__(
13531353 metadata_props: The metadata properties.
13541354
13551355 Raises:
1356- TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr` .
1356+ TypeError: If the attributes are not :class:`Attr`.
13571357 ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
13581358 ValueError: If an output value is ``None``, when outputs is specified.
13591359 ValueError: If an output value has a producer set already, when outputs is specified.
@@ -1368,13 +1368,13 @@ def __init__(
13681368 # Values belong to their defining nodes. The values list is immutable
13691369 self ._outputs : tuple [Value , ...] = self ._create_outputs (num_outputs , outputs )
13701370 attributes = tuple (attributes )
1371- if attributes and not isinstance (attributes [0 ], ( Attr , RefAttr ) ):
1371+ if attributes and not isinstance (attributes [0 ], Attr ):
13721372 raise TypeError (
1373- f"Expected the attributes to be Attr or RefAttr , got { type (attributes [0 ])} . "
1373+ f"Expected the attributes to be Attr, got { type (attributes [0 ])} . "
13741374 "If you are copying the attributes from another node, make sure you call "
13751375 "node.attributes.values() because it is a dictionary."
13761376 )
1377- self ._attributes : OrderedDict [str , Attr | RefAttr ] = OrderedDict (
1377+ self ._attributes : OrderedDict [str , Attr ] = OrderedDict (
13781378 (attr .name , attr ) for attr in attributes
13791379 )
13801380 self ._overload : str = overload
@@ -1633,7 +1633,7 @@ def outputs(self, _: Sequence[Value]) -> None:
16331633 raise AttributeError ("outputs is immutable. Please create a new node instead." )
16341634
16351635 @property
1636- def attributes (self ) -> OrderedDict [str , Attr | RefAttr ]:
1636+ def attributes (self ) -> OrderedDict [str , Attr ]:
16371637 """The attributes of the node."""
16381638 return self ._attributes
16391639
@@ -3106,22 +3106,28 @@ def __repr__(self) -> str:
31063106 return f"{ self .__class__ .__name__ } ({ self .domain !r} , { self .name !r} , { self .overload !r} , inputs={ self .inputs !r} , attributes={ self .attributes !r} ), outputs={ self .outputs !r} )"
31073107
31083108
3109- class RefAttr (_protocols .ReferenceAttributeProtocol , _display .PrettyPrintable ):
3110- """Reference attribute."""
3109+ class Attr (
3110+ _protocols .AttributeProtocol ,
3111+ _protocols .ReferenceAttributeProtocol ,
3112+ _display .PrettyPrintable ,
3113+ ):
3114+ """Base class for ONNX attributes or references."""
31113115
3112- __slots__ = ("_name" , "_ref_attr_name" , "_type" , "doc_string" )
3116+ __slots__ = ("_name" , "_ref_attr_name" , "_type" , "_value" , " doc_string" )
31133117
31143118 def __init__ (
31153119 self ,
31163120 name : str ,
3117- ref_attr_name : str ,
31183121 type : _enums .AttributeType ,
3122+ value : Any ,
3123+ ref_attr_name : str | None = None ,
31193124 * ,
31203125 doc_string : str | None = None ,
3121- ) -> None :
3126+ ):
31223127 self ._name = name
3123- self ._ref_attr_name = ref_attr_name
31243128 self ._type = type
3129+ self ._value = value
3130+ self ._ref_attr_name = ref_attr_name
31253131 self .doc_string = doc_string
31263132
31273133 @property
@@ -3132,43 +3138,21 @@ def name(self) -> str:
31323138 def name (self , value : str ) -> None :
31333139 self ._name = value
31343140
3135- @property
3136- def ref_attr_name (self ) -> str :
3137- return self ._ref_attr_name
3138-
3139- @ref_attr_name .setter
3140- def ref_attr_name (self , value : str ) -> None :
3141- self ._ref_attr_name = value
3142-
31433141 @property
31443142 def type (self ) -> _enums .AttributeType :
31453143 return self ._type
31463144
3147- @type .setter
3148- def type (self , value : _enums .AttributeType ) -> None :
3149- self ._type = value
3150-
3151- def __repr__ (self ) -> str :
3152- return f"{ self .__class__ .__name__ } ({ self ._name !r} , { self ._type !r} , ref_attr_name={ self .ref_attr_name !r} )"
3153-
3154-
3155- class Attr (_protocols .AttributeProtocol , _display .PrettyPrintable ):
3156- """Base class for ONNX attributes."""
3145+ @property
3146+ def value (self ) -> Any :
3147+ return self ._value
31573148
3158- __slots__ = ("doc_string" , "name" , "type" , "value" )
3149+ @property
3150+ def ref_attr_name (self ) -> str | None :
3151+ return self ._ref_attr_name
31593152
3160- def __init__ (
3161- self ,
3162- name : str ,
3163- type : _enums .AttributeType ,
3164- value : Any ,
3165- * ,
3166- doc_string : str | None = None ,
3167- ):
3168- self .name = name
3169- self .type = type
3170- self .value = value
3171- self .doc_string = doc_string
3153+ def is_ref (self ) -> bool :
3154+ """Check if this attribute is a reference attribute."""
3155+ return self .ref_attr_name is not None
31723156
31733157 def __eq__ (self , other : object ) -> bool :
31743158 if not isinstance (other , _protocols .AttributeProtocol ):
@@ -3185,11 +3169,15 @@ def __eq__(self, other: object) -> bool:
31853169 return True
31863170
31873171 def __str__ (self ) -> str :
3172+ if self .is_ref ():
3173+ return f"@{ self .ref_attr_name } "
31883174 if self .type == _enums .AttributeType .GRAPH :
31893175 return textwrap .indent ("\n " + str (self .value ), " " * 4 )
31903176 return str (self .value )
31913177
31923178 def __repr__ (self ) -> str :
3179+ if self .is_ref ():
3180+ return f"{ self .__class__ .__name__ } ({ self .name !r} , { self .type !r} , ref_attr_name={ self .ref_attr_name !r} )"
31933181 return f"{ self .__class__ .__name__ } ({ self .name !r} , { self .type !r} , { self .value !r} )"
31943182
31953183 # Well typed getters
@@ -3269,6 +3257,29 @@ def as_graphs(self) -> Sequence[Graph]:
32693257
32703258
32713259# NOTE: The following functions are just for convenience
3260+
3261+
3262+ def RefAttr (
3263+ name : str ,
3264+ ref_attr_name : str ,
3265+ type : _enums .AttributeType ,
3266+ doc_string : str | None = None ,
3267+ ) -> Attr :
3268+ """Create a reference attribute.
3269+
3270+ Args:
3271+ name: The name of the attribute.
3272+ type: The type of the attribute.
3273+ ref_attr_name: The name of the referenced attribute.
3274+ doc_string: Documentation string.
3275+
3276+ Returns:
3277+ A reference attribute.
3278+ """
3279+ # NOTE: The function name is capitalized to maintain API backward compatibility.
3280+ return Attr (name , type , None , ref_attr_name = ref_attr_name , doc_string = doc_string )
3281+
3282+
32723283def AttrFloat32 (name : str , value : float , doc_string : str | None = None ) -> Attr :
32733284 """Create a float attribute."""
32743285 # NOTE: The function name is capitalized to maintain API backward compatibility.
0 commit comments