Skip to content

Commit d2b8020

Browse files
committed
Python formatting
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent d96818d commit d2b8020

3 files changed

Lines changed: 35 additions & 27 deletions

File tree

onnxscript/onnxscript/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# ElementType: an enumeration encoding the allowed element types in a tensor
88
# Corresponds to TensorProto::DataType
9+
10+
911
class ElementType:
1012
UNDEFINED = 0
1113

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import onnx
22
import onnx.helper
33

4+
45
class Tensor:
56
# Reference implementation placeholder
67
# represents a generic ONNX tensor type
7-
def __init__(self, dtype = onnx.TensorProto.UNDEFINED, shape = None) -> None:
8+
def __init__(self, dtype=onnx.TensorProto.UNDEFINED, shape=None) -> None:
89
self.dtype = dtype
910
self.shape = shape
10-
11+
1112
def __str__(self) -> str:
1213
shapestr = str(self.shape) if self.shape else "[...]"
1314
return onnx.TensorProto.DataType.Name(self.dtype) + shapestr
14-
15+
1516
def to_type_proto(self):
1617
# TODO: handle None
1718
return onnx.helper.make_tensor_type_proto(self.dtype, self.shape)
@@ -23,11 +24,12 @@ def to_type_proto(self):
2324
# x : FLOAT['M', 'N'] (a tensor of rank 2 of unknown dimensions, with symbolic names)
2425
# x : FLOAT[128, 1024] (a tensor of rank 2 of known dimensions)
2526

27+
2628
class ParametricTensor:
2729
def __init__(self, dtype) -> None:
2830
self.dtype = dtype
29-
30-
def __getitem__ (self, shape):
31+
32+
def __getitem__(self, shape):
3133
def mk_dim(dim):
3234
r = onnx.TensorShapeProto.Dimension()
3335
if (isinstance(dim, int)):
@@ -44,26 +46,27 @@ def mk_dim(dim):
4446
else:
4547
s = [shape]
4648
return Tensor(self.dtype, s)
47-
49+
4850
def to_type_proto(self):
4951
return onnx.helper.make_tensor_type_proto(self.dtype, ())
50-
52+
5153
def __str__(self) -> str:
5254
return onnx.TensorProto.DataType.Name(self.dtype)
5355

54-
FLOAT = ParametricTensor (onnx.TensorProto.FLOAT)
55-
UINT8 = ParametricTensor (onnx.TensorProto.UINT8)
56-
INT8 = ParametricTensor (onnx.TensorProto.INT8)
57-
UINT16 = ParametricTensor (onnx.TensorProto.UINT16)
58-
INT16 = ParametricTensor (onnx.TensorProto.INT16)
59-
INT32 = ParametricTensor (onnx.TensorProto.INT32)
60-
INT64 = ParametricTensor (onnx.TensorProto.INT64)
61-
STRING = ParametricTensor (onnx.TensorProto.STRING)
62-
BOOL = ParametricTensor (onnx.TensorProto.BOOL)
63-
FLOAT16 = ParametricTensor (onnx.TensorProto.FLOAT16)
64-
DOUBLE = ParametricTensor (onnx.TensorProto.DOUBLE)
65-
UINT32 = ParametricTensor (onnx.TensorProto.UINT32)
66-
UINT64 = ParametricTensor (onnx.TensorProto.UINT64)
67-
COMPLEX64 = ParametricTensor (onnx.TensorProto.COMPLEX64)
68-
COMPLEX128 = ParametricTensor (onnx.TensorProto.COMPLEX128)
69-
BFLOAT16 = ParametricTensor (onnx.TensorProto.BFLOAT16)
56+
57+
FLOAT = ParametricTensor(onnx.TensorProto.FLOAT)
58+
UINT8 = ParametricTensor(onnx.TensorProto.UINT8)
59+
INT8 = ParametricTensor(onnx.TensorProto.INT8)
60+
UINT16 = ParametricTensor(onnx.TensorProto.UINT16)
61+
INT16 = ParametricTensor(onnx.TensorProto.INT16)
62+
INT32 = ParametricTensor(onnx.TensorProto.INT32)
63+
INT64 = ParametricTensor(onnx.TensorProto.INT64)
64+
STRING = ParametricTensor(onnx.TensorProto.STRING)
65+
BOOL = ParametricTensor(onnx.TensorProto.BOOL)
66+
FLOAT16 = ParametricTensor(onnx.TensorProto.FLOAT16)
67+
DOUBLE = ParametricTensor(onnx.TensorProto.DOUBLE)
68+
UINT32 = ParametricTensor(onnx.TensorProto.UINT32)
69+
UINT64 = ParametricTensor(onnx.TensorProto.UINT64)
70+
COMPLEX64 = ParametricTensor(onnx.TensorProto.COMPLEX64)
71+
COMPLEX128 = ParametricTensor(onnx.TensorProto.COMPLEX128)
72+
BFLOAT16 = ParametricTensor(onnx.TensorProto.BFLOAT16)

onnxscript/test/converter_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
77

8+
89
class TestConverter(unittest.TestCase):
9-
def _convert (self, script):
10+
def _convert(self, script):
1011
converter = Converter()
1112
converter.convert(script)
1213

13-
def _convert_and_save (self, script):
14+
def _convert_and_save(self, script):
1415
converter = Converter()
1516
fnlist = converter.convert(script)
1617
TEST_OUTPUT_DIR = os.path.join(CURRENT_DIR, "testoutputs")
@@ -37,9 +38,10 @@ def foo(x):
3738
return msdomain.bar(x, x)
3839
"""
3940
self._convert(script)
41+
4042
def test_onnxfns(self):
4143
self._convert(os.path.join(CURRENT_DIR, "onnxfns.py"))
42-
44+
4345
def test_models(self):
4446
self._convert_and_save(os.path.join(CURRENT_DIR, "onnxmodels.py"))
4547

@@ -49,5 +51,6 @@ def test_if_models(self):
4951
def test_loop_models(self):
5052
self._convert_and_save(os.path.join(CURRENT_DIR, "loop.py"))
5153

54+
5255
if __name__ == '__main__':
53-
unittest.main()
56+
unittest.main()

0 commit comments

Comments
 (0)