Skip to content

Commit 4bd81b2

Browse files
committed
[Rewriter] Move reshape rule to another class (#2301)
- rewrite test with ir.tape approach. - include new tests around check function.
1 parent f06cfa5 commit 4bd81b2

1 file changed

Lines changed: 75 additions & 59 deletions

File tree

onnxscript/rewriter/basic_rules_test.py

Lines changed: 75 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import onnxscript.rewriter.basic_rules as basic_rules
1616
from onnxscript import ir
1717
from onnxscript.onnx_opset import opset18
18+
from onnxscript.rewriter import MatchingTracer, testing
19+
from onnxscript.rewriter import pattern as orp
1820

1921
FLOAT = onnx.TensorProto.FLOAT
2022

@@ -29,6 +31,10 @@ def _make_model(*args, **kwargs) -> ir.Model:
2931
return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs))
3032

3133

34+
def clone_model(model: ir.Model) -> ir.Model:
35+
return ir.from_proto(ir.to_proto(model))
36+
37+
3238
class BasicRulesTest(unittest.TestCase):
3339
def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]:
3440
feeds: dict[str, Any] = {}
@@ -318,65 +324,6 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model):
318324
self.assertEqual(["Constant", "Unsqueeze"], [n.op_type for n in model.graph])
319325
self._check_model(model_proto, rewritten_model)
320326

321-
@parameterized.parameterized.expand(
322-
[
323-
(
324-
"double_reshape_1",
325-
_make_model(
326-
onnx.helper.make_graph(
327-
[
328-
onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]),
329-
onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]),
330-
],
331-
"name",
332-
[onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])],
333-
[onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])],
334-
[
335-
onnx.numpy_helper.from_array(
336-
np.array([4, 5, 3], dtype=np.int64), name="shape_"
337-
),
338-
onnx.numpy_helper.from_array(
339-
np.array([5, 4, 3], dtype=np.int64), name="shape"
340-
),
341-
],
342-
),
343-
opset_imports=[onnx.helper.make_opsetid("", 18)],
344-
),
345-
),
346-
(
347-
"double_reshape_2",
348-
_make_model(
349-
onnx.helper.make_graph(
350-
[
351-
onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]),
352-
onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]),
353-
],
354-
"name",
355-
[onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])],
356-
[onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])],
357-
[
358-
onnx.numpy_helper.from_array(
359-
np.array([-1], dtype=np.int64), name="shape_"
360-
),
361-
onnx.numpy_helper.from_array(
362-
np.array([5, 4, 3], dtype=np.int64), name="shape"
363-
),
364-
],
365-
),
366-
opset_imports=[onnx.helper.make_opsetid("", 18)],
367-
),
368-
),
369-
]
370-
)
371-
def test_reshape_reshape_rule(self, _: str, model: ir.Model):
372-
rule_set = basic_rules.basic_optimization_rules()
373-
model_proto = ir.serde.serialize_model(model)
374-
rule_set.apply_to_model(model)
375-
rewritten_model = ir.serde.serialize_model(model)
376-
377-
self.assertEqual(["Reshape"], [n.op_type for n in model.graph])
378-
self._check_model(model_proto, rewritten_model)
379-
380327
@classmethod
381328
def _slices_split_models(cls):
382329
models = [
@@ -465,5 +412,74 @@ def model3(X: ot.FLOAT[1, 1]):
465412
check(model3, 0)
466413

467414

415+
class ReshapeReshapeTest(unittest.TestCase):
416+
@staticmethod
417+
def create_model(input_shape, shape1, shape2):
418+
def _convert_shape(shape, name):
419+
if isinstance(shape, np.ndarray):
420+
shape = tape.initializer(ir.Tensor(shape, name=name))
421+
elif isinstance(shape, (list, tuple)):
422+
shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
423+
tape.graph_like.inputs.append(shape)
424+
else:
425+
raise TypeError(f"Unsupported type {type(shape)} for shape.")
426+
return shape
427+
428+
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
429+
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
430+
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))
431+
432+
# Build the graph.
433+
reshape = tape.op("Reshape", inputs=[x, _convert_shape(shape1, "shape_")])
434+
tape.op("Reshape", inputs=[reshape, _convert_shape(shape2, "shape")], output=y)
435+
model = ir.Model(tape.graph_like, ir_version=10)
436+
return model
437+
438+
@parameterized.parameterized.expand(
439+
[
440+
((3, 4, 5), [4, 5, 3], [5, 4, 3]),
441+
((3, 4, 5), [4, 5, 3], [5, 4, 3]),
442+
]
443+
)
444+
def test_reshape_reshape_rule(self, input_shape, shape1, shape2):
445+
model = self.create_model(
446+
input_shape, np.array(shape1, dtype="int64"), np.array(shape2, dtype="int64")
447+
)
448+
updated_model = clone_model(model)
449+
450+
# check rewrite approach.
451+
count = basic_rules.reshape_reshape_rule.apply_to_model(updated_model)
452+
self.assertEqual(count, 1)
453+
self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph])
454+
455+
# Check inference.
456+
inputs = np.random.default_rng(10).random(input_shape, dtype="float32")
457+
testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0)
458+
459+
@parameterized.parameterized.expand(
460+
[
461+
((2,), np.array([1, 6], dtype="int64"), "ignored is not a constant"),
462+
(np.array([1, 6], dtype="int64"), (3,), "is not a constant"),
463+
(
464+
np.array([1, 6], dtype="int64"),
465+
np.array([0, 6], dtype="int64"),
466+
"non-positive values",
467+
),
468+
]
469+
)
470+
def test_unsupported_reshape_reshape(self, shape1, shape2, error_msg):
471+
model = self.create_model((1, 2, 3), shape1, shape2)
472+
473+
# Check rewrite approach.
474+
tracer = MatchingTracer()
475+
count = basic_rules.reshape_reshape_rule.apply_to_model(model, tracer=tracer)
476+
self.assertEqual(count, 0)
477+
478+
# Check that the error message is the expected one
479+
tracer_match = tracer.best_matches_map[basic_rules.reshape_reshape_rule][0]
480+
self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED)
481+
self.assertRegex(tracer_match.match_result.reason, error_msg)
482+
483+
468484
if __name__ == "__main__":
469485
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)