Skip to content

Commit 953a940

Browse files
committed
review(min_max_to_clip): remove is_graph_input condition
1 parent f71868c commit 953a940

2 files changed

Lines changed: 29 additions & 33 deletions

File tree

onnxscript/rewriter/rules/common/_min_max_to_clip.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ def check(self, context, out1, out2, **_):
8989

9090
# Ensure all inputs except the first are constants
9191
for input_ in first_node.inputs[1:] + second_node.inputs[1:]:
92-
if input_.is_graph_input():
93-
return check_result.fail(f"{input_.name} is a graph input.")
94-
9592
if ir.convenience.get_const_tensor(input_) is None:
9693
return check_result.fail(f"{input_.name} is not a constant.")
9794

onnxscript/rewriter/rules/common/_min_max_to_clip_test.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,21 @@ def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewri
173173

174174
@parameterized.expand(
175175
[
176-
("min_graph_input", "Min", fuse_successive_min_rule),
177-
("max_graph_input", "Max", fuse_successive_max_rule),
176+
("min_graph_input", "Min"),
177+
("max_graph_input", "Max"),
178178
]
179179
)
180-
def test_failure_fuse_successive_min_or_max_graph_inputs(self, _, op_type, rewrite_rule):
180+
def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants(self, _, op_type):
181181
base_model = ir.from_onnx_text(f"""
182182
< ir_version: 10, opset_import: ["" : 20] >
183183
test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y)
184+
<float[1] cst1 = {{1.0}}, float[1] cst2 = {{6}}>
184185
{{
185186
x1 = {op_type}(X, cst1)
186187
Y = {op_type}(x1, cst2)
187188
}}
188189
""")
189-
self.run_failed_condition_test(base_model, rewrite_rule, "is a graph input")
190+
self.run_test(base_model, expected_op_types=[op_type])
190191

191192

192193
class TestMinMaxToClip(_TestMinMaxToClipBase):
@@ -215,6 +216,18 @@ def test_successful_min_max_to_clip_constants(self):
215216
""")
216217
self.run_test(base_model, expected_op_types=["Constant", "Clip"])
217218

219+
def test_successful_min_max_to_clip_graph_inputs_as_constants(self):
220+
base_model = ir.from_onnx_text("""
221+
< ir_version: 10, opset_import: ["" : 20] >
222+
test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
223+
<float[1] min = {12.0}, float[1] max = {6.0}>
224+
{
225+
x1 = Min(X, min)
226+
Y = Max(x1, max)
227+
}
228+
""")
229+
self.run_test(base_model, expected_op_types=["Clip"])
230+
218231
def test_failure_min_max_to_clip_invalid_bounds(self):
219232
"""Min node should have the max value and Max node should have the min value."""
220233
base_model = ir.from_onnx_text("""
@@ -245,19 +258,6 @@ def test_failure_fuse_min_max_to_clip_non_constant(self):
245258
model, fuse_successive_min_max_rule, "is not a constant."
246259
)
247260

248-
def test_failure_min_max_to_clip_graph_inputs(self):
249-
base_model = ir.from_onnx_text("""
250-
< ir_version: 10, opset_import: ["" : 20] >
251-
test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
252-
{
253-
x1 = Min(X, min)
254-
Y = Max(x1, max)
255-
}
256-
""")
257-
self.run_failed_condition_test(
258-
base_model, fuse_successive_min_max_rule, "is a graph input"
259-
)
260-
261261
def test_failure_min_max_to_clip_need_scalars(self):
262262
base_model = ir.from_onnx_text("""
263263
< ir_version: 10, opset_import: ["" : 20] >
@@ -299,6 +299,18 @@ def test_successful_max_min_to_clip_constants(self):
299299
""")
300300
self.run_test(base_model, expected_op_types=["Constant", "Clip"])
301301

302+
def test_successful_max_min_to_clip_graph_inputs_as_constants(self):
303+
base_model = ir.from_onnx_text("""
304+
< ir_version: 10, opset_import: ["" : 20] >
305+
test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
306+
<float[1] min = {12.0}, float[1] max = {6.0}>
307+
{
308+
x1 = Max(X, max)
309+
Y = Min(x1, min)
310+
}
311+
""")
312+
self.run_test(base_model, expected_op_types=["Clip"])
313+
302314
def test_failure_max_min_to_clip_invalid_bounds(self):
303315
"""Min node should have the max value and Max node should have the min value."""
304316
base_model = ir.from_onnx_text("""
@@ -329,19 +341,6 @@ def test_failure_fuse_max_min_to_clip_non_constant(self):
329341
model, fuse_successive_max_min_rule, "is not a constant."
330342
)
331343

332-
def test_failure_max_min_to_clip_graph_inputs(self):
333-
base_model = ir.from_onnx_text("""
334-
< ir_version: 10, opset_import: ["" : 20] >
335-
test_model (float[N, 32, 14, 17] X, float[1] max, float[1] min) => (float [N, ?, ?, ?] Y)
336-
{
337-
x1 = Max(X, max)
338-
Y = Min(x1, min)
339-
}
340-
""")
341-
self.run_failed_condition_test(
342-
base_model, fuse_successive_max_min_rule, "is a graph input"
343-
)
344-
345344
def test_failure_max_min_to_clip_need_scalars(self):
346345
base_model = ir.from_onnx_text("""
347346
< ir_version: 10, opset_import: ["" : 20] >

0 commit comments

Comments
 (0)