@@ -173,20 +173,21 @@ def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewri
173173
174174 @parameterized .expand (
175175 [
176- ("min_graph_input" , "Min" , fuse_successive_min_rule ),
177- ("max_graph_input" , "Max" , fuse_successive_max_rule ),
176+ ("min_graph_input" , "Min" ),
177+ ("max_graph_input" , "Max" ),
178178 ]
179179 )
180- def test_failure_fuse_successive_min_or_max_graph_inputs (self , _ , op_type , rewrite_rule ):
180+ def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants (self , _ , op_type ):
181181 base_model = ir .from_onnx_text (f"""
182182 < ir_version: 10, opset_import: ["" : 20] >
183183 test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y)
184+ <float[1] cst1 = {{1.0}}, float[1] cst2 = {{6}}>
184185 {{
185186 x1 = { op_type } (X, cst1)
186187 Y = { op_type } (x1, cst2)
187188 }}
188189 """ )
189- self .run_failed_condition_test (base_model , rewrite_rule , "is a graph input" )
190+ self .run_test (base_model , expected_op_types = [ op_type ] )
190191
191192
192193class TestMinMaxToClip (_TestMinMaxToClipBase ):
@@ -215,6 +216,18 @@ def test_successful_min_max_to_clip_constants(self):
215216 """ )
216217 self .run_test (base_model , expected_op_types = ["Constant" , "Clip" ])
217218
219+ def test_successful_min_max_to_clip_graph_inputs_as_constants (self ):
220+ base_model = ir .from_onnx_text ("""
221+ < ir_version: 10, opset_import: ["" : 20] >
222+ test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
223+ <float[1] min = {12.0}, float[1] max = {6.0}>
224+ {
225+ x1 = Min(X, min)
226+ Y = Max(x1, max)
227+ }
228+ """ )
229+ self .run_test (base_model , expected_op_types = ["Clip" ])
230+
218231 def test_failure_min_max_to_clip_invalid_bounds (self ):
219232 """Min node should have the max value and Max node should have the min value."""
220233 base_model = ir .from_onnx_text ("""
@@ -245,19 +258,6 @@ def test_failure_fuse_min_max_to_clip_non_constant(self):
245258 model , fuse_successive_min_max_rule , "is not a constant."
246259 )
247260
248- def test_failure_min_max_to_clip_graph_inputs (self ):
249- base_model = ir .from_onnx_text ("""
250- < ir_version: 10, opset_import: ["" : 20] >
251- test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
252- {
253- x1 = Min(X, min)
254- Y = Max(x1, max)
255- }
256- """ )
257- self .run_failed_condition_test (
258- base_model , fuse_successive_min_max_rule , "is a graph input"
259- )
260-
261261 def test_failure_min_max_to_clip_need_scalars (self ):
262262 base_model = ir .from_onnx_text ("""
263263 < ir_version: 10, opset_import: ["" : 20] >
@@ -299,6 +299,18 @@ def test_successful_max_min_to_clip_constants(self):
299299 """ )
300300 self .run_test (base_model , expected_op_types = ["Constant" , "Clip" ])
301301
302+ def test_successful_max_min_to_clip_graph_inputs_as_constants (self ):
303+ base_model = ir .from_onnx_text ("""
304+ < ir_version: 10, opset_import: ["" : 20] >
305+ test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y)
306+ <float[1] min = {12.0}, float[1] max = {6.0}>
307+ {
308+ x1 = Max(X, max)
309+ Y = Min(x1, min)
310+ }
311+ """ )
312+ self .run_test (base_model , expected_op_types = ["Clip" ])
313+
302314 def test_failure_max_min_to_clip_invalid_bounds (self ):
303315 """Min node should have the max value and Max node should have the min value."""
304316 base_model = ir .from_onnx_text ("""
@@ -329,19 +341,6 @@ def test_failure_fuse_max_min_to_clip_non_constant(self):
329341 model , fuse_successive_max_min_rule , "is not a constant."
330342 )
331343
332- def test_failure_max_min_to_clip_graph_inputs (self ):
333- base_model = ir .from_onnx_text ("""
334- < ir_version: 10, opset_import: ["" : 20] >
335- test_model (float[N, 32, 14, 17] X, float[1] max, float[1] min) => (float [N, ?, ?, ?] Y)
336- {
337- x1 = Max(X, max)
338- Y = Min(x1, min)
339- }
340- """ )
341- self .run_failed_condition_test (
342- base_model , fuse_successive_max_min_rule , "is a graph input"
343- )
344-
345344 def test_failure_max_min_to_clip_need_scalars (self ):
346345 base_model = ir .from_onnx_text ("""
347346 < ir_version: 10, opset_import: ["" : 20] >
0 commit comments