11import onnx
22import onnx .helper
33
4+
45class Tensor :
56 # Reference implementation placeholder
67 # represents a generic ONNX tensor type
7- def __init__ (self , dtype = onnx .TensorProto .UNDEFINED , shape = None ) -> None :
8+ def __init__ (self , dtype = onnx .TensorProto .UNDEFINED , shape = None ) -> None :
89 self .dtype = dtype
910 self .shape = shape
10-
11+
1112 def __str__ (self ) -> str :
1213 shapestr = str (self .shape ) if self .shape else "[...]"
1314 return onnx .TensorProto .DataType .Name (self .dtype ) + shapestr
14-
15+
1516 def to_type_proto (self ):
1617 # TODO: handle None
1718 return onnx .helper .make_tensor_type_proto (self .dtype , self .shape )
@@ -23,11 +24,12 @@ def to_type_proto(self):
2324# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
2425# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
2526
27+
2628class ParametricTensor :
2729 def __init__ (self , dtype ) -> None :
2830 self .dtype = dtype
29-
30- def __getitem__ (self , shape ):
31+
32+ def __getitem__ (self , shape ):
3133 def mk_dim (dim ):
3234 r = onnx .TensorShapeProto .Dimension ()
3335 if (isinstance (dim , int )):
@@ -44,26 +46,27 @@ def mk_dim(dim):
4446 else :
4547 s = [shape ]
4648 return Tensor (self .dtype , s )
47-
49+
4850 def to_type_proto (self ):
4951 return onnx .helper .make_tensor_type_proto (self .dtype , ())
50-
52+
5153 def __str__ (self ) -> str :
5254 return onnx .TensorProto .DataType .Name (self .dtype )
5355
54- FLOAT = ParametricTensor (onnx .TensorProto .FLOAT )
55- UINT8 = ParametricTensor (onnx .TensorProto .UINT8 )
56- INT8 = ParametricTensor (onnx .TensorProto .INT8 )
57- UINT16 = ParametricTensor (onnx .TensorProto .UINT16 )
58- INT16 = ParametricTensor (onnx .TensorProto .INT16 )
59- INT32 = ParametricTensor (onnx .TensorProto .INT32 )
60- INT64 = ParametricTensor (onnx .TensorProto .INT64 )
61- STRING = ParametricTensor (onnx .TensorProto .STRING )
62- BOOL = ParametricTensor (onnx .TensorProto .BOOL )
63- FLOAT16 = ParametricTensor (onnx .TensorProto .FLOAT16 )
64- DOUBLE = ParametricTensor (onnx .TensorProto .DOUBLE )
65- UINT32 = ParametricTensor (onnx .TensorProto .UINT32 )
66- UINT64 = ParametricTensor (onnx .TensorProto .UINT64 )
67- COMPLEX64 = ParametricTensor (onnx .TensorProto .COMPLEX64 )
68- COMPLEX128 = ParametricTensor (onnx .TensorProto .COMPLEX128 )
69- BFLOAT16 = ParametricTensor (onnx .TensorProto .BFLOAT16 )
56+
57+ FLOAT = ParametricTensor (onnx .TensorProto .FLOAT )
58+ UINT8 = ParametricTensor (onnx .TensorProto .UINT8 )
59+ INT8 = ParametricTensor (onnx .TensorProto .INT8 )
60+ UINT16 = ParametricTensor (onnx .TensorProto .UINT16 )
61+ INT16 = ParametricTensor (onnx .TensorProto .INT16 )
62+ INT32 = ParametricTensor (onnx .TensorProto .INT32 )
63+ INT64 = ParametricTensor (onnx .TensorProto .INT64 )
64+ STRING = ParametricTensor (onnx .TensorProto .STRING )
65+ BOOL = ParametricTensor (onnx .TensorProto .BOOL )
66+ FLOAT16 = ParametricTensor (onnx .TensorProto .FLOAT16 )
67+ DOUBLE = ParametricTensor (onnx .TensorProto .DOUBLE )
68+ UINT32 = ParametricTensor (onnx .TensorProto .UINT32 )
69+ UINT64 = ParametricTensor (onnx .TensorProto .UINT64 )
70+ COMPLEX64 = ParametricTensor (onnx .TensorProto .COMPLEX64 )
71+ COMPLEX128 = ParametricTensor (onnx .TensorProto .COMPLEX128 )
72+ BFLOAT16 = ParametricTensor (onnx .TensorProto .BFLOAT16 )
0 commit comments