Skip to content

Commit 14f20ce

Browse files
committed
Rewriter: fix fuse_relu_clip with None max
1 parent 19e5284 commit 14f20ce

2 files changed

Lines changed: 33 additions & 7 deletions

File tree

onnxscript/rewriter/rules/common/_fuse_relus_clips.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def extract_min_max(self, node: ir.Node):
7474
min_clip = min_input.const_value.numpy()
7575

7676
if len(node.inputs) > 2:
77-
max_clip = node.inputs[2].const_value.numpy()
77+
max_clip = node.inputs[2]
78+
if max_clip is not None:
79+
max_clip = max_clip.const_value.numpy()
7880

7981
return min_clip, max_clip, dtype
8082

onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import parameterized
1010
from onnx_ir.passes.common import onnx_checker, shape_inference
1111

12-
from onnxscript.rewriter import (
13-
MatchingTracer,
14-
MatchStatus,
15-
RewriteRule,
16-
testing,
17-
)
12+
from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing
1813
from onnxscript.rewriter.rules.common import _fuse_relus_clips
1914
from onnxscript.rewriter.rules.common._fuse_relus_clips import (
2015
successive_clip_relu_rule,
@@ -206,6 +201,35 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes):
206201
""")
207202
self.run_test(model, expected_op_types=["Clip"])
208203

204+
@parameterized.parameterized.expand(
205+
[
206+
(
207+
"relu_then_clip",
208+
"""
209+
x1 = Relu(X)
210+
Y = Clip(x1,min,"")
211+
""",
212+
),
213+
(
214+
"clip_then_relu",
215+
"""
216+
x1 = Clip(X,min,"")
217+
Y = Relu(x1)
218+
""",
219+
),
220+
]
221+
)
222+
def test_successful_fuse_successive_relu_clip_no_max(self, _, nodes):
223+
model = ir.from_onnx_text(f"""
224+
< ir_version: 10, opset_import: ["" : 20] >
225+
test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y)
226+
<float min = {{1.0}}>
227+
{{
228+
{nodes}
229+
}}
230+
""")
231+
self.run_test(model, expected_op_types=["Clip"])
232+
209233
@parameterized.parameterized.expand(
210234
[
211235
(

0 commit comments

Comments
 (0)