Skip to content

Commit 9dcdda7

Browse files
committed
address reviews
1 parent 390f0e7 commit 9dcdda7

2 files changed

Lines changed: 26 additions & 27 deletions

File tree

onnxscript/ir/passes/common/constant_manipulation.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,33 +87,28 @@ def _constant_node_attribute_to_tensor(
8787
)
8888
return None
8989

90-
# Dispatch table for attribute-to-tensor conversion
91-
tensor_converters = {
92-
"value": lambda: attr_value.as_tensor(), # pylint: disable=unnecessary-lambda`
93-
"value_int": lambda: ir.tensor(
90+
if attr_name == "value":
91+
tensor = attr_value.as_tensor() # type: ignore[union-attr]
92+
elif attr_name == "value_int":
93+
tensor = ir.tensor(
9494
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
95-
),
96-
"value_ints": lambda: ir.tensor(
95+
)
96+
elif attr_name == "value_ints":
97+
tensor = ir.tensor(
9798
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
98-
),
99-
"value_float": lambda: ir.tensor(
99+
)
100+
elif attr_name == "value_float":
101+
tensor = ir.tensor(
100102
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
101-
),
102-
"value_floats": lambda: ir.tensor(
103+
)
104+
elif attr_name == "value_floats":
105+
tensor = ir.tensor(
103106
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
104-
),
105-
"value_string": lambda: ir.StringTensor(
106-
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
107-
),
108-
"value_strings": lambda: ir.StringTensor(
107+
)
108+
elif attr_name in ("value_string", "value_strings"):
109+
tensor = ir.StringTensor(
109110
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
110-
),
111-
}
112-
converter = tensor_converters.get(attr_name)
113-
if converter is None:
114-
logger.debug(
115-
"Unsupported constant node attribute '%s' in node '%s'", attr_name, node.name
116111
)
117-
return None
118-
119-
return converter() # type: ignore[return-value]
112+
else:
113+
tensor = None
114+
return tensor # type: ignore[return-value]

onnxscript/ir/passes/common/constant_manipulation_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class TestLiftConstantsToInitializersPass(unittest.TestCase):
2121
]
2222
)
2323
def test_pass_with_lifting_float_and_int_constants_to_initializers(
24-
self, ir_dtype, lift_all_constants
24+
self, ir_dtype: ir.DataType, lift_all_constants: bool
2525
):
2626
inputs = [
2727
ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))),
@@ -80,7 +80,7 @@ def test_pass_with_lifting_float_and_int_constants_to_initializers(
8080
]
8181
)
8282
def test_pass_with_lifting_constants_to_initializers_within_subgraph(
83-
self, lift_all_constants
83+
self, lift_all_constants: bool
8484
):
8585
input_value = ir.Value(
8686
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))
@@ -167,7 +167,11 @@ def test_pass_with_lifting_constants_to_initializers_within_subgraph(
167167
]
168168
)
169169
def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings(
170-
self, value, constant_attribute, np_dtype, lift_all_constants
170+
self,
171+
value: float | int | str | list[float] | list[int] | list[str],
172+
constant_attribute: str,
173+
np_dtype: type[np.dtype],
174+
lift_all_constants: bool,
171175
):
172176
input_value = ir.Value(
173177
name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))

0 commit comments

Comments
 (0)