diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 232750af78..fc000dc176 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -35,6 +35,7 @@ _broadcast_to_matmul, _cast_constant_of_shape, _collapse_slices, + _fuse_batchnorm, _fuse_pad_into_conv, _fuse_relus_clips, _min_max_to_clip, @@ -53,6 +54,7 @@ *_basic_rules.basic_optimization_rules(), *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, ) diff --git a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py index a5ceb00468..9d8b8f23f4 100644 --- a/onnxscript/rewriter/rules/common/_fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Mapping +from typing import ClassVar, Mapping import numpy as np @@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra class _FuseBatchNormBase(RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" - def __init__( - self, - op_type: str, - name: str | None = None, - remove_nodes: bool = True, - as_function: bool = False, - ) -> None: - super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) - self.op_type = op_type - @abstractmethod def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" @@ -116,8 +106,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M class FuseBatchNormIntoConv(_FuseBatchNormBase): """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" - def __init__(self): - super().__init__("Conv") + op_type: ClassVar = "Conv" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 0 @@ -133,8 +122,7 @@ def pattern(self, op, x): class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" - def __init__(self): - super().__init__("ConvTranspose") + op_type: ClassVar = "ConvTranspose" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 1 @@ -150,8 +138,7 @@ def pattern(self, op, x): class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" - def __init__(self): - super().__init__("Gemm") + op_type: ClassVar = "Gemm" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return (