diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 752e3c9430..fc47e6f0e6 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -10,6 +10,7 @@ "div_by_1_rule", "dropout_inference_rule", "dropout_zero_rule", + "flatten_to_reshape_rule", "fuse_batchnorm_into_conv_rule", "fuse_batchnorm_into_conv_transpose_rule", "fuse_batchnorm_into_gemm_rule", @@ -44,6 +45,7 @@ from onnxscript.rewriter.rules.common._basic_rules import ( cast_cast_rule, + flatten_to_reshape_rule, no_op_cast_rule, no_op_expand_rule, no_op_transpose_rule, diff --git a/onnxscript/rewriter/rules/common/_basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py index 6f38050f3e..b7a648880a 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -11,6 +11,8 @@ from typing import ClassVar, Sequence +import numpy as np + from onnxscript import ir from onnxscript.rewriter import _ir_utils as ir_utils from onnxscript.rewriter._basics import MatchResult @@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape): return op.Reshape(op.Reshape(x, shape_ignored), shape) def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): - return op.Reshape(x, shape) + new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name)) + return op.Reshape(x, new_shape, allowzero=self._allowzero) def check(self, context, x, shape_ignored, shape) -> MatchResult: check_result = MatchResult() - if shape_ignored.const_value is None: - return check_result.fail("Shape ignored is not a constant.") - if shape.const_value is None: + + # Shape must be a constant. + if (np_shape := ir_utils.get_numpy_value(shape)) is None: return check_result.fail("Shape is not a constant.") - if shape.const_value.numpy().min() <= 0: - return check_result.fail("Shape has non-positive values.") + # Convert to array to support assignment destination. + self._new_shape = np.array(np_shape, np_shape.dtype) + + # Try to replace {0,-1} values in shape if reshape output is known. + if (reshape_output := context.output_values[0].shape) is not None: + for i, dim in enumerate(reshape_output): + if isinstance(dim, int) and dim > 0: + self._new_shape[i] = dim + + # Constraints for shape. + self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) + if self._allowzero == 1 and any(self._new_shape == 0): + return check_result + if any(self._new_shape == 0) and any(self._new_shape < 0): + return check_result.fail("Shape cannot contain both 0 and -1 dimensions.") + elif np.count_nonzero(self._new_shape == 0) > 1: + return check_result.fail("Shape cannot contain more than one 0 dimension.") + + # At this point, we can safely replace '0' with '-1'. + # Note allowzero is removed since at this point it does not have any effect. + self._allowzero = None + self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape) return check_result @@ -279,6 +302,55 @@ def check(self, context, x, axes1, axes2) -> MatchResult: return check_result +class Flatten2Reshape(RewriteRuleClassBase): + """Convert ``Flatten(x)`` to Reshape.""" + + def pattern(self, op, x: ir.Value): + return op.Flatten(x) + + def rewrite(self, op, x: ir.Value): + new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape")) + return op.Reshape(x, new_shape) + + def check(self, context, x: ir.Value) -> MatchResult: + check_result = MatchResult() + self._new_shape = np.array([-1, -1], "int64") + + # Convert axis in a positive value if possible. + axis = context.root.attributes.get_int("axis", 1) + input_rank = None + if (input_shape := x.shape) is not None: + input_rank = len(input_shape) + if axis < 0: + axis += input_rank + + # Compute reshape shape following axis attribute. + if axis == 0: + self._new_shape[0] = 1 + elif axis == 1: + self._new_shape[0] = 0 + elif axis == input_rank: + self._new_shape[1] = 1 + + # Try to update shape if output is known. + if (output_shape := context.output_values[0].shape) is not None: + for i, dim in enumerate(output_shape): + if isinstance(dim, int): + self._new_shape[i] = dim + + # Try to update shape if input is known. + if input_shape is not None: + if all(isinstance(dim, int) for dim in input_shape[:axis]): + self._new_shape[0] = np.prod(input_shape[:axis]) + if all(isinstance(dim, int) for dim in input_shape[axis:]): + self._new_shape[1] = np.prod(input_shape[axis:]) + + # Verify if it is possible to apply rule. + if np.count_nonzero(self._new_shape == -1) > 1: + return check_result.fail("Impossible to compute new shape.") + return check_result + + # Create rule instances cast_cast_rule = CastCast.rule() no_op_cast_rule = CastIdentity.rule() @@ -289,6 +361,7 @@ def check(self, context, x, axes1, axes2) -> MatchResult: transpose_transpose_rule = TransposeTranspose.rule() unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() squeeze_reshape_1d_rule = SqueezeReshape.rule() +flatten_to_reshape_rule = Flatten2Reshape.rule() def basic_optimization_rules() -> RewriteRuleSet: @@ -311,6 +384,8 @@ def basic_optimization_rules() -> RewriteRuleSet: cast_cast_rule, no_op_cast_rule, no_op_expand_rule, + # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule + flatten_to_reshape_rule, reshape_reshape_rule, slice_split_rule, no_op_transpose_rule, diff --git a/onnxscript/rewriter/rules/common/_basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py index 8709300763..9ce74b22a2 100644 --- a/onnxscript/rewriter/rules/common/_basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -14,6 +14,8 @@ import onnxscript.onnx_types as ot from onnxscript import ir from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter import MatchingTracer, testing +from onnxscript.rewriter import pattern as orp from onnxscript.rewriter.rules.common import _basic_rules FLOAT = onnx.TensorProto.FLOAT @@ -29,6 +31,10 @@ def _make_model(*args, **kwargs) -> ir.Model: return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs)) +def clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + class BasicRulesTest(unittest.TestCase): def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: feeds: dict[str, Any] = {} @@ -318,65 +324,6 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): self.assertEqual(["Constant", "Unsqueeze"], [n.op_type for n in model.graph]) self._check_model(model_proto, rewritten_model) - @parameterized.parameterized.expand( - [ - ( - "double_reshape_1", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([4, 5, 3], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ( - "double_reshape_2", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([-1], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ] - ) - def test_reshape_reshape_rule(self, _: str, model: ir.Model): - rule_set = _basic_rules.basic_optimization_rules() - model_proto = ir.serde.serialize_model(model) - rule_set.apply_to_model(model) - rewritten_model = ir.serde.serialize_model(model) - - self.assertEqual(["Reshape"], [n.op_type for n in model.graph]) - self._check_model(model_proto, rewritten_model) - @classmethod def _slices_split_models(cls): models = [ @@ -465,5 +412,204 @@ def model3(X: ot.FLOAT[1, 1]): check(model3, 0) +class ReshapeReshapeTest(unittest.TestCase): + @staticmethod + def create_model( + input_shape, shape1, shape2, allowzero1=0, allowzero2=0, infer_shape=False + ): + def _convert_shape(shape, name): + if isinstance(shape, np.ndarray): + shape = tape.initializer(ir.Tensor(shape, name=name)) + elif isinstance(shape, (list, tuple)): + shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64)) + tape.graph_like.inputs.append(shape) + else: + raise TypeError(f"Unsupported type {type(shape)} for shape.") + return shape + + x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) + y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + reshape = tape.op( + "Reshape", + inputs=[x, _convert_shape(shape1, "shape_")], + attributes={"allowzero": allowzero1}, + ) + tape.op( + "Reshape", + inputs=[reshape, _convert_shape(shape2, "shape")], + attributes={"allowzero": allowzero2}, + output=y, + ) + model = ir.Model(tape.graph_like, ir_version=10) + + # Infer shapes. + if infer_shape: + model = ir.passes.common.ShapeInferencePass()(model).model + return model + + @parameterized.parameterized.expand( + [ + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 8), [2, 0, 3, -1], [0, 3, 2, 8]), + ((3, 4, 8), [3, 4, -1], [-1, 12], 1), + ((3, 4, 2), [0, 4, -1], [12, -1], 0, 1), + ((3, 0, 8), [4, 2, 0, 0], [3, 0], 1, 1), + ] + ) + def test_reshape_reshape_rule( + self, input_shape, shape1, shape2, allowzero1=0, allowzero2=0 + ): + model = self.create_model( + input_shape, + np.array(shape1, dtype="int64"), + np.array(shape2, dtype="int64"), + allowzero1=allowzero1, + allowzero2=allowzero2, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(10).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand([([3, 2, 3, 3, 3], 1), ([0, -1, 3, 2], 0)]) + def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0): + input_shape = (3, 6, 9) + shape1 = np.array(shape1, dtype="int64") + # Build the model with unknown shape1. + model = self.create_model( + input_shape, + (shape1.size,), + np.array((1, 6, 27), dtype="int64"), + allowzero1=allowzero1, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + feeds = { + "X": np.random.default_rng(2).random(input_shape, dtype="float32"), + "shape_": shape1, + } + testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0) + + @parameterized.parameterized.expand( + [((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)] + ) + def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0): + # Note that shape inference is required for this test to be valid. + shape2 = np.array(shape2, dtype="int64") + model = self.create_model( + input_shape, + np.array((3, 2, -1), dtype="int64"), + shape2, + allowzero2=allowzero2, + infer_shape=True, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(7).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ((3,), "is not a constant"), + (np.array([0, -1], dtype="int64"), "both 0 and -1 dimensions"), + (np.array([0, 0, 3], dtype="int64"), "more than one 0 dimension"), + ] + ) + def test_unsupported_reshape_reshape(self, shape2, error_msg): + model = self.create_model((1, 2, 3), np.array([1, 6], dtype="int64"), shape2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.reshape_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.reshape_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + +class Flatten2ReshapeTest(unittest.TestCase): + @staticmethod + def create_model(input_shape, axis=1): + x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT)) + y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT)) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + tape.op("Flatten", inputs=[x], attributes={"axis": axis}, output=y) + model = ir.Model(tape.graph_like, ir_version=10) + return model + + @parameterized.parameterized.expand(list(range(-5, 6))) + def test_flatten_to_reshape_rule(self, axis): + input_shape = (1, 4, 8, 7, 5) + model = self.create_model(input_shape=input_shape, axis=axis) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(13).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand(list(range(-4, 5))) + def test_flatten_to_reshape_dynamic_input(self, axis): + model = self.create_model(input_shape=("N", "C1", "C2", "C3"), axis=axis) + # Rule is supported in all cases if the output shape is known for non-special cases. + input_shape = (1, 2, 3, 4) + if axis not in {-3, 0, 1, 4}: + out_shape = ir.Shape((np.prod(input_shape[:axis]), np.prod(input_shape[axis:]))) + model.graph.outputs[0].shape = out_shape + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(17).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + def test_unsupported_flatten_to_reshape(self): + model = self.create_model(input_shape=("N", "C1", "C2"), axis=2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.flatten_to_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "Impossible to compute new shape") + + if __name__ == "__main__": unittest.main(verbosity=2)