diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 6953a76929..30a97315d0 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -41,11 +41,11 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v5 + uses: actions/checkout@v6 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@v3 + uses: github/codeql-action/init@v4 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -59,7 +59,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@v3 + uses: github/codeql-action/autobuild@v4 # â„šī¸ Command-line programs to run using the OS shell. # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun @@ -72,4 +72,4 @@ jobs: # ./location_of_script_within_repo/buildscript.sh - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 + uses: github/codeql-action/analyze@v4 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 88787d6cce..a87792fd2f 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -20,7 +20,7 @@ jobs: pull-requests: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: misspell # Check spelling uses: reviewdog/action-misspell@v1 with: @@ -43,9 +43,9 @@ jobs: permissions: security-events: write steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: # Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset. python-version: "3.10" @@ -78,7 +78,7 @@ jobs: # To toggle linter comments in the files page, press `i` on the keyboard if: always() continue-on-error: true - uses: github/codeql-action/upload-sarif@v3 + uses: github/codeql-action/upload-sarif@v4 with: # Path to SARIF file relative to the root of the repository sarif_file: lintrunner.sarif diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index c547608cc6..fcff6d2dd4 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -57,9 +57,9 @@ jobs: nox-tag: test-onnx-ir-git runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install nox @@ -83,7 +83,7 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} - name: Upload torchlib error reports if: always() - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: Error reports (${{ matrix.name }}-${{ matrix.os }}) path: error_reports @@ -95,9 +95,9 @@ jobs: os: [ubuntu-latest, windows-latest] runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: pip @@ -119,9 +119,9 @@ jobs: update_readme: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 - name: Update readme run: | python docs/update_readme.py diff --git a/.github/workflows/pages.yaml b/.github/workflows/pages.yaml index c38de94b15..51ae68abcc 100644 --- a/.github/workflows/pages.yaml +++ b/.github/workflows/pages.yaml @@ -25,14 +25,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v5 + uses: actions/checkout@v6 - name: Setup Pages uses: actions/configure-pages@v4 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel diff --git a/.lintrunner.toml b/.lintrunner.toml index 7b31bab564..ed937d352c 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -39,6 +39,7 @@ include_patterns = [ exclude_patterns = [ 'tests/**', # Skip linting test files for speed # FIXME: Fix typing annotations in these files + 'examples/custom_op_expansion.py', 'onnxscript/converter_test.py', 'onnxscript/converter.py', 'onnxscript/evaluator_test.py', @@ -57,7 +58,6 @@ exclude_patterns = [ 'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME 'onnxscript/tools/function_unittest_producer.py', # FIXME 'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME - 'onnxscript/rewriter/generic_pattern.py', # FIXME ] command = [ 'python', diff --git a/VERSION b/VERSION index 267577d47e..d3532a107e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.4.1 +0.5.7 diff --git a/docs/api/optimizer.md b/docs/api/optimizer.md index 90de403099..6c8adf21bb 100644 --- a/docs/api/optimizer.md +++ b/docs/api/optimizer.md @@ -15,5 +15,4 @@ optimizer.inline optimizer.basic_constant_propagation optimizer.fold_constants - optimizer.remove_unused_nodes ``` diff --git a/examples/custom_op_expansion.py b/examples/custom_op_expansion.py new file mode 100644 index 0000000000..c261ff18d7 --- /dev/null +++ b/examples/custom_op_expansion.py @@ -0,0 +1,61 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ruff: noqa + +"""A utility and an example showing how onnxscript functions can be used to define function expansions +and be used with the inliner to replace calls to the custom function with an expanded subgraph. +This is useful to perform certain classes of graph surgery easily. +""" + +import onnx + +import onnxscript +import onnxscript.utils.replace as replace + +script = onnxscript.script +FLOAT = onnxscript.FLOAT +op = onnxscript.values.opset22 +local = onnxscript.values.Opset("local", 1) + + +# Example Model: Actual models can come from ModelBuilder or Exporter or any other source. +# Models can contain calls to custom operations (from a custom domain like 'local' here or +# even "com.microsoft" etc.) +@script() +def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]: + DoubleX = op.Add(X, X) + YSquare = op.Mul(Y, Y) + # Example call to a custom operation + Temp1 = local.CustomOp1(DoubleX, YSquare) + # Another call to a custom operation with an attribute + Temp2 = local.CustomOp2(Temp1, alp=0.9) + return Temp2 + + +# Define expansions for custom operations as onnxscript functions +@script(opset=local) +def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]: + Temp1 = op.Sub(X, Y) + return op.Div(Temp1, X) + + +@script(opset=local) +def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]: + Temp2 = op.Elu(X, alpha=alp) + return op.Mul(Temp2, Temp2) + + +# Now, we can replace the custom operations in the model with their expansions: + +functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()] + +model = model_script.to_model_proto() + +print("Original Model with custom operations:") +print(onnx.printer.to_text(model)) + + +updated_model = replace.replace_functions(model, functions) + +print("\nUpdated Model after replacing custom operations with their expansions:") +print(onnx.printer.to_text(updated_model)) diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py index 7b5c56d5e3..fd84d7f3cb 100644 --- a/examples/pattern_rewriting.py +++ b/examples/pattern_rewriting.py @@ -141,28 +141,3 @@ def rotary_apply_pattern(op, x, pos_ids, axis): rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10) rule.apply_to_model(ir_model) - -# TODO(rama): Update the following, the trace-printed looks different now. - -###################################### -# The logs shows every time the algorithm rejected a pattern. -# We can see the following: -# -# :: -# -# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast -# --hint--: BACKWARD: different node types -# --pattern -# ConcatTraining(transpose, transpose) -> (output, length) -# -- model -# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) -# iteration=1 -# --marked-- #2 -# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] -# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] -# len(stacked)=0:[] -# -# Line 673 in file `generic_pattern.py`, the match was rejected. -# It says while comparing two nodes in the backward direction, -# node types do not match. -# It also says that two nodes were actually matched. diff --git a/noxfile.py b/noxfile.py index f69c5af9bd..fc80761b68 100644 --- a/noxfile.py +++ b/noxfile.py @@ -12,7 +12,6 @@ COMMON_TEST_DEPENDENCIES = ( - "beartype==0.17.2", "expecttest==0.1.6", "hypothesis", "numpy", @@ -30,9 +29,9 @@ "ml-dtypes", ) ONNX = "onnx==1.17" -ONNX_RUNTIME = "onnxruntime==1.20.1" -PYTORCH = "torch==2.5.1" -TORCHVISON = "torchvision==0.20.1" +ONNX_RUNTIME = "onnxruntime==1.23.0" +PYTORCH = "torch==2.7.1" +TORCHVISON = "torchvision==0.22.1" TRANSFORMERS = "transformers==4.37.2" ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = ( "flatbuffers", @@ -42,7 +41,7 @@ "packaging", "protobuf", ) -ONNX_IR = "onnx_ir==0.1.7" +ONNX_IR = "onnx_ir==0.1.12" ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" diff --git a/onnxscript/__init__.py b/onnxscript/__init__.py index b839093d2b..bccfd84cd4 100644 --- a/onnxscript/__init__.py +++ b/onnxscript/__init__.py @@ -55,6 +55,7 @@ "opset20", "opset21", "opset22", + "opset23", "opset_ai_onnx_ml1", "opset_ai_onnx_ml2", "opset_ai_onnx_ml3", @@ -92,6 +93,7 @@ opset20, opset21, opset22, + opset23, opset_ai_onnx_ml1, opset_ai_onnx_ml2, opset_ai_onnx_ml3, diff --git a/onnxscript/_framework_apis/torch_2_5.py b/onnxscript/_framework_apis/torch_2_5.py index 2f8601c7c6..5bbb64af88 100644 --- a/onnxscript/_framework_apis/torch_2_5.py +++ b/onnxscript/_framework_apis/torch_2_5.py @@ -13,6 +13,7 @@ ] import dataclasses +import importlib.util import os import pathlib from typing import Callable @@ -63,20 +64,48 @@ def check_model(model: ir.Model) -> None: del model # Unused yet -def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None: +def save_model_with_external_data( + model: ir.Model, model_path: str | os.PathLike, verbose: bool = False +) -> None: """Save the model with external data. The model is unchanged after saving.""" # TODO(#1835): Decide if we want to externalize large attributes as well - for value in model.graph.initializers.values(): - if value.const_value is None: - raise ValueError( - "The model contains uninitialized initializer values. " - "Please make sure all initializer values are initialized." - ) + uninitialized_values = [ + value.name for value in model.graph.initializers.values() if value.const_value is None + ] + if uninitialized_values: + raise ValueError( + f"The model contains uninitialized initializer values ({uninitialized_values}). " + "Please make sure all initializer values are initialized." + ) destination_path = pathlib.Path(model_path) data_path = f"{destination_path.name}.data" - ir.save(model, model_path, external_data=data_path) + # Show a progress bar if verbose is True and tqdm is installed + use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None + + if use_tqdm: + import tqdm # pylint: disable=import-outside-toplevel + + with tqdm.tqdm() as pbar: + total_set = False + + def callback( + tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo + ) -> None: + nonlocal total_set + if not total_set: + pbar.total = metadata.total + total_set = True + + pbar.update() + pbar.set_description( + f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}" + ) + + ir.save(model, model_path, external_data=data_path, callback=callback) + else: + ir.save(model, model_path, external_data=data_path) def get_torchlib_ops() -> list[_OnnxFunctionMeta]: diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 0403f60c91..c89542d344 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -47,183 +47,254 @@ def get_id(e): return {get_id(lhs)} -def assigned_vars( - stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: - """Return the set of all variables that may be assigned to in an execution of input stmt - or sequence of statements. - """ - - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: - result: set[Any] = set() - for s in block: - result = result | assigned_vars(s, formatter) - return result - - if isinstance(stmt, ast.Assign): - return _lhs_vars(stmt.targets[0]) - if isinstance(stmt, ast.AnnAssign): - return _lhs_vars(stmt.target) - if isinstance(stmt, ast.Return): - return set() - if isinstance(stmt, ast.If): - return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) - if isinstance(stmt, ast.For): - return assigned_in_block(stmt.body) | {_get_loop_var(stmt, formatter)} - if isinstance(stmt, ast.While): - return assigned_in_block(stmt.body) - if isinstance(stmt, list): - return assigned_in_block(stmt) - if isinstance(stmt, ast.Break): - return set() - if ast_utils.is_print_call(stmt): - return set() - if ast_utils.is_doc_string(stmt): - return set() - error_message = formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") - raise ValueError(error_message) +class AstAnalyzer: + def __init__( + self, + fun: ast.FunctionDef, + formatter: sourceinfo.Formatter, + globals: dict[str, Any] | None = None, + ) -> None: + self._formatter = formatter + self._constant_if_condition: dict[ast.If, bool] = {} + self._live_in: dict[ast.stmt, Set[str]] = {} + self._live_out: dict[ast.stmt, Set[str]] = {} + if globals: + self._compute_constant_if_conditions(fun, globals) + self.do_liveness_analysis(fun) + def live_in(self, stmt: ast.stmt) -> Set[str] | None: + """Get the set of variables that are live at the entry of the given statement.""" + return self._live_in.get(stmt) -def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Perform liveness analysis of the given function-ast. The results of the - analysis are stored directly with each statement-ast `s` as attributes `s.live_in` - and `s.live_out`. - """ + def live_out(self, stmt: ast.stmt) -> Set[str] | None: + """Get the set of variables that are live at the exit of the given statement.""" + return self._live_out.get(stmt) - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - stmt.live_out = live_out # type: ignore[attr-defined] - live = do_visit(stmt, live_out) - stmt.live_in = live # type: ignore[attr-defined] - return live + def _compute_constant_if_conditions( + self, fun: ast.FunctionDef, globals: dict[str, Any] + ) -> None: + """Identify if-statements with constant conditions. - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: - for s in reversed(block): - live_out = visit(s, live_out) - return live_out + If-statements of the form `if name:` where `name` is an outer-scope variable + and name is not assigned to within the function body, are treated as constant + conditions. The value of such conditions is determined from the outer-scope. + """ + + assigned_vars = self.assigned_vars(fun.body) + for node in ast.walk(fun): + if isinstance(node, ast.If): + if isinstance(node.test, ast.Name): + python_var = node.test.id + if python_var not in assigned_vars and python_var in globals: + # Condition depends on an outer-scope variable. + self._constant_if_condition[node] = bool(globals[python_var]) + + def constant_if_condition(self, if_stmt: ast.If) -> Optional[bool]: + """Return the constant value of the if-statement condition, if it is constant. + + Args: + if_stmt: The if-statement-ast to analyze. + + Returns: + The constant boolean value of the if-statement condition, or None if not constant. + """ + return self._constant_if_condition.get(if_stmt, None) + + def assigned_vars(self, stmt: ast.stmt | list[ast.stmt]) -> Set[str]: + """Return the set of all variables that may be assigned to in an execution of input stmt + or sequence of statements. + """ + + def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + result: set[Any] = set() + for s in block: + result = result | self.assigned_vars(s) + return result if isinstance(stmt, ast.Assign): - return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + return _lhs_vars(stmt.targets[0]) if isinstance(stmt, ast.AnnAssign): - return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + return _lhs_vars(stmt.target) if isinstance(stmt, ast.Return): - return _used_vars(stmt.value) + return set() if isinstance(stmt, ast.If): - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) - return live1 | live2 | _used_vars(stmt.test) + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + return assigned_in_block(stmt.body) | assigned_in_block(stmt.orelse) + elif constant_cond: + return assigned_in_block(stmt.body) + else: + return assigned_in_block(stmt.orelse) if isinstance(stmt, ast.For): - p_loop_var = _get_loop_var(stmt, formatter) - prev = None - curr = live_out - while curr != prev: - prev = curr - curr = visitBlock(stmt.body, prev).difference({p_loop_var}) - return curr + return assigned_in_block(stmt.body) | {_get_loop_var(stmt, self._formatter)} if isinstance(stmt, ast.While): - cond_vars = _used_vars(stmt.test) - prev = None - curr = live_out | cond_vars - while curr != prev: - prev = curr - curr = visitBlock(stmt.body, prev) | cond_vars - return curr + return assigned_in_block(stmt.body) + if isinstance(stmt, list): + return assigned_in_block(stmt) if isinstance(stmt, ast.Break): - # The following is sufficient for the current restricted usage, where - # a (conditional) break is allowed only as the last statement of a loop. - # Break statements in the middle of the loop, however, will require - # a generalization. - return live_out - if ast_utils.is_doc_string(stmt): - return live_out + return set() if isinstance(stmt, ast.FunctionDef): - return live_out + # Supported function-definitions (used for higher order ops like Scan) + # do not assign to any variable in the outer scope. + return set() if ast_utils.is_print_call(stmt): - return live_out - raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) + return set() + if ast_utils.is_doc_string(stmt): + return set() + error_message = self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + raise ValueError(error_message) - assert isinstance(fun, ast.FunctionDef) - live: set[Any] = set() - for s in reversed(fun.body): - live = visit(s, live) + def do_liveness_analysis(self, fun: ast.FunctionDef): + """Perform liveness analysis of the given function-ast.""" + def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + self._live_out[stmt] = live_out + live = do_visit(stmt, live_out) + self._live_in[stmt] = live + return live -def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): - """Return the set of variables that are used before being defined by given block. - In essence, this identifies the "inputs" to a given code-block. - For example, consider the following code-block: - :: + def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + for s in reversed(block): + live_out = visit(s, live_out) + return live_out - x = x + 10 - y = 20 - z = x + y - x = 30 + if isinstance(stmt, ast.Assign): + return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + if isinstance(stmt, ast.AnnAssign): + return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + if isinstance(stmt, ast.Return): + return _used_vars(stmt.value) + if isinstance(stmt, ast.If): + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return live1 | live2 | _used_vars(stmt.test) + elif constant_cond: + return visitBlock(stmt.body, live_out) + else: + return visitBlock(stmt.orelse, live_out) + if isinstance(stmt, ast.For): + p_loop_var = _get_loop_var(stmt, self._formatter) + prev = None + curr = live_out + while curr != prev: + prev = curr + curr = visitBlock(stmt.body, prev).difference({p_loop_var}) + return curr + if isinstance(stmt, ast.While): + cond_vars = _used_vars(stmt.test) + prev = None + curr = live_out | cond_vars + while curr != prev: + prev = curr + curr = visitBlock(stmt.body, prev) | cond_vars + return curr + if isinstance(stmt, ast.Break): + # The following is sufficient for the current restricted usage, where + # a (conditional) break is allowed only as the last statement of a loop. + # Break statements in the middle of the loop, however, will require + # a generalization. + return live_out + if ast_utils.is_doc_string(stmt): + return live_out + if isinstance(stmt, ast.FunctionDef): + return live_out + if ast_utils.is_print_call(stmt): + return live_out + raise ValueError( + self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + ) - The exposed_uses of this code-block is { x }. The value of z is not used within - the block. Even though the value of y is used within the block, it is assigned - a value before it is used. However, in contrast, the incoming value of x is used - (in the first statement). Hence x is included in the exposed_uses. - """ + assert isinstance(fun, ast.FunctionDef) + live: set[Any] = set() + for s in reversed(fun.body): + live = visit(s, live) - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: - for stmt in reversed(block): - live_out = visit(stmt, live_out) - return live_out + def exposed_uses(self, stmts: Sequence[ast.stmt]): + """Return the set of variables that are used before being defined by given block. + In essence, this identifies the "inputs" to a given code-block. + For example, consider the following code-block: + :: - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - if isinstance(stmt, ast.Assign): - return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) - if isinstance(stmt, ast.AnnAssign): - return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) - if isinstance(stmt, ast.Return): - return _used_vars(stmt.value) - if isinstance(stmt, ast.If): - live1 = visitBlock(stmt.body, live_out) - live2 = visitBlock(stmt.orelse, live_out) - return (live1 | live2) | _used_vars(stmt.test) - if ast_utils.is_print_call(stmt): - return live_out - if ast_utils.is_doc_string(stmt): - return live_out - if isinstance(stmt, ast.For): - # Analysis assumes loop may execute zero times. Results can be improved - # for loops that execute at least once. - loop_var_set = {_get_loop_var(stmt, formatter)} - used_after_loop = live_out.difference(loop_var_set) - used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) - used_in_loop_header = _used_vars(stmt.iter) - return used_inside_loop | used_in_loop_header | used_after_loop - if isinstance(stmt, ast.While): - # Analysis assumes loop may execute zero times. Results can be improved - # for loops that execute at least once. - used_inside_loop = visitBlock(stmt.body, set()) - used_in_loop_header = _used_vars(stmt.test) - return used_inside_loop | used_in_loop_header | live_out - if isinstance(stmt, ast.Break): - # Currently, we assume that break statements are only allowed as the last - # statement in a loop, as "if cond: break". - return live_out - if isinstance(stmt, ast.FunctionDef): - if stmt.name in live_out: - live_out.remove(stmt.name) - live_out = live_out | outer_scope_variables(stmt, formatter) + x = x + 10 + y = 20 + z = x + y + x = 30 + + The exposed_uses of this code-block is { x }. The value of z is not used within + the block. Even though the value of y is used within the block, it is assigned + a value before it is used. However, in contrast, the incoming value of x is used + (in the first statement). Hence x is included in the exposed_uses. + """ + + def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + for stmt in reversed(block): + live_out = visit(stmt, live_out) return live_out - raise ValueError(formatter(stmt, f"Unsupported statement type {type(stmt)!r}.")) - return visitBlock(stmts, set()) + def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + if isinstance(stmt, ast.Assign): + return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) + if isinstance(stmt, ast.AnnAssign): + return live_out.difference(_lhs_vars(stmt.target)) | _used_vars(stmt.value) + if isinstance(stmt, ast.Return): + return _used_vars(stmt.value) + if isinstance(stmt, ast.If): + constant_cond = self.constant_if_condition(stmt) + if constant_cond is None: + live1 = visitBlock(stmt.body, live_out) + live2 = visitBlock(stmt.orelse, live_out) + return (live1 | live2) | _used_vars(stmt.test) + elif constant_cond: + return visitBlock(stmt.body, live_out) + else: + return visitBlock(stmt.orelse, live_out) + if ast_utils.is_print_call(stmt): + return live_out + if ast_utils.is_doc_string(stmt): + return live_out + if isinstance(stmt, ast.For): + # Analysis assumes loop may execute zero times. Results can be improved + # for loops that execute at least once. + loop_var_set = {_get_loop_var(stmt, self._formatter)} + used_after_loop = live_out.difference(loop_var_set) + used_inside_loop = visitBlock(stmt.body, set()).difference(loop_var_set) + used_in_loop_header = _used_vars(stmt.iter) + return used_inside_loop | used_in_loop_header | used_after_loop + if isinstance(stmt, ast.While): + # Analysis assumes loop may execute zero times. Results can be improved + # for loops that execute at least once. + used_inside_loop = visitBlock(stmt.body, set()) + used_in_loop_header = _used_vars(stmt.test) + return used_inside_loop | used_in_loop_header | live_out + if isinstance(stmt, ast.Break): + # Currently, we assume that break statements are only allowed as the last + # statement in a loop, as "if cond: break". + return live_out + if isinstance(stmt, ast.FunctionDef): + if stmt.name in live_out: + live_out.remove(stmt.name) + live_out = live_out | self.outer_scope_variables(stmt) + return live_out + raise ValueError( + self._formatter(stmt, f"Unsupported statement type {type(stmt)!r}.") + ) + return visitBlock(stmts, set()) -def outer_scope_variables(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): - """Return the set of outer-scope variables used in a nested function. + def outer_scope_variables(self, fun: ast.FunctionDef): + """Return the set of outer-scope variables used in a nested function. - Args: - fun: The function-ast to analyze. - formatter: The formatter object. + Args: + fun: The function-ast to analyze. + formatter: The formatter object. - Returns: - A set of variable names (strings). - """ - assert isinstance(fun, ast.FunctionDef) - used_vars_ = exposed_uses(fun.body, formatter) - inputs = [x.arg for x in fun.args.args] - return used_vars_.difference(inputs) + Returns: + A set of variable names (strings). + """ + assert isinstance(fun, ast.FunctionDef) + used_vars_ = self.exposed_uses(fun.body) + inputs = [x.arg for x in fun.args.args] + return used_vars_.difference(inputs) diff --git a/onnxscript/_internal/analysis_test.py b/onnxscript/_internal/analysis_test.py index 74e7ca4c18..7a7e5feaa0 100644 --- a/onnxscript/_internal/analysis_test.py +++ b/onnxscript/_internal/analysis_test.py @@ -14,24 +14,27 @@ class AnalysisResultsVisitor(ast.NodeVisitor): """Visitor class to flatten the results of liveness analysis in a pre-order traversal.""" - def __init__(self) -> None: + def __init__(self, analyzer: analysis.AstAnalyzer) -> None: super().__init__() self.results: list[Any] = [] + self.analyzer = analyzer def generic_visit(self, node): - if hasattr(node, "live_in"): - self.results.append(node.live_in) + live_in = self.analyzer.live_in(node) + if live_in is not None: + self.results.append(live_in) ast.NodeVisitor.generic_visit(self, node) if isinstance(node, (ast.For, ast.While)): last = node.body[-1] - self.results.append(last.live_out) # type: ignore + live_out = self.analyzer.live_out(last) + self.results.append(live_out) # type: ignore class TestLivenessAnalysis(unittest.TestCase): def analyze(self, fun): source, parse_tree = ast_utils.get_src_and_ast(fun) - analysis.do_liveness_analysis(parse_tree, formatter(source)) - visitor = AnalysisResultsVisitor() + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + visitor = AnalysisResultsVisitor(analyzer) visitor.visit(parse_tree) return visitor.results @@ -113,7 +116,8 @@ def while_eg(x): class TestExposedUses(unittest.TestCase): def assertUses(self, f, expected): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.exposed_uses(parse_tree.body, formatter(source)) + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + result = analyzer.exposed_uses(parse_tree.body) self.assertEqual(result, set(expected)) def test_basic(self): @@ -190,7 +194,8 @@ def f(x): class TestAssignedVarAnalysis(unittest.TestCase): def assert_assigned_vars(self, f, expected: set[str]): source, parse_tree = ast_utils.get_src_and_ast(f) - result = analysis.assigned_vars(parse_tree.body, formatter(source)) + analyzer = analysis.AstAnalyzer(parse_tree, formatter(source)) + result = analyzer.assigned_vars(parse_tree.body) self.assertEqual(result, expected) def test_basic_defs(self): @@ -248,5 +253,42 @@ def f(x): self.assert_assigned_vars(f, {"x", "y"}) +class ConstantIfAnalysisTest(unittest.TestCase): + def test_constant_ifs(self): + cond1 = True + cond2 = False + + def f(x): + if cond1: + y = x + 1 + else: + y = x + 2 + if cond2: + z = y * 2 + else: + z = y * 3 + if x > 0: + w = z - 1 + else: + w = z + 1 + return w + + source, parse_tree = ast_utils.get_src_and_ast(f) + + analyzer = analysis.AstAnalyzer( + parse_tree, formatter(source), {"cond1": True, "cond2": False} + ) + for node in ast.walk(parse_tree): + if isinstance(node, ast.If): + result = analyzer.constant_if_condition(node) + if isinstance(node.test, ast.Name): + if node.test.id == "cond1": + self.assertEqual(result, True) + elif node.test.id == "cond2": + self.assertEqual(result, False) + else: + self.assertIsNone(result) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py deleted file mode 100644 index 3cf8a8db57..0000000000 --- a/onnxscript/_internal/runtime_typing.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""An internal wrapper for the beartype library. - -Decorate a function with `@runtime_typing.checked` to enable runtime -type checking. The decorator is a no-op when the `beartype` library is not -installed. -""" - -import typing -import warnings - -__all__ = [ - "checked", -] - -T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any]) - -try: - from beartype import beartype as _beartype_decorator - from beartype import roar as _roar - - checked = typing.cast(typing.Callable[[T], T], _beartype_decorator) - - # Beartype warns when we import from typing because the types are deprecated - # in Python 3.9. But there will be a long time until we can move to using - # the native container types for type annotations (when 3.9 is the lowest - # supported version). So we silence the warning. - warnings.filterwarnings( - "ignore", - category=_roar.BeartypeDecorHintPep585DeprecationWarning, - ) -except ImportError: - - def checked(func: T) -> T: # type: ignore[no-redef] - return func - -except Exception as e: # pylint: disable=broad-exception-caught - # Warn errors that are not import errors (unexpected). - warnings.warn(f"{e}", stacklevel=2) - - def checked(func: T) -> T: # type: ignore[no-redef] - return func diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 49eb398750..1f913ed897 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -84,6 +84,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), + skip(r"^test_attention", "ONNX Runtime 1.23 fails on these tests"), ) if sys.platform == "win32": diff --git a/onnxscript/converter.py b/onnxscript/converter.py index dfcddefbd3..3e87c366ad 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -183,6 +183,13 @@ def __init__( self._nextvar: int = 0 self._used_vars: set[str] = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._analyzer: analysis.AstAnalyzer | None = None + + @property + def analyzer(self) -> analysis.AstAnalyzer: + if self._analyzer is None: + raise RuntimeError("Analyzer not initialized.") + return self._analyzer @property def default_opset(self) -> values.Opset: @@ -1089,12 +1096,24 @@ def ret(exp, i, suffix): return ret(val, 0, "") def _translate_if_stmt(self, stmt: ast.If) -> None: - if hasattr(stmt, "live_out"): - live_defs = list( - stmt.live_out.intersection(analysis.assigned_vars(stmt, self._message)) - ) - else: - live_defs = list(analysis.assigned_vars(stmt, self._message)) + constant_cond = self.analyzer.constant_if_condition(stmt) + if constant_cond is True: + # Translate only the "then" branch + for s in stmt.body: + self._translate_stmt(s) + return + if constant_cond is False: + # Translate only the "else" branch + for s in stmt.orelse: + self._translate_stmt(s) + return + live_def_set = self.analyzer.assigned_vars(stmt) + live_out = self.analyzer.live_out(stmt) + if live_out is not None: + # Ideally, live_out should never be None here. But handle this conditionally + # due to some existing usage. + live_def_set = live_out.intersection(live_def_set) + live_defs = list(live_def_set) test = self._translate_expr(stmt.test, "cond").name lineno = self._source_of(stmt).lineno thenGraph, sub_fct_then = self._translate_block( @@ -1174,9 +1193,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: else: self.fail(loop_stmt, f"Unexpected loop type {type(loop_stmt)!r}.") # analyze loop body - exposed_uses = analysis.exposed_uses(loop_stmt.body, self._message) - vars_def_in_loop = analysis.assigned_vars(loop_stmt.body, self._message) - loop_state_vars = vars_def_in_loop.intersection(exposed_uses | loop_stmt.live_out) + exposed_uses = self.analyzer.exposed_uses(loop_stmt.body) + vars_def_in_loop = self.analyzer.assigned_vars(loop_stmt.body) + live_out = self.analyzer.live_out(loop_stmt) + assert live_out is not None, "live_out cannot be None here." + loop_state_vars = vars_def_in_loop.intersection(exposed_uses | live_out) scan_outputs = set() # TODO outputs = list(loop_state_vars | scan_outputs) @@ -1362,7 +1383,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: self._enter_scope(fn.name, fn) self._translate_function_def_common(fn) function_ir = self._exit_scope() - outer_scope_vars = analysis.outer_scope_variables(fn, self._message) + outer_scope_vars = self.analyzer.outer_scope_variables(fn) function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] @@ -1448,10 +1469,11 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: self._set_default_opset(opset, stmt) domain = self.this_module.domain self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) - analysis.do_liveness_analysis(stmt, self._message) + self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) fn_ir.debug_print() self.this_module.add_function_def(fn_ir) + self._analyzer = None return fn_ir raise ValueError(f"Unsupported top-level statement type {type(stmt)!r}.") diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 9a7ca504a7..a35711aea9 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -710,6 +710,36 @@ def model(x): self.assertEqual(len(onnx_opset_import), 1) self.assertEqual(onnx_opset_import[0].version, 19) + def test_traced_if(self): + """Test that traced if statements are converted correctly.""" + + @script() + def add_model(x: FLOAT[10]) -> FLOAT[10]: + y = op.Add(x, x) + return y + + @script() + def sub_model(x: FLOAT[10]) -> FLOAT[10]: + y = op.Sub(x, x) + return y + + def make_model(flag: bool): + @script() + def model(x: FLOAT[10]) -> FLOAT[10]: + if flag: + y = op.Add(x, x) + else: + y = op.Sub(x, x) + return y + + return model.to_model_proto() + + model_true = make_model(True) + onnxscript.testing.assert_isomorphic(model_true, add_model.to_model_proto()) + + model_false = make_model(False) + onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242a..b547737bf5 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -12,7 +12,7 @@ import parameterized import onnxscript -import onnxscript.function_libs.torch_lib.ops # Import to populate registry +import onnxscript.function_libs.torch_lib.ops # Import to populate registry # noqa: F401 from onnxscript.function_libs.tools.torch_lib import deduce_type_constraints from onnxscript.function_libs.torch_lib import registration diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index ebbdd43bd8..6661e34afe 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List, Sequence import torch -import torchgen.gen import torchgen.model from torch._ops import _OpNamespace from torchgen.model import FunctionSchema diff --git a/onnxscript/function_libs/torch_lib/ops/common.py b/onnxscript/function_libs/torch_lib/ops/common.py index d7784a5289..38544b59ba 100644 --- a/onnxscript/function_libs/torch_lib/ops/common.py +++ b/onnxscript/function_libs/torch_lib/ops/common.py @@ -5,6 +5,8 @@ # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value" from __future__ import annotations +from collections.abc import Sequence + import numpy.typing as npt import onnx @@ -26,14 +28,24 @@ @onnxscript.script(common_opset) def Rank(input: tensor_typing.TTensor) -> INT64: - """Take the rank of the input tensor.""" + """Deprecated. + + NOTE: Do not remove, for backward compatibility with PyTorch < 2.10. + + Take the rank of the input tensor. + """ return op.Size(op.Shape(input)) @onnxscript.script(common_opset) def IsScalar(input: tensor_typing.TTensor) -> BOOL: - """Return whether the input has rank 0, or is a scalar.""" + """Deprecated. + + NOTE: Do not remove, for backward compatibility with PyTorch < 2.10. + + Return whether the input has rank 0, or is a scalar. + """ return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0)) @@ -78,3 +90,22 @@ def constant( A constant node. """ return op.Constant(value=ir.tensor(array, dtype=ir.DataType(dtype))) + + +def merge_dims(dims: Sequence[int | INT64]) -> INT64: + """Concatenate dimensions into a single value.""" + + if not dims: + return op.Constant(value_ints=ir.AttrInt64s("value_ints", [])) + + neg_one_1d = op.Constant(value_ints=ir.AttrInt64s("value_ints", [-1])) + + result_dims = [ + op.Constant(value_ints=[d]) if isinstance(d, int) else op.Reshape(d, neg_one_1d) + for d in dims + ] + + # Set the output type to INT64 so op.Concat can be used + for dim in result_dims: + dim.dtype = ir.DataType.INT64 + return op.Concat(*result_dims, axis=0) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6eb9fb4cbb..64905496ec 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -18,7 +18,6 @@ import torch from onnxscript import ( - BFLOAT16, BOOL, COMPLEX64, COMPLEX128, @@ -30,9 +29,6 @@ INT32, INT64, UINT8, - UINT16, - UINT32, - UINT64, graph, ir, ) @@ -56,10 +52,10 @@ from onnxscript.onnx_opset import opset18 as op from onnxscript.onnx_types import TensorType +_INT32_MAX = 2147483647 _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 _MATH_PI = math.pi -Rank = common_ops.Rank @torch_op("aten::_local_scalar_dense", trace_only=True) @@ -77,13 +73,11 @@ def aten__local_scalar_dense(self: TensorType) -> TensorType: @torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax_half( - self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool -) -> FLOAT: +def aten__log_softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 - if half_to_float: + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: self = op.Cast(self, to=FLOAT.dtype) if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) @@ -93,44 +87,23 @@ def aten__log_softmax_half( return result -@torch_op("aten::_log_softmax", trace_only=True) -def aten__log_softmax( - self: TFloatHighPrecision, - dim: int, - half_to_float: bool, -) -> TFloatHighPrecision: - """_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" +@torch_op("aten::_softmax", trace_only=True) +def aten__softmax(self: TFloat, dim: int, half_to_float: bool) -> TFloatHighPrecision: + """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" self_is_scalar = len(self.shape) == 0 + + if half_to_float and self.dtype in {ir.DataType.FLOAT16, ir.DataType.BFLOAT16}: + self = op.Cast(self, to=FLOAT.dtype) + if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.LogSoftmax(self, axis=dim) + result = op.Softmax(self, axis=dim) if self_is_scalar: + # Convert to scalar when input is scalar result = op.Squeeze(result) - return result - - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax_half(self: Union[FLOAT16, BFLOAT16], dim: int, half_to_float: bool) -> FLOAT: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - - # trace_only because we need to cast conditionally based on half_to_float - if half_to_float: - self = op.Cast(self, to=FLOAT.dtype) - - return aten_softmax_no_dtype(self, dim) - - -@torch_op("aten::_softmax", trace_only=True) -def aten__softmax( - self: TFloatHighPrecision, dim: int, half_to_float: bool -) -> TFloatHighPrecision: - """_softmax(Tensor self, int dim, bool half_to_float) -> Tensor""" - # trace_only to reuse aten_softmax_no_dtype - - del half_to_float # Unused - return aten_softmax_no_dtype(self, dim) + return result @torch_op(("aten::abs", "_operator::abs"), trace_only=True) @@ -161,16 +134,35 @@ def aten_acosh(self: TFloat) -> TFloat: return op.Acosh(self) -@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True) -def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: +@torch_op("aten::add.Tensor", trace_only=True) +def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" - # TODO(microsoft/onnxruntime#15977): Improve fp16 precision + + if self.dtype == ir.DataType.BOOL: + # alpha can also be bool + if alpha == 0: + return op.Identity(self) + return op.Or(self, other) + if alpha != 1.0: alpha = op.CastLike(alpha, other) other = op.Mul(other, alpha) return op.Add(self, other) +@torch_op("aten::add.Scalar", trace_only=True) +def aten_add_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_add(self, other, alpha=alpha) + + +@torch_op("_operator::add", trace_only=True) +def operator_add(self: TTensor, other: TTensor) -> TTensor: + return op.Add(self, other) + + @torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True) def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: """add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor""" @@ -374,7 +366,6 @@ def aten_all_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::all.dims", trace_only=True) def _aten_all_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: """all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor""" @@ -493,7 +484,6 @@ def aten_any_dims(self: TTensor, dim: Sequence[int] = (), keepdim: bool = False) return self -@torch_op("aten::any.dims", trace_only=True) def _aten_any_dims_no_dim(self: TTensor, keepdims: bool) -> BOOL: if len(self.shape) == 0: result = op.Cast(self, to=BOOL.dtype) @@ -733,7 +723,6 @@ def aten_argmax( return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -746,7 +735,6 @@ def _aten_argmax(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmax", private=True, trace_only=True) def _aten_argmax_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -774,7 +762,6 @@ def aten_argmin( return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -787,7 +774,6 @@ def _aten_argmin(self: Union[RealType, UINT8], keepdim: bool = False) -> INT64: return result -@torch_op("aten::argmin", private=True, trace_only=True) def _aten_argmin_dim(self: Union[RealType, UINT8], dim: int, keepdim: bool = False) -> INT64: """argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor""" @@ -925,16 +911,21 @@ def aten_atan(self: TFloat) -> TFloat: return op.Atan(self) -@torch_op("aten::atan2") +@torch_op("aten::atan2", trace_only=True) def aten_atan2(self: TFloat, other: TFloat) -> TFloat: """atan2(Tensor self, Tensor other) -> Tensor""" # self is y, and other is x on coordinate slope = op.Div(self, other) atan = op.Atan(slope) + zero = common_ops.constant(0.0, dtype=self.dtype) + pi = common_ops.constant(_MATH_PI, dtype=self.dtype) + + second_third_quadrant = op.Where(op.Greater(self, zero), atan + pi, atan - pi) + result = op.Where(op.Less(other, zero), second_third_quadrant, atan) - second_third_quadrant = op.Where(self > 0.0, atan + _MATH_PI, atan - _MATH_PI) - result = op.Where(other < 0.0, second_third_quadrant, atan) + # Map NaN to 0 to match PyTorch behavior + result = op.Where(op.IsNaN(result), zero, result) return result @@ -970,11 +961,11 @@ def reshape_to_1d(tensor): return op.SequenceMap(self, body=reshape_to_1d) -@torch_op("aten::atleast_2d") +@torch_op("aten::atleast_2d", trace_only=True) def aten_atleast_2d(self: TTensor) -> TTensor: """atleast_2d(Tensor self) -> Tensor""" - if Rank(self) <= 1: + if len(self.shape) <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1])) return op.Identity(self) @@ -998,7 +989,7 @@ def reshape_to_2d(tensor): def aten_atleast_3d(self: TTensor) -> TTensor: """atleast_3d(Tensor self) -> Tensor""" - rank = Rank(self) + rank = len(self.shape) if rank <= 1: self = op.Reshape(self, op.Constant(value_ints=[1, -1, 1])) elif rank == 2: @@ -1184,6 +1175,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor: return op.CastLike(sampled, self) +@torch_op("aten::bilinear", trace_only=True) def aten_bilinear( input1: TensorType, input2: TensorType, @@ -1192,7 +1184,23 @@ def aten_bilinear( ) -> TensorType: """bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor""" - raise NotImplementedError() + # Bilinear transformation: y = x1^T A x2 + b + # input1 shape: (..., in1_features) + # input2 shape: (..., in2_features) + # weight shape: (out_features, in1_features, in2_features) + # bias shape: (out_features) - optional + # output shape: (..., out_features) + + # Use Einsum to compute the bilinear transformation + # "...i,oij,...j->...o" means: + # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] + result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + + # Add bias if provided + if bias is not None: + result = op.Add(result, bias) + + return result def aten_binary_cross_entropy_with_logits( @@ -1226,212 +1234,178 @@ def aten_binomial( @torch_op( ( "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", "_operator::and_", ), trace_only=True, ) -def aten_bitwise_and(self: TInt, other: TInt) -> TInt: +def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor: """bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_and implements the BOOL variant - return op.BitwiseAnd(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + if dtype.is_integer(): + return op.BitwiseAnd(self, other) + if dtype == ir.DataType.BOOL: + return op.And(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) - result = op.BitShift(self, other, direction="LEFT") +@torch_op("aten::bitwise_and.Scalar", trace_only=True) +def aten_bitwise_and_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor""" - return op.Cast(result, to=INT16.dtype) + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_and(self, other_tensor) -@torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, -) -def aten_bitwise_left_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) +@torch_op("aten::bitwise_and.Scalar_Tensor", trace_only=True) +def aten_bitwise_and_scalar_tensor(self: float, other: TTensor) -> TTensor: + """bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" - result = op.BitShift(self, other, direction="LEFT") - - return op.Cast(result, to=INT32.dtype) + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_and(self_tensor, other) @torch_op( ( "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", "_operator::__lshift__", - "aten::__lshift__.Scalar", ), trace_only=True, ) -def aten_bitwise_left_shift_int64(self: INT64, other: INT64) -> INT64: +def aten_bitwise_left_shift(self: TInt, other: TInt) -> TInt: """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + # assert other >= 0 - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) + if dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + elif dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + elif dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + elif dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + else: + raise NotImplementedError(f"Not implemented for type {dtype}") + + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT64.dtype) + return op.Cast(result, to=signed_dtype) @torch_op( - ( - "aten::bitwise_left_shift.Tensor", - "aten::bitwise_left_shift.Tensor_Scalar", - "aten::bitwise_left_shift.Scalar_Tensor", - "_operator::__lshift__", - "aten::__lshift__.Scalar", - ), - trace_only=True, + ("aten::bitwise_left_shift.Tensor_Scalar", "aten::__lshift__.Scalar"), trace_only=True ) -def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - # assert other >= 0 - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) +def aten_bitwise_left_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_left_shift(self, other_tensor) - result = op.BitShift(self, other, direction="LEFT") - return op.Cast(result, to=INT8.dtype) +@torch_op("aten::bitwise_left_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_left_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_left_shift(self_tensor, other) @torch_op("aten::bitwise_not", trace_only=True) -def aten_bitwise_not(self: TInt) -> TInt: +def aten_bitwise_not(self: TTensor) -> TTensor: """bitwise_not(Tensor self) -> Tensor""" - # logical_not implements the BOOL variant - return op.BitwiseNot(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + if self.dtype.is_integer(): + return op.BitwiseNot(self) + raise NotImplementedError(f"Not implemented for type {self.dtype}") @torch_op( ( "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", "_operator::or_", ), trace_only=True, ) -def aten_bitwise_or(self: TInt, other: TInt) -> TInt: +def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor: """bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_or implements the BOOL variant - - return op.BitwiseOr(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int16(self: INT16, other: INT16) -> INT16: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT16.dtype) - other = op.Cast(other, to=UINT16.dtype) + if dtype.is_integer(): + return op.BitwiseOr(self, other) + if dtype == ir.DataType.BOOL: + return op.Or(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFF), to=UINT16.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT16.dtype), op.Cast(shifted, to=INT16.dtype) - ) +@torch_op("aten::bitwise_or.Scalar", trace_only=True) +def aten_bitwise_or_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_or(self, other_tensor) -@torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) -) -def aten_bitwise_right_shift_int32(self: INT32, other: INT32) -> INT32: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT32.dtype) - other = op.Cast(other, to=UINT32.dtype) - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFFFFFFFF), to=UINT32.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT32.dtype), op.Cast(shifted, to=INT32.dtype) - ) +@torch_op("aten::bitwise_or.Scalar_Tensor", trace_only=True) +def aten_bitwise_or_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_or(self_tensor, other) @torch_op( ( "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) + ), + trace_only=True, ) -def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: +def aten_bitwise_right_shift(self: TInt, other: TInt) -> TInt: """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.bitwidth == 8: + unsigned_dtype = ir.DataType.UINT8 + signed_dtype = ir.DataType.INT8 + mask = ir.tensor(0xFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 16: + unsigned_dtype = ir.DataType.UINT16 + signed_dtype = ir.DataType.INT16 + mask = ir.tensor(0xFFFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 32: + unsigned_dtype = ir.DataType.UINT32 + signed_dtype = ir.DataType.INT32 + mask = ir.tensor(0xFFFFFFFF, dtype=unsigned_dtype) + elif dtype.bitwidth == 64: + unsigned_dtype = ir.DataType.UINT64 + signed_dtype = ir.DataType.INT64 + mask = ir.tensor(0xFFFFFFFFFFFFFFFF, dtype=unsigned_dtype) # 0xFFFFFFFFFFFFFFFF + else: + raise NotImplementedError(f"Not implemented for type {dtype}") + negative = op.Less(self, 0) - self = op.Cast(self, to=UINT64.dtype) - other = op.Cast(other, to=UINT64.dtype) + self = op.Cast(self, to=unsigned_dtype) + other = op.Cast(other, to=unsigned_dtype) # Simulate arithmetic shift using logical shift # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - # 0xFFFFFFFFFFFFFFFF - op.Cast(op.Constant(value_int=-1), to=UINT64.dtype), - other, - direction="RIGHT", - ) + mask = op.BitShift(mask, other, direction="RIGHT") mask = op.BitwiseNot(mask) # Do logical shift shifted = op.BitShift(self, other, direction="RIGHT") @@ -1439,54 +1413,53 @@ def aten_bitwise_right_shift_int64(self: INT64, other: INT64) -> INT64: negative_shifted = op.BitwiseOr(shifted, mask) # Choose the shifted value based on the sign bit return op.Where( - negative, op.Cast(negative_shifted, to=INT64.dtype), op.Cast(shifted, to=INT64.dtype) + negative, op.Cast(negative_shifted, to=signed_dtype), op.Cast(shifted, to=signed_dtype) ) @torch_op( - ( - "aten::bitwise_right_shift.Tensor", - "aten::bitwise_right_shift.Tensor_Scalar", - "aten::bitwise_right_shift.Scalar_Tensor", - "_operator::__rshift__", - "aten::__rshift__.Scalar", - ) + ("aten::bitwise_right_shift.Tensor_Scalar", "aten::__rshift__.Scalar"), trace_only=True ) -def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8: - """bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor""" - negative = op.Less(self, 0) - self = op.Cast(self, to=UINT8.dtype) - other = op.Cast(other, to=UINT8.dtype) +def aten_bitwise_right_shift_tensor_scalar(self: TInt, other: int) -> TInt: + """bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_right_shift(self, other_tensor) - # Simulate arithmetic shift using logical shift - # Clear the lower bits of an all one mask to create the mask to simulate the sign bit shifting - mask = op.BitShift( - op.Cast(op.Constant(value_int=0xFF), to=UINT8.dtype), other, direction="RIGHT" - ) - mask = op.BitwiseNot(mask) - # Do logical shift - shifted = op.BitShift(self, other, direction="RIGHT") - # Compute the arithmetic shifted value assuming the sign bit was set - negative_shifted = op.BitwiseOr(shifted, mask) - # Choose the shifted value based on the sign bit - return op.Where( - negative, op.Cast(negative_shifted, to=INT8.dtype), op.Cast(shifted, to=INT8.dtype) - ) +@torch_op("aten::bitwise_right_shift.Scalar_Tensor", trace_only=True) +def aten_bitwise_right_shift_scalar_tensor(self: int, other: TInt) -> TInt: + """bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_right_shift(self_tensor, other) -@torch_op( - ( - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), - trace_only=True, -) -def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: + +@torch_op("aten::bitwise_xor.Tensor", trace_only=True) +def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor: """bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor""" - # logical_xor implements the BOOL variant - return op.BitwiseXor(self, other) + assert self.dtype == other.dtype or self.dtype is None or other.dtype is None + dtype = self.dtype if self.dtype is not None else other.dtype + assert dtype is not None + + if dtype.is_integer(): + return op.BitwiseXor(self, other) + if dtype == ir.DataType.BOOL: + return op.Xor(self, other) + raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}") + + +@torch_op("aten::bitwise_xor.Scalar", trace_only=True) +def aten_bitwise_xor_scalar(self: TTensor, other: int) -> TTensor: + """bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor""" + other_tensor = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_bitwise_xor(self, other_tensor) + + +@torch_op("aten::bitwise_xor.Scalar_Tensor", trace_only=True) +def aten_bitwise_xor_scalar_tensor(self: int, other: TTensor) -> TTensor: + """bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + self_tensor = op.Constant(value=ir.tensor(self, dtype=other.dtype)) + return aten_bitwise_xor(self_tensor, other) @torch_op("aten::blackman_window", trace_only=True) @@ -1523,10 +1496,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::broadcast_to") -def aten_broadcast_to(self: TTensor, size: INT64) -> TTensor: +@torch_op("aten::broadcast_to", trace_only=True) +def aten_broadcast_to(self: TTensor, size: Sequence[INT64]) -> TTensor: """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - + size = common_ops.merge_dims(size) return op.Expand(self, size) @@ -1550,7 +1523,7 @@ def aten_cartesian_prod(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::cat", trace_only=True, complex=True) +@torch_op(("aten::cat", "aten::concat", "aten::concatenate"), trace_only=True, complex=True) def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" # Real representation unsqueezes the last dimension @@ -1563,8 +1536,18 @@ def aten_cat_complex(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: def aten_cat(tensors: Sequence[TTensor], dim: int = 0) -> TTensor: """cat(Tensor[] tensors, int dim=0) -> Tensor""" - # Remove None tensors - tensors = [tensor for tensor in tensors if tensor is not None] + filtered_tensors = [] + for tensor in tensors: + # Remove None tensors + if tensor is None: + continue + # Remove empty tensors + if tensor.shape == (0,): + continue + filtered_tensors.append(tensor) + assert filtered_tensors, "aten::cat received all None or empty tensors" + if len(filtered_tensors) == 1: + return op.Identity(filtered_tensors[0]) return op.Concat(*tensors, axis=dim) @@ -1861,39 +1844,21 @@ def aten_conj_physical(self: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::constant_pad_nd") -def aten_constant_pad_nd(self: TTensor, pad: INT64, value: float = 0.0) -> TTensor: +@torch_op("aten::constant_pad_nd", trace_only=True) +def aten_constant_pad_nd(self: TTensor, pad: Sequence[INT64], value: float = 0.0) -> TTensor: """constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor""" # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # assume zero-dimensions in the beginning - # rank = len(self.shape) # rank must be scalar - # paddings = list(pad[:]) + [0] * (rank * 2 - len(pad)) + rank = len(self.shape) + paddings = list(pad) + [0] * (rank * 2 - len(pad)) # reverse order and collate first beginnings and then ends - # paddings = paddings[-2::-2] + paddings[-1::-2] - - neg_1 = op.Constant(value_ints=[-1]) - - zero_count = op.Sub(op.Mul(Rank(self), 2), op.Size(pad)) - zero_count = op.Reshape(zero_count, neg_1) - zero = op.Constant(value_ints=[0]) - zeros = op.Expand(zero, zero_count) - torch_paddings = op.Concat(pad, zeros, axis=0) - size_d = op.Size(torch_paddings) - steps = op.Constant(value_ints=[-2]) + paddings = paddings[-2::-2] + paddings[-1::-2] + constant_value = op.Constant(value=ir.tensor(value, dtype=self.dtype)) - starts = steps - ends = op.Sub(starts, size_d) - odd_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - - starts = neg_1 - ends = op.Sub(starts, size_d) - even_elements = op.Slice(torch_paddings, starts, ends, zero, steps) - - onnx_padding = op.Concat(odd_elements, even_elements, axis=0) - return op.Pad(self, onnx_padding, value) + return op.Pad(self, paddings, constant_value) @torch_op("aten::contiguous", trace_only=True) @@ -2128,7 +2093,6 @@ def aten_convolution( return result -@torch_op("aten::convolution", private=True, trace_only=True) def _aten_convolution_onnx( input: TFloat, weight: TFloat, @@ -2231,7 +2195,7 @@ def aten_convolution_overrideable( raise NotImplementedError() -@torch_op("aten::copy") +@torch_op("aten::copy", trace_only=True) def aten_copy( self: TTensor, src: TTensor2, @@ -2600,80 +2564,10 @@ def aten_diagflat(self: TensorType, offset: int = 0) -> TensorType: @torch_op(("aten::diagonal", "aten::diagonal_copy"), trace_only=True) -def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TReal: +def aten_diagonal(self: TTensor, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TTensor: """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" - # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims - # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 - # [0,1,2] -> [1,0,2] when dim1=0 and dim2=2 - # [0,1,2] -> [0,1,2] when dim1=1 and dim2=2 - if dim1 < 0: - dim1 = dim1 + len(self.shape) - if dim2 < 0: - dim2 = dim2 + len(self.shape) - - self_rank = len(self.shape) - perm = list(range(self_rank)) - perm.remove(dim1) - perm.remove(dim2) - perm.append(dim1) - perm.append(dim2) - - # If rank=2, then axes=[0]; if rank=3, then axes=[1] - # This is because computing diagonal sum is on dim2 after transpose by perm - axes = [self_rank - 2] - - neg_1 = op.Constant(value_ints=[-1]) - dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row - dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col - mask_shape = op.Concat(dim1_size, dim2_size, axis=0) - mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - mask = op.CastLike(mask, self) - self_t = op.Transpose(self, perm=perm) - result = op.Mul(self_t, mask) - result = op.ReduceSum(result, keepdims=False, axes=axes) - # min(row, col) - min_dim_size = op.Min(dim1_size, dim2_size) - # take 2 tensors as example: - # one is 3x5 in size, min_dim_size = 3, dim1_size = 3 - # the other is 5x3 in size, min_dim_size = 3, dim1_size = 5 - # 3 rows x 5 cols 5 rows x 3 cols - # offset diagonal offset diagonal - # ---------------- ---------------- - # -4 0 -6 0 - # -3 0 -5 0 - # -2 1 -4 1 - # -1 2 -3 2 - # 0 3 -2 3 - # 1 3 -1 3 - # 2 3 0 3 - # 3 2 1 2 - # 4 1 2 1 - # 5 0 3 0 - # 6 0 4 0 - - # From above table, we can get the logic below - offset_val = op.Constant(value_ints=[offset]) - if offset < 0: - # row + offset - length = op.Add(dim1_size, offset_val) - start = op.Constant(value_ints=[0]) - else: # offset >= 0 - # col - offset - length = op.Sub(dim2_size, offset_val) - start = offset_val - - # max(min(length, min(row, col)), 0) - length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) - end = op.Add(start, length) - result = op.Slice(result, start, end, axes=axes) - - return result - - -@torch_op("aten::diagonal", trace_only=True) -def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1) -> BOOL: - """diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)""" + is_bool = self.dtype == BOOL.dtype # perm is used to transpose the tensor to make dim1 and dim2 as the last 2 dims # [0,1,2] -> [2,0,1] when dim1=0 and dim2=1 @@ -2700,10 +2594,16 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col mask_shape = op.Concat(dim1_size, dim2_size, axis=0) mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset) - self_int = op.Cast(self, to=INT64.dtype) - mask_int = op.Cast(mask, to=INT64.dtype) - self_int_t = op.Transpose(self_int, perm=perm) - result = op.Mul(self_int_t, mask_int) + + if is_bool: + self_int = op.Cast(self, to=INT64.dtype) + mask_int = op.Cast(mask, to=INT64.dtype) + self_int_t = op.Transpose(self_int, perm=perm) + result = op.Mul(self_int_t, mask_int) + else: + mask = op.CastLike(mask, self) + self_t = op.Transpose(self, perm=perm) + result = op.Mul(self_t, mask) result = op.ReduceSum(result, keepdims=False, axes=axes) # min(row, col) min_dim_size = op.Min(dim1_size, dim2_size) @@ -2740,7 +2640,9 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1 length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0])) end = op.Add(start, length) result = op.Slice(result, start, end, axes=axes) - result = op.Cast(result, to=BOOL.dtype) + + if is_bool: + result = op.Cast(result, to=BOOL.dtype) return result @@ -2851,45 +2753,37 @@ def aten_div_complex(self: TFloat, other: TFloat) -> TFloat: @torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: Optional[str] = None) -> TFloat: +def aten_div_mode(self: TReal, other: TReal, rounding_mode: Optional[str] = None) -> TReal: """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor""" assert rounding_mode in {"trunc", "floor", None} - if rounding_mode == "trunc": - # Rounds the results of the division towards zero. - # Equivalent to C-style integer division - return aten_trunc(op.Div(self, other)) - if rounding_mode == "floor": - return op.Floor(op.Div(self, other)) - - return op.Div(self, other) + if self.dtype.is_integer(): + quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + if rounding_mode == "trunc": + # Rounds the results of the division towards zero. + # Equivalent to C-style integer division + result = aten_trunc(quotient) + return op.CastLike(result, self) + if rounding_mode == "floor": + result = op.Floor(quotient) + return op.CastLike(result, self) -@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) -def aten_div_mode_int( - self: TInt, other: TInt, rounding_mode: Optional[str] = None -) -> TensorType: - """div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor - - Variant for integer inputs. - """ - assert rounding_mode in {"trunc", "floor", None} + assert rounding_mode is None + # When rounding_mode is None, the return type is float32 + return quotient - quotient = op.Div(op.Cast(self, to=FLOAT.dtype), op.Cast(other, to=FLOAT.dtype)) + # Float inputs if rounding_mode == "trunc": # Rounds the results of the division towards zero. # Equivalent to C-style integer division - result = aten_trunc(quotient) - return op.CastLike(result, self) + return aten_trunc(op.Div(self, other)) if rounding_mode == "floor": - result = op.Floor(quotient) - return op.CastLike(result, self) + return op.Floor(op.Div(self, other)) - assert rounding_mode is None - # When rounding_mode is None, the return type is float32 - return quotient + return op.Div(self, other) @torch_op("aten::dot", trace_only=True) @@ -3299,20 +3193,20 @@ def aten_embedding_sparse_backward( @torch_op("aten::empty.memory_format", trace_only=True) def aten_empty( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", pin_memory: bool = False, memory_format: str = "", ) -> TensorType: # type: ignore[type-var] - # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + """empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # using Zeros to simulate np.empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + # using Zeros to simulate empty() + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3347,17 +3241,18 @@ def aten_empty_quantized( @torch_op("aten::empty_strided", trace_only=True) def aten_empty_strided( - size: INT64, + size: Sequence[INT64], stride: INT64, layout: str = "", + dtype: int = FLOAT.dtype, device: str = "", pin_memory: bool = False, ) -> TTensor: # type: ignore[type-var] # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor # using Zeros to simulate empty() - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) @@ -3405,13 +3300,14 @@ def aten_exp2(self: TFloat) -> TFloat: @torch_op("aten::expand", trace_only=True) -def aten_expand(self: TTensor, size: TInt, implicit: bool = False) -> TTensor: +def aten_expand(self: TTensor, size: Sequence[INT64], implicit: bool = False) -> TTensor: """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1. # To support -1 dim, we need to convert -1 to 1. - size = op.Abs(size) - return op.Expand(self, size) + # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely + # and isn't expected to appear from correct usages of SymInt. + size = [1 if isinstance(s, int) and s == -1 else s for s in size] + return op.Expand(self, common_ops.merge_dims(size)) @torch_op("aten::expand_as", trace_only=True) @@ -3435,17 +3331,58 @@ def aten_eye(n: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True) def aten_fake_quantize_per_channel_affine( - self: TensorType, - scale: TensorType, - zero_point: TensorType, + self: TFloat, + scale: FLOAT, # float32 specifically! + zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only! axis: int, quant_min: int, quant_max: int, ) -> TensorType: """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + if zero_point.type.dtype == ir.DataType.INT32: + zero_point = op.Cast(zero_point, to=int_dtype) + else: + raise NotImplementedError( + "ONNX only supports integer values for the zero_point parameter. " + f"Got {zero_point.type.dtype}", + ) + + quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_channel_affine_cachemask( @@ -3469,12 +3406,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward( raise NotImplementedError() +@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True) def aten_fake_quantize_per_tensor_affine( - self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int -) -> TensorType: + self: TFloat, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, +) -> TFloat: """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True) +def aten_fake_quantize_per_tensor_affine_tensor_qparams( + self: TFloat, + scale: TReal, + zero_point: TReal, + quant_min: int, + quant_max: int, +) -> TFloat: + """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor""" + + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +def _aten_fake_quantize_per_tensor_affine( + self: TFloat, + scale: Union[float, TReal], + zero_point: Union[int, TReal], + quant_min: int, + quant_max: int, +) -> TFloat: + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + # TODO: When opset >= 19, relex the condition for this cast + if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT: + scale = op.Cast(scale, to=ir.DataType.FLOAT) + + if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype: + zero_point = op.Cast(zero_point, to=int_dtype) + + quantized = op.QuantizeLinear(self, scale, zero_point) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_tensor_affine_cachemask( @@ -3671,23 +3675,27 @@ def python_math_floor(self: TFloat) -> TInt: @torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat: +def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor: """floor_divide(Tensor self, Tensor other) -> Tensor""" - return op.Floor(op.Div(self, other)) + if self.dtype.is_floating_point(): + return op.Floor(op.Div(self, other)) + assert self.dtype.is_integer() -@torch_op("aten::floor_divide", trace_only=True) -def aten_floor_divide_int(self: TInt, other: TInt) -> TInt: - """floor_divide(Tensor self, Tensor other) -> Tensor""" + if not self.dtype.is_signed(): + return op.Div(self, other) - # TODO(justinchuby): This can be simplified if we can constrain the - # inputs to be positive integers. Consider how we can embed constraints in the model. - dtype = self.dtype - self = op.Cast(self, to=FLOAT.dtype) - other = op.Cast(other, to=FLOAT.dtype) - result = op.Floor(op.Div(self, other)) - return op.Cast(result, to=dtype) + # Convert truncation to flooring + # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70 + # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + offset = op.And( + op.Not(op.Equal(op.Sign(self), op.Sign(other))), + op.Cast(op.Mod(self, other), to=BOOL.dtype), + ) + offset = op.Cast(offset, to=self.dtype) + return op.Sub(op.Div(self, other), offset) @torch_op("_operator::floordiv", trace_only=True) @@ -3820,11 +3828,15 @@ def aten_gather( else: return op.Expand(self, op.Shape(index)) - if len(index.shape) == 0: - return op.Identity(self) + is_scalar_index = len(index.shape) == 0 + if is_scalar_index: + index = op.Unsqueeze(index, [0]) index = op.Cast(index, to=INT64.dtype) result = op.GatherElements(self, index, axis=dim) + + if is_scalar_index: + result = op.Squeeze(result, [0]) return result @@ -3843,29 +3855,27 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), + ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor"), trace_only=True, ) -def aten_ge(self: TReal, other: TReal) -> BOOL: +def aten_ge(self: TTensor, other: TTensor) -> BOOL: """ge.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.GreaterOrEqual(self, other) - + if self.dtype == ir.DataType.BOOL: + # self, other, self >= other + # F, F, T + # F, T, F + # T, F, T + # T, T, T + return op.Or(self, op.Not(other)) -@torch_op( - ("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"), - trace_only=True, -) -def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL: - """ge.Tensor(Tensor self, Tensor other) -> Tensor""" + return op.GreaterOrEqual(self, other) - # self, other, self >= other - # F, F, T - # F, T, F - # T, F, T - # T, T, T - return op.Or(self, op.Not(other)) +@torch_op("_operator::ge", trace_only=True) +def operator_ge(self: TTensor, other: TTensor) -> BOOL: + # operator.ge for SymInt + return op.GreaterOrEqual(self, other) def aten_geqrf(self: TensorType) -> tuple[TensorType, TensorType]: @@ -3880,6 +3890,192 @@ def aten_ger(self: TensorType, vec2: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("aten::gru.input", trace_only=True) +def aten_gru( + input: TFloat, + hx: TFloat, + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat]: + """gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)""" + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(hx, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + + for layer_idx in range(num_layers): + # Extract hidden state for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(hx, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX GRU + # ONNX expects: W[zrh] shape [num_directions, 3*hidden_size, input_size] + # PyTorch provides: W_ih shape [3*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [3*hidden_size, input_size] - PyTorch order: [r,z,n] + W_hh = params[ + dir_param_start + 1 + ] # [3*hidden_size, hidden_size] - PyTorch order: [r,z,n] + + # Reorder gates from PyTorch [r,z,n] to ONNX [z,r,n] + # Split into individual gates + W_ir = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_iz = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_in = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + W_hr = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hz = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hn = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + W_ih_reordered = op.Concat( + W_iz, W_ir, W_in, axis=0 + ) # [3*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hz, W_hr, W_hn, axis=0 + ) # [3*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 3*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 3*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[dir_param_start + 2] # [3*hidden_size] - PyTorch order: [r,z,n] + b_hh = params[dir_param_start + 3] # [3*hidden_size] - PyTorch order: [r,z,n] + + # Reorder biases from PyTorch [r,z,n] to ONNX [z,r,n] + b_ir = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_iz = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_in = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + b_hr = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hz = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hn = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + + # Reorder: [z,r,n] + b_ih_reordered = op.Concat( + b_iz, b_ir, b_in, axis=0 + ) # [3*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hz, b_hr, b_hn, axis=0 + ) # [3*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[zrh], Rb[zrh]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [6*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 6*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) + + # Call ONNX GRU operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from hx shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = hx.shape[2] + + if B is not None: + Y, Y_h = op.GRU( + current_input, + W, + R, + B, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h = op.GRU( + current_input, + W, + R, + initial_h=layer_h, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden state + output_h_list.append(Y_h) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h + + @torch_op(("_operator::getitem", "aten::getitem")) def aten_getitem(self: Sequence[TTensor], i: INT64) -> TTensor: return op.SequenceAt(self, i) @@ -3991,28 +4187,28 @@ def aten_gru_cell( @torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), + ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor"), trace_only=True, ) -def aten_gt(self: TReal, other: TReal) -> BOOL: +def aten_gt(self: TTensor, other: TTensor) -> BOOL: """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Greater(self, other) + if self.dtype == ir.DataType.BOOL: + # self, other, self > other + # F, F, F + # F, T, F + # T, F, T + # T, T, F + return op.And(self, op.Not(other)) -@torch_op( - ("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), - trace_only=True, -) -def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: - """gt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self > other - # F, F, F - # F, T, F - # T, F, T - # T, T, F + return op.Greater(self, other) - return op.And(self, op.Not(other)) + +@torch_op("_operator::gt", trace_only=True) +def operator_gt(self: TTensor, other: TTensor) -> BOOL: + # operator.gt for SymInt + return op.Greater(self, other) @torch_op("aten::hamming_window", trace_only=True) @@ -4125,7 +4321,7 @@ def reshape_to_atleast_2d(tensor): result = op.ConcatFromSequence(tensors_atleast_2d, axis=1, new_axis=0) # hstack expects a non-empty sequence of tensors. So we don't need to check for length - rank_1d_or_less = op.Less(Rank(op.SequenceAt(tensors, 0)), 2) + rank_1d_or_less = op.Less(op.Size(op.Shape(op.SequenceAt(tensors, 0))), 2) if rank_1d_or_less: result = op.Reshape(result, op.Constant(value_ints=[-1])) return result @@ -4358,80 +4554,135 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ - - def _make_reshape_list_broadcastable(reshape_list, values_shape): - # Remove ones until the rank of reshape_list matches values_shape. - while len(reshape_list) > len(values_shape) and 1 in reshape_list: - reshape_list.remove(1) - - # Now ensure each dimension is broadcastable: - # This is mandatory when mixing basic and advanced indexing - # Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3) - # the reshape list should be : [[2, 1], [1, 3], [2, 1]] - for i, r in enumerate(reshape_list): - if r not in (1, values_shape[i]): - value_index = values_shape.index(r) - # Swap elements - # For the example above the current reshape list is [1, 2] for last dim, - # to make it broadcastable, we swap the elements - reshape_list[value_index], reshape_list[i] = r, 1 - - return reshape_list - - # Ensure the number of indices matches the tensor rank. + # Ensure the number of indices matches the tensor rank by appending trailing Nones. self_rank = len(self.shape) if len(indices) < self_rank: indices = list(indices) + [None] * (self_rank - len(indices)) - # Get values shape - values_shape = tuple(values.shape) + # The behavior of the op is dependent on whether there are advanced indices (i.e., non-scalar tensors) + # and whether these advanced indices are contiguous. + + # Identify advanced indices. + def is_advanced_index(index): + # Note: In this function, the index is assumed to be either None or an int64 Tensor. + return index is not None + + advanced_indices: list[int] = [] + none_indices: list[int] = [] + num_advanced_indices = 0 + num_none_indices = 0 + + for i, index in enumerate(indices): + if is_advanced_index(index): + advanced_indices.append(i) + num_advanced_indices += 1 + elif index is None: + none_indices.append(i) + num_none_indices += 1 + else: + raise ValueError(f"Unhandled index at position {i}: {index}") + + self_shape = op.Shape(self) + if num_advanced_indices == 0: + return op.Expand(values, self_shape) + + # More than one advanced index may require broadcasting of index values + if num_advanced_indices > 1: + # Check for special case where all advanced indices have same shape. + # But need to ensure none of the shapes have None as a dimension, which + # will invalidate equality-based check. + first_shape = indices[advanced_indices[0]].shape + + def same_shape(other_shape: ir.Shape) -> bool: + return (not any(d is None for d in other_shape)) and other_shape == first_shape + + all_same_shape = all(same_shape(indices[i].shape) for i in advanced_indices) + if not all_same_shape: + # Broadcast advanced indices to a common shape. + advanced_index_rank = max(len(indices[i].shape) for i in advanced_indices) + shapes = [] + for i in advanced_indices: + index = indices[i] + index_rank = len(index.shape) + index_shape = op.Shape(index) + if index_rank < advanced_index_rank: + padding = op.Constant( + value_ints=[1 for _ in range(advanced_index_rank - index_rank)] + ) + index_shape = op.Concat(padding, index_shape, axis=0) + shapes.append(index_shape) + advanced_indices_shape = op.Max(*shapes) + indices = [ + op.Expand(index, advanced_indices_shape) if is_advanced_index(index) else index + for index in indices + ] + else: + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) + else: + advanced_indices_shape = op.Shape(indices[advanced_indices[0]]) + advanced_index_rank = len(indices[advanced_indices[0]].shape) + + # ONNX ScatterND supports only the case where all advanced indices appear first, + # followed by None indices. So, we need to transpose self and values so that the + # advanced indices appear first, and then transpose the result back to original + # order at the end. + + none_indices_constant = op.Constant(value_ints=none_indices) + none_indices_shape = op.Gather(self_shape, none_indices_constant, axis=0) + target_shape = op.Concat(advanced_indices_shape, none_indices_shape, axis=0) + target_rank = advanced_index_rank + num_none_indices + + # Generate indices tensor required by ONNX ScatterND by unsqueezing an extra dimension and + # concatenating all advanced indices along this new dimension. + minus_one = op.Constant(value_ints=[-1]) + advanced_index_values = [op.Unsqueeze(indices[i], minus_one) for i in advanced_indices] + onnx_index = op.Concat(*advanced_index_values, axis=-1) + + # Check if advanced indices are contiguous: + contiguous = True + if advanced_indices: + if advanced_indices[-1] - advanced_indices[0] + 1 != len(advanced_indices): + contiguous = False + + # Bring advanced indices to front: + perm = advanced_indices + none_indices + transposed = op.Transpose(self, perm=perm) + + # Expand values to match target shape: + # First, transpose values if necessary to match advanced indices order! + if contiguous: + # values may need to be transposed before expanding to target shape + num_padded_dims = target_rank - len(values.shape) + if num_padded_dims > 0: + unsqueezed_dims = op.Constant(value_ints=list(range(num_padded_dims))) + values = op.Unsqueeze(values, unsqueezed_dims) + initial_none_index_positions = list(range(advanced_indices[0])) + advanced_index_replacement_positions = list( + range(advanced_indices[0], advanced_indices[0] + advanced_index_rank) + ) + final_none_index_positions = list( + range(advanced_indices[0] + advanced_index_rank, target_rank) + ) + values_perm = ( + advanced_index_replacement_positions + + initial_none_index_positions + + final_none_index_positions + ) + values = op.Transpose(values, perm=values_perm) + + expanded_values = op.Expand(values, target_shape) + + updated = op.ScatterND( + transposed, onnx_index, expanded_values, reduction="add" if accumulate else None + ) - index_vectors = [] - for i in range(self_rank): - if indices[i] is None: - # For a full slice along dim i, create a range index [0, self.shape[i]). - idx = op.Range(0, self.shape[i], 1) - reshape_update = self.shape[i] - else: - idx = indices[i] - reshape_update = math.prod(idx.shape) - # when Index is more than 1D, flatten it and also the values shape - # Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3) - # Indices -> (2*4,) and values shape (2*4, 32) - if len(idx.shape) > 1: - values_shape = (reshape_update, *values_shape[len(idx.shape) :]) - - # Flatten index (always working with 1D index in each dim) - idx = op.Reshape(idx, [-1]) - - # Create a reshape pattern: one value per index dimension, - # with the current dimension set to the update size. - reshape_list = [1] * len(indices) - reshape_list[i] = reshape_update - - # Adjust the reshape list to match the values shape. - reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape) - - # Reshape and expand the index. - idx = op.Reshape(idx, reshape_list, allowzero=True) - idx = op.Expand(idx, values_shape) - - # Flatten the index to 1D and unsqueeze to form a column vector. - idx = op.Reshape(idx, [-1]) - idx = op.Unsqueeze(idx, axes=[1]) - index_vectors.append(idx) - - # Concatenate the index vectors along axis=1 to form the final indices. - new_index = op.Concat(*index_vectors, axis=1) - - # Flatten values to match the indices - flat_values = op.Reshape(values, [-1]) - - if accumulate: - result = op.ScatterND(self, new_index, flat_values, reduction="add") - else: - result = op.ScatterND(self, new_index, flat_values) + # Inverse transpose to restore original dimension order: + inverse_perm = [0] * self_rank + for i, p in enumerate(perm): + inverse_perm[p] = i + result = op.Transpose(updated, perm=inverse_perm) return result @@ -4830,29 +5081,28 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType: @torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), + ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor"), trace_only=True, ) -def aten_le(self: TReal, other: TReal) -> BOOL: +def aten_le(self: TTensor, other: TTensor) -> BOOL: """le.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.LessOrEqual(self, other) + if self.dtype == ir.DataType.BOOL: + # self, other, self <= other + # F, F, T + # F, T, T + # T, F, F + # T, T, T + return op.Or(other, op.Not(self)) -@torch_op( - ("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), - trace_only=True, -) -def aten_le_bool(self: BOOL, other: BOOL) -> BOOL: - """le.Tensor(Tensor self, Tensor other) -> Tensor""" + return op.LessOrEqual(self, other) - # self, other, self <= other - # F, F, T - # F, T, T - # T, F, F - # T, T, T - return op.Or(other, op.Not(self)) +@torch_op("_operator::le", trace_only=True) +def operator_le(self: TTensor, other: TTensor) -> BOOL: + # operator.le for SymInt + return op.LessOrEqual(self, other) @torch_op(("aten::lerp.Tensor", "aten::lerp.Scalar")) @@ -5012,81 +5262,63 @@ def aten_logdet(self: TFloat) -> TFloat: return op.Log(op.Det(self)) -@torch_op( - ( - "aten::logical_and", - "aten::bitwise_and.Tensor", - "aten::bitwise_and.Scalar", - "aten::bitwise_and.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_and(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_and", trace_only=True) +def aten_logical_and(self: TTensor, other: TTensor) -> BOOL: """logical_and(Tensor self, Tensor other) -> Tensor""" - return op.And(self, other) + assert self.dtype == other.dtype + + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) + return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True) -def aten_logical_not(self: BOOL) -> BOOL: +@torch_op("aten::logical_not", trace_only=True) +def aten_logical_not(self: TTensor) -> BOOL: """logical_not(Tensor self) -> Tensor""" - return op.Not(self) + if self.dtype == ir.DataType.BOOL: + return op.Not(self) + return op.Not(op.Cast(self, to=BOOL.dtype)) -@torch_op( - ( - "aten::logical_or", - "aten::bitwise_or.Tensor", - "aten::bitwise_or.Scalar", - "aten::bitwise_or.Scalar_Tensor", - "aten::add.Tensor", - "aten::add.Scalar", - ), - trace_only=True, -) -def aten_logical_or(self: BOOL, other: BOOL) -> BOOL: +@torch_op("aten::logical_or", trace_only=True) +def aten_logical_or(self: TTensor, other: TTensor) -> BOOL: """logical_or(Tensor self, Tensor other) -> Tensor""" - return op.Or(self, other) - - -@torch_op( - ( - "aten::logical_xor", - "aten::bitwise_xor.Tensor", - "aten::bitwise_xor.Scalar", - "aten::bitwise_xor.Scalar_Tensor", - ), - trace_only=True, -) -def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL: - """logical_xor(Tensor self, Tensor other) -> Tensor""" + assert self.dtype == other.dtype - return op.Xor(self, other) + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) + return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) -@torch_op("aten::logit", private=True) -def _aten_logit_onnx(self: TFloat) -> TFloat: - return op.Log(op.Div(self, op.Sub(1.0, self))) - +@torch_op("aten::logical_xor", trace_only=True) +def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL: + """logical_xor(Tensor self, Tensor other) -> Tensor""" -@torch_op("aten::logit", private=True) -def _aten_logit_clamp_onnx(self: TFloat, eps: float) -> TFloat: - eps = op.CastLike(eps, self) - one = op.CastLike(1.0, self) - temporary_self = op.Where(self <= one - eps, self, one - eps) - z = op.Where(temporary_self < eps, eps, temporary_self) + assert self.dtype == other.dtype - return op.Log(op.Div(z, op.Sub(one, z))) + if self.dtype == ir.DataType.BOOL: + return op.Xor(self, other) + return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype)) @torch_op("aten::logit", trace_only=True) def aten_logit(self: TFloat, eps: Optional[float] = None) -> TFloat: """logit(Tensor self, float? eps=None) -> Tensor""" + one = ir.tensor(1, dtype=self.dtype) + if eps is None: - return _aten_logit_onnx(self) - return _aten_logit_clamp_onnx(self, eps) + return op.Log(op.Div(self, op.Sub(one, self))) + + one_minus_eps = ir.tensor(1 - eps, dtype=self.dtype) + eps = ir.tensor(eps, dtype=self.dtype) + + temporary_self = op.Where(self <= one_minus_eps, self, one_minus_eps) + z = op.Where(temporary_self < eps, eps, temporary_self) + + return op.Log(op.Div(z, op.Sub(one, z))) def aten_logspace(start: float, end: float, steps: int, base: float = 10.0) -> TensorType: @@ -5141,30 +5373,234 @@ def aten_lstm_mps_backward( raise NotImplementedError() -@torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), - trace_only=True, -) -def aten_lt(self: TReal, other: TReal) -> BOOL: - """lt.Tensor(Tensor self, Tensor other) -> Tensor""" +@torch_op("aten::lstm.input", trace_only=True) +def aten_lstm( + input: TFloat, + hx: Sequence[TFloat], + params: Sequence[TFloat], + has_biases: bool, + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool, +) -> tuple[TFloat, TFloat, TFloat]: + """lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)""" + + # Extract initial hidden and cell states + initial_h = hx[0] # Shape: [num_directions * num_layers, batch_size, hidden_size] + initial_c = hx[1] # Shape: [num_directions * num_layers, batch_size, hidden_size] + + # Determine number of directions + num_directions = 2 if bidirectional else 1 + + # Get dimensions + if batch_first: + # Convert from [batch, seq, input_size] to [seq, batch, input_size] + input = op.Transpose(input, perm=[1, 0, 2]) + + hidden_size = op.Shape(initial_h, start=2, end=3) + + # Process each layer + current_input = input + output_h_list = [] + output_c_list = [] + + for layer_idx in range(num_layers): + # Extract hidden and cell states for this layer + layer_start = layer_idx * num_directions + layer_end = (layer_idx + 1) * num_directions + layer_h = op.Slice(initial_h, layer_start, layer_end, axes=[0]) + layer_c = op.Slice(initial_c, layer_start, layer_end, axes=[0]) + + # Extract parameters for this layer + # Parameter layout: [W_ih, W_hh, b_ih, b_hh] for each direction + params_per_direction = 4 if has_biases else 2 + params_per_layer = params_per_direction * num_directions + param_start_idx = layer_idx * params_per_layer + + # Build weight matrices for ONNX LSTM + # ONNX expects: W[iofc] shape [num_directions, 4*hidden_size, input_size] + # PyTorch provides: W_ih shape [4*hidden_size, input_size] + W_list = [] + R_list = [] + B_list = [] if has_biases else None + + for dir_idx in range(num_directions): + dir_param_start = param_start_idx + dir_idx * params_per_direction + W_ih = params[ + dir_param_start + ] # [4*hidden_size, input_size] - PyTorch order: [i,f,g,o] + W_hh = params[ + dir_param_start + 1 + ] # [4*hidden_size, hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder gates from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + # Split into individual gates + W_ii = op.Slice(W_ih, starts=[0], ends=hidden_size, axes=[0]) + W_if = op.Slice(W_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_ig = op.Slice(W_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_io = op.Slice(W_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + W_hi = op.Slice(W_hh, starts=[0], ends=hidden_size, axes=[0]) + W_hf = op.Slice(W_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + W_hg = op.Slice(W_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + W_ho = op.Slice(W_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + W_ih_reordered = op.Concat( + W_ii, W_io, W_if, W_ig, axis=0 + ) # [4*hidden_size, input_size] - ONNX order + W_hh_reordered = op.Concat( + W_hi, W_ho, W_hf, W_hg, axis=0 + ) # [4*hidden_size, hidden_size] - ONNX order + + # Add direction dimension + W_ih_expanded = op.Unsqueeze(W_ih_reordered, [0]) # [1, 4*hidden_size, input_size] + W_hh_expanded = op.Unsqueeze( + W_hh_reordered, [0] + ) # [1, 4*hidden_size, hidden_size] + + W_list.append(W_ih_expanded) + R_list.append(W_hh_expanded) + + if has_biases: + b_ih = params[ + dir_param_start + 2 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + b_hh = params[ + dir_param_start + 3 + ] # [4*hidden_size] - PyTorch order: [i,f,g,o] + + # Reorder biases from PyTorch [i,f,g,o] to ONNX [i,o,f,g] + b_ii = op.Slice(b_ih, starts=[0], ends=hidden_size, axes=[0]) + b_if = op.Slice(b_ih, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_ig = op.Slice(b_ih, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_io = op.Slice(b_ih, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + b_hi = op.Slice(b_hh, starts=[0], ends=hidden_size, axes=[0]) + b_hf = op.Slice(b_hh, starts=hidden_size, ends=hidden_size * 2, axes=[0]) + b_hg = op.Slice(b_hh, starts=hidden_size * 2, ends=hidden_size * 3, axes=[0]) + b_ho = op.Slice(b_hh, starts=hidden_size * 3, ends=hidden_size * 4, axes=[0]) + + # Reorder: [i,o,f,g] + b_ih_reordered = op.Concat( + b_ii, b_io, b_if, b_ig, axis=0 + ) # [4*hidden_size] - ONNX order + b_hh_reordered = op.Concat( + b_hi, b_ho, b_hf, b_hg, axis=0 + ) # [4*hidden_size] - ONNX order + + # ONNX expects biases concatenated: [Wb[iofc], Rb[iofc]] + b_combined = op.Concat( + b_ih_reordered, b_hh_reordered, axis=0 + ) # [8*hidden_size] + b_expanded = op.Unsqueeze(b_combined, [0]) # [1, 8*hidden_size] + B_list.append(b_expanded) + + # Concatenate weights for all directions + W = op.Concat(*W_list, axis=0) if len(W_list) > 1 else W_list[0] + R = op.Concat(*R_list, axis=0) if len(R_list) > 1 else R_list[0] + B = ( + op.Concat(*B_list, axis=0) + if has_biases and len(B_list) > 1 + else (B_list[0] if has_biases else None) + ) - return op.Less(self, other) + # Call ONNX LSTM operator + direction = "bidirectional" if bidirectional else "forward" + + # Extract hidden_size from initial_h shape: [num_layers * num_directions, batch, hidden_size] + hidden_size_attr = initial_h.shape[2] + + if B is not None: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + B, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + else: + Y, Y_h, Y_c = op.LSTM( + current_input, + W, + R, + initial_h=layer_h, + initial_c=layer_c, + direction=direction, + hidden_size=hidden_size_attr, + ) + + # Y shape: [seq_length, num_directions, batch_size, hidden_size] + # Reshape to [seq_length, batch_size, num_directions * hidden_size] + Y = op.Transpose( + Y, perm=[0, 2, 1, 3] + ) # [seq_length, batch_size, num_directions, hidden_size] + Y_shape = op.Shape(Y) + new_shape = op.Concat( + op.Slice(Y_shape, [0], [1]), # seq_length + op.Slice(Y_shape, [1], [2]), # batch_size + op.Reshape( + op.Mul( + op.Slice(Y_shape, [2], [3]), # num_directions + op.Slice(Y_shape, [3], [4]), # hidden_size + ), + op.Constant(value_ints=[-1]), + ), + axis=0, + ) + current_input = op.Reshape(Y, new_shape) + + # Apply dropout if not last layer and dropout > 0 + if layer_idx < num_layers - 1 and dropout > 0.0 and train: + current_input, _ = op.Dropout(current_input, dropout, train) + + # Store final hidden and cell states + output_h_list.append(Y_h) + output_c_list.append(Y_c) + + # Concatenate all layer outputs + final_h = ( + output_h_list[0] if len(output_h_list) == 1 else op.Concat(*output_h_list, axis=0) + ) + final_c = ( + output_c_list[0] if len(output_c_list) == 1 else op.Concat(*output_c_list, axis=0) + ) + + # Handle batch_first for output + if batch_first: + # Convert from [seq, batch, features] to [batch, seq, features] + current_input = op.Transpose(current_input, perm=[1, 0, 2]) + + return current_input, final_h, final_c @torch_op( - ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), + ("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor"), trace_only=True, ) -def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL: +def aten_lt(self: TTensor, other: TTensor) -> BOOL: """lt.Tensor(Tensor self, Tensor other) -> Tensor""" - # self, other, self < other - # F, F, F - # F, T, T - # T, F, F - # T, T, F + if self.dtype == ir.DataType.BOOL: + # self, other, self < other + # F, F, F + # F, T, T + # T, F, F + # T, T, F + return op.And(other, op.Not(self)) + + return op.Less(self, other) + - return op.And(other, op.Not(self)) +@torch_op("_operator::lt", trace_only=True) +def operator_lt(self: TTensor, other: TTensor) -> BOOL: + # operator.lt for SymInt + return op.Less(self, other) def aten_lu_solve(self: TensorType, LU_data: TensorType, LU_pivots: TensorType) -> TensorType: @@ -5338,18 +5774,14 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I return result, indices -@torch_op("aten::maximum") -def aten_maximum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::maximum", trace_only=True) +def aten_maximum(self: TTensor, other: TTensor) -> TTensor: """maximum(Tensor self, Tensor other) -> Tensor""" - return op.Max(self, other) - - -@torch_op("aten::maximum") -def aten_maximum_bool(self: BOOL, other: BOOL) -> BOOL: - """maximum(Tensor self, Tensor other) -> Tensor""" + if self.dtype == ir.DataType.BOOL: + return op.Or(self, other) - return op.Or(self, other) + return op.Max(self, other) @torch_op("aten::mean") @@ -5384,7 +5816,7 @@ def aten_meshgrid(tensors: Sequence[TensorType]) -> TensorType: raise NotImplementedError() -@torch_op("aten::min") +@torch_op("aten::min", trace_only=True) def aten_min(self: TReal) -> TReal: """min(Tensor self) -> Tensor""" @@ -5405,18 +5837,14 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T return result, indices -@torch_op("aten::minimum") -def aten_minimum(self: TReal, other: TReal) -> TReal: +@torch_op("aten::minimum", trace_only=True) +def aten_minimum(self: TTensor, other: TTensor) -> TTensor: """minimum(Tensor self, Tensor other) -> Tensor""" - return op.Min(self, other) - + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) -@torch_op("aten::minimum") -def aten_minimum_bool(self: BOOL, other: BOOL) -> BOOL: - """minimum(Tensor self, Tensor other) -> Tensor""" - - return op.And(self, other) + return op.Min(self, other) def aten_miopen_batch_norm( @@ -5756,26 +6184,21 @@ def aten_msort(self: TensorType) -> TensorType: @torch_op( - ("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"), + ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), trace_only=True, ) -def aten_mul(self: TReal, other: TReal) -> TReal: +def aten_mul(self: TTensor, other: TTensor) -> TTensor: """mul.Tensor(Tensor self, Tensor other) -> Tensor""" - return op.Mul(self, other) - + if self.dtype == ir.DataType.BOOL: + return op.And(self, other) -@torch_op( - ("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"), - trace_only=True, -) -def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: - """ONNX Mul doesn't support Boolean, so use And as an equivalent operator.""" + return op.Mul(self, other) - # TODO(justinchuby): Handle cases where type reconcilation is not enough, - # since different ONNX operators are used based on different data types. - return op.And(self, other) +@torch_op("_operator::mul", trace_only=True) +def operator_mul(self: TTensor, other: TTensor) -> TTensor: + return op.Mul(self, other) @torch_op( @@ -6017,7 +6440,6 @@ def aten_native_batch_norm( return norm, input_mean, input_rstd -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_training_onnx( input: TFloat, weight: TFloat, @@ -6069,7 +6491,6 @@ def _aten_native_batch_norm_training_onnx( return norm, mean, rstd, running_mean, new_running_var -@torch_op("aten::native_batch_norm", private=True) def _aten_native_batch_norm_inference_onnx( input: TFloat, weight: TFloat, @@ -6239,22 +6660,10 @@ def aten_native_group_norm( if bias is None: # Set to 0.0 as default, the shape is Channel size bias = op.Expand(op.Constant(value_floats=[0.0]), op.Shape(input, start=1, end=2)) - # Accoding to Torch, return rstd instead of var - norm, mean, rstd = _aten_native_group_norm_onnx(input, weight, bias, group, eps) - return norm, mean, rstd - - -@torch_op("aten::native_group_norm", private=True) -def _aten_native_group_norm_onnx( - input: TFloat, - weight: TFloat, - bias: TFloat, - group: INT64, - eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate + # This implementation should be simplified after opset 21 neg_1 = op.Constant(value_ints=[-1]) # Create weight_instance_norm and bias_instance_norm, copied from Torch ONNX converter group_tensor = op.Reshape(group, neg_1) @@ -6270,7 +6679,7 @@ def _aten_native_group_norm_onnx( norm = op.Reshape(norm, op.Shape(input), allowzero=True) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = Rank(input) + input_rank = len(input.shape) axes_unsqueeze = op.Range(1, input_rank - 1, 1) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) bias_full_shape = op.Unsqueeze(bias, axes_unsqueeze) @@ -6291,7 +6700,9 @@ def _aten_native_group_norm_onnx( sqr_input_sub_mean = op.Mul(input_sub_mean, input_sub_mean) # In Pytorch, vstd = 1/(sqrt(var + eps)) var = op.ReduceMean(sqr_input_sub_mean, axes, keepdims=False) - rstd = op.Div(1.0, op.Sqrt(var + eps)) + eps = op.Constant(value=ir.tensor(eps, dtype=input.dtype)) + one = op.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + rstd = op.Div(one, op.Sqrt(op.Add(var, eps))) # Get the correct shape [N, group] for mean again mean = op.ReduceMean(input_N_group_neg1, axes, keepdims=False) return norm_result, mean, rstd @@ -6503,16 +6914,7 @@ def aten_norm_except_dim(v: TensorType, pow: int = 2, dim: int = 0) -> TensorTyp raise NotImplementedError() -@torch_op( - ( - "aten::normal.Tensor_float", - "aten::normal.Tensor_Tensor", - "aten::normal.float_Tensor", - "aten::normal.float_float", - "aten::normal_functional", - ), - trace_only=True, -) +@torch_op("aten::normal_functional", trace_only=True) def aten_normal( self: TTensor, mean: float = 0.0, @@ -6541,7 +6943,7 @@ def aten_normal_float_float( return op.Cast(result, to=dtype) -@torch_op("aten::normal.float_Tensor") +@torch_op("aten::normal.float_Tensor", trace_only=True) def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: """normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6551,7 +6953,7 @@ def aten_normal_float_tensor(mean: FLOAT, std: TFloat) -> TFloat: return op.Add(op.Mul(std, sampled), mean_casted) -@torch_op("aten::normal.Tensor_float") +@torch_op("aten::normal.Tensor_float", trace_only=True) def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: """normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor""" @@ -6560,7 +6962,7 @@ def aten_normal_tensor_float(mean: TFloat, std: FLOAT) -> TFloat: return op.Add(op.Mul(op.CastLike(std, sampled), sampled), mean) -@torch_op("aten::normal.Tensor_Tensor") +@torch_op("aten::normal.Tensor_Tensor", trace_only=True) def aten_normal_tensor_tensor(mean: TFloat, std: TFloat) -> TFloat: """normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor""" @@ -6704,34 +7106,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType: raise NotImplementedError() -@torch_op("aten::pixel_shuffle") +@torch_op("aten::pixel_shuffle", trace_only=True) def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal: """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor""" - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] + if len(self.shape) == 4: + return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD") + # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD") - output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0) + final_dims = op.Shape(depth_to_space, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(depth_to_space, output_shape, allowzero=True) -@torch_op("aten::pixel_unshuffle") +@torch_op("aten::pixel_unshuffle", trace_only=True) def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal: """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor""" + if len(self.shape) == 4: + return op.SpaceToDepth(self, blocksize=downscale_factor) - self_shape = op.Shape(self) - batch_dims = self_shape[:-3] - chw_in_dims = self_shape[-3:] # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D) + batch_dims = op.Shape(self, end=-3) + chw_in_dims = op.Shape(self, start=-3) + reshaped_self = op.Reshape( self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0) ) space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor) - output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0) + final_dims = op.Shape(space_to_depth, start=1) + output_shape = op.Concat(batch_dims, final_dims, axis=0) return op.Reshape(space_to_depth, output_shape, allowzero=True) @@ -7261,9 +7670,9 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TFloat, other: TFloat) -> TFloat: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" +def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor: + if integer: + return op.Mod(self, other) # TODO(justinchuby): Improve fp16 precision by following the logic in # https://github.com/pytorch/pytorch/blob/3a823e46170778cc32783f27596c77d0103084a9/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L264-L277 @@ -7274,12 +7683,32 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat: return op.Sub(self, op.Mul(rounded_quotient, other)) -@torch_op( - ("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), trace_only=True -) -def aten_remainder_int(self: TInt, other: TInt) -> TInt: +@torch_op("aten::remainder.Tensor", trace_only=True) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + return _aten_remainder(self, other, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar", trace_only=True) +def aten_remainder_scalar(self: TTensor, other: float) -> TTensor: + """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" + + other_tensor = ir.tensor(other, dtype=self.dtype) + return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar_Tensor", trace_only=True) +def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor: + """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + + self_tensor = ir.tensor(self, dtype=other.dtype) + return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer()) + + +@torch_op("_operator::mod", trace_only=True) +def operator_mod(self: TTensor, other: TTensor) -> TTensor: + # Modulus operator % on SymInt return op.Mod(self, other) @@ -7305,20 +7734,134 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor: return op.Tile(self_expanded, repeats) -def aten_repeat_interleave( - repeats: TensorType, output_size: Optional[int] = None +@torch_op("aten::repeat_interleave.self_int", trace_only=True) +def aten_repeat_interleave_self_int( + self: TensorType, + repeats: int, + dim: Optional[int] = None, + output_size: Optional[int] = None, ) -> TensorType: - """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor""" + """repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor - raise NotImplementedError() + The trick is to repeat in one direction orthogonal to reshape. + .. code-block:: python -@torch_op("aten::reshape") -def aten_reshape(self: TTensor, shape: IntType) -> TTensor: - """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat_interleave(2, dim=0) + + is equivalent to: + + .. code-block:: python + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + x.repeat((1, 2)).reshape((-1, t.shape[1])) + """ + if dim is None: + raise NotImplementedError("No conversion available yet when dim is None.") + + self_rank = len(self.shape) + pos_dim = (dim + self_rank) % self_rank + unsqueezed = op.Unsqueeze(self, [pos_dim + 1]) + if isinstance(repeats, int): + tiles = [1] * (self_rank + 1) + tiles[pos_dim + 1] = repeats + tile_repeat = op.Constant(value=ir.tensor(tiles, dtype=INT64.dtype)) + else: + # repeats is a symbolic tensor + tile_repeat = op.Concat( + op.Constant(value=ir.tensor([1] * pos_dim, dtype=INT64.dtype)), + op.Reshape(repeats, op.Constant(value=ir.tensor([-1], dtype=INT64.dtype))), + op.Constant(value=ir.tensor([1] * (self_rank - pos_dim), dtype=INT64.dtype)), + axis=0, + ) + tiled = op.Expand(unsqueezed, tile_repeat) + if self_rank == 1: + return op.Identity(tiled) + final_shape = op.Concat( + op.Shape(self, start=0, end=dim), + op.Constant(value_ints=[-1]), + op.Shape(self, start=pos_dim + 1), + axis=0, + ) + return op.Reshape(tiled, final_shape) + + +@torch_op("aten::repeat_interleave.Tensor", trace_only=True) +def aten_repeat_interleave_Tensor( + self: TensorType, repeats: Optional[TensorType] = None, dim: Optional[int] = None +) -> TensorType: + """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor + + When `repeats` is a tensor, each line is multiplied + by a different number. + There are multiple strategies. Here is one. + + .. code-block:: python + + import torch + + x = torch.tensor([[0, 1, 2], [3, 4, 5]]) + times = torch.tensor([2, 3], dtype=torch.int64) + y = x.repeat_interleave(times, dim=0) + print("repeat_interleave") + print(y) + + ci = times.cumsum(dim=0) + rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1)) + srows = times.shape[0] - rows.to(torch.int64).sum(axis=0) + indices = srows.reshape((-1, )) + print("decomposed") + print(x[indices, :]) + """ + if repeats is None: + repeats = self + self = op.Range(0, op.Squeeze(op.Shape(repeats, start=-1), [0]), 1) + if dim is None: + # flatten + self = op.Reshape(self, [-1]) + rank = 1 + else: + rank = len(self.shape) + + if rank > 2: + shape_x0 = op.Shape(self, start=0, end=1) + shape_x = op.Shape(self, start=1) + self = op.Reshape(self, op.Concat(shape_x0, [-1], axis=0)) + elif rank == 1: + shape_x = None + self = op.Reshape(self, [-1, 1]) + else: + if rank != 2: + raise NotImplementedError( + f"rank(self)={rank} not implemented for repeat_interleave" + ) + shape_x = None + + ci = op.CumSum(repeats, [0]) + last_ci = op.Gather(ci, [-1]) + trange = op.Range(0, op.Squeeze(last_ci, [0]), 1) + rows = op.Less(trange, op.Unsqueeze(ci, [-1])) + srows = op.Sub( + op.Shape(self, start=0, end=1), + op.ReduceSum(op.Cast(rows, to=INT64.dtype), [0]), + ) + indices = op.Reshape(srows, [-1]) + values = op.GatherND(self, op.Unsqueeze(indices, [-1])) + if rank == 2: + return values + # shape_x is None at this stage. + assert shape_x is None # for mypy + return op.Reshape( + values, + op.Concat([-1], shape_x, axis=0) if shape_x else [-1], + ) - # Reshape only support INT64 as 'shape' - shape = op.Cast(shape, to=INT64.dtype) + +@torch_op("aten::reshape", trace_only=True) +def aten_reshape(self: TTensor, shape: Sequence[INT64]) -> TTensor: + """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)""" + shape = common_ops.merge_dims(shape) return op.Reshape(self, shape) @@ -7390,23 +7933,29 @@ def aten_rnn_tanh_cell( def aten_roll(self: TTensor, shifts: Sequence[int], dims: Sequence[int] = ()) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 0: return op.Identity(self) elif self.shape[0] == 0: # empty tensor return op.Identity(self) + + # NOTE: In pytorch, default value of dims is an empty list. + if len(dims) == 0: # Empty sequence + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + return _aten_roll_shift_no_dim_onnx(self, shifts[0]) else: - # NOTE: In pytorch, default value of dims is an empty list. - if len(dims) == 0: # Empty sequence - # assert isinstance(shifts, int) - return _aten_roll_shift_no_dim_onnx(self, shifts) - else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list - result = self - for i, shift in enumerate(shifts): - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + assert len(shifts) == len(dims) + result = self + for i, shift in enumerate(shifts): + dim = dims[i] + result = _aten_roll_shift_and_dim_onnx(result, shift, dim) + return result @torch_op("aten::roll", trace_only=True, complex=True) @@ -7415,6 +7964,12 @@ def aten_roll_complex( ) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" + if isinstance(shifts, int): + shifts = [shifts] + + if isinstance(dims, int): + dims = [dims] + self_rank = len(self.shape) if self_rank == 1: return op.Identity(self) @@ -7425,37 +7980,34 @@ def aten_roll_complex( self_real = op.Slice(self, [0], [1], axes=[-1]) self_imag = op.Slice(self, [1], [2], axes=[-1]) if not dims: - # assert isinstance(shifts, int) - shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts) - shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts) + assert len(shifts) == 1, "shifts should be a single integer if dims is empty" + shift_real = _aten_roll_shift_no_dim_onnx(self_real, shifts[0]) + shift_imag = _aten_roll_shift_no_dim_onnx(self_imag, shifts[0]) result = op.Concat(shift_real, shift_imag, axis=-1) else: - # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list + assert len(shifts) == len(dims) for i, dim in enumerate(dims): - shift = op.Gather(shifts, i, axis=0) - self_real = _aten_roll_shift_and_dim_onnx(self_real, shift, dim) - self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shift, dim) + self_real = _aten_roll_shift_and_dim_onnx(self_real, shifts[i], dim) + self_imag = _aten_roll_shift_and_dim_onnx(self_imag, shifts[i], dim) result = op.Concat(self_real, self_imag, axis=-1) return result -@torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: +def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D] self_flatten = op.Reshape(self, neg_1) # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: + if shift < 0: # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end - slice_length = -shift_tensor + slice_length = op.Constant(value_ints=[-shift]) else: # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end # The effect equals to move [D] to the beginning - slice_length = op.Size(self_flatten) - shift_tensor + slice_length = op.Size(self_flatten) - op.Constant(value_ints=[shift]) # Get second part of the tensor, e.g. [A,B,C] suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length) # Get first part of the tensor, e.g. [D] @@ -7465,15 +8017,13 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: return op.Reshape(result, op.Shape(self)) -@torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: +def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: int, dim: int) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor + dim_tensor = op.Constant(value_ints=[dim]) + if shift < 0: + slice_length = op.Constant(value_ints=[-shift]) else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor + slice_length = op.Shape(self, start=dim, end=dim + 1) - op.Constant(value_ints=[shift]) # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) @@ -7552,7 +8102,7 @@ def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: @torch_op("aten::scalar_tensor", trace_only=True) def aten_scalar_tensor( - s: float, + s: TensorType, dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -7561,8 +8111,7 @@ def aten_scalar_tensor( """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - # Set trace_only=True because different if branches return different dtypes - # which is not supported in an ONNX function + return common_ops.cast_to(s, dtype=dtype) @@ -7591,31 +8140,35 @@ def aten_scalar_tensor_complex( return result -@torch_op("aten::scalar_tensor", trace_only=True) -def aten_scalar_tensor_sym_number( - s: TensorType, - dtype: int = FLOAT.dtype, - layout: str = "", - device: str = "", - pin_memory: bool = False, -) -> RealType: - """scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - if dtype == -1: - dtype = FLOAT.dtype - return common_ops.cast_to(s, dtype=dtype) +@torch_op("aten::scatter.src", trace_only=True) +def aten_scatter_src( + self: TTensor, + dim: int, # we have to use int here because ScatterElements() will use this attribute + index: TInt, + src: TTensor, +) -> TTensor: + """scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) + if len(src.shape) == 0: + src = op.Unsqueeze(src, [0]) + return op.ScatterElements(self, index, src, axis=dim) -@torch_op(("aten::scatter.value", "aten::scatter.src"), trace_only=True) -def aten_scatter( - self: TReal, +@torch_op("aten::scatter.value", trace_only=True) +def aten_scatter_value( + self: TTensor, dim: int, # we have to use int here because ScatterElements() will use this attribute index: TInt, - src: TReal, -) -> TReal: - """scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor""" - - update = op.Expand(src, op.Shape(index)) - return op.ScatterElements(self, index, update, axis=dim) + value: float, +) -> TTensor: + """scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor""" + # Ensure value is a scalar tensor and expand it to match index shape + if len(index.shape) == 0: + index = op.Unsqueeze(index, [0]) + scalar_tensor = ir.tensor([value], dtype=self.dtype) + src = op.ConstantOfShape(op.Shape(index), value=scalar_tensor) + return op.ScatterElements(self, index, src, axis=dim) @torch_op("aten::scatter_add", trace_only=True) @@ -7974,7 +8527,7 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: if self_is_scalar: self = op.Unsqueeze(self, op.Constant(value_ints=[0])) result = op.Softmax(self, axis=dim) - if dtype != -1: + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) if self_is_scalar: # Convert to scalar when input is scalar @@ -7983,21 +8536,6 @@ def aten_softmax(self: TFloat, dim: int, dtype: int = -1) -> TFloat: return result -@torch_op(("aten::softmax.int", "aten::special_softmax"), trace_only=True) -def aten_softmax_no_dtype(self: TFloat, dim: int) -> TFloat: - """softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor""" - - self_is_scalar = len(self.shape) == 0 - if self_is_scalar: - self = op.Unsqueeze(self, op.Constant(value_ints=[0])) - result = op.Softmax(self, axis=dim) - if self_is_scalar: - # Convert to scalar when input is scalar - result = op.Squeeze(result) - - return result - - @torch_op("aten::sort", trace_only=True) def aten_sort( self: TReal, dim: int = -1, descending: bool = False, stable: bool = False @@ -8211,12 +8749,107 @@ def aten_std_mean_correction( return op.Sqrt(var), mean +def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat: + left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), win_length) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + win_length = op.Reshape(win_length, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + window_list = op.Expand(op.Constant(value_ints=[1]), win_length) + return op.Concat(left_win, window_list, right_win, axis=0) + + +def _create_window_from_n_fft(n_fft: int) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + window = op.Expand(op.Constant(value_ints=[1]), n_fft_tensor) + return window + + +def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat: + n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1])) + sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal)) + result = op.Div(result, sqrt_nfft) + return result + + +@torch_op("aten::stft", trace_only=True) +def aten_stft( + self: TFloat, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[TFloat] = None, + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, +) -> TFloat: + """stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None) -> Tensor""" + + # NOTE: regardless of the value of return_complex, we always return a real representation. + del return_complex + + # Get STFT sizes + if hop_length is None: + # core dump + # hop_length = op.Div(op.Constant(value_ints=n_fft), op.Constant(value_ints=[4])) + hop_length = n_fft // 4 + frame_step_const = op.Reshape(hop_length, op.Constant(value_ints=[1])) + + # Pre-process input if needed + is_signal_rank1 = len(self.shape) == 1 + if is_signal_rank1: + # Add a batch dimension + self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0]))) + + # Get window and make sure it's the same size as `win_length` or `n_fft` + if window is not None and window.shape[0] is not None: + # first dimension + n_win = op.Shape(window, start=0, end=1) + # Center window around zeros if needed (required by ONNX's STFT) + if n_win < n_fft: + left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2])) + + right = op.Sub(op.Sub(n_fft, left), n_win) + left = op.Reshape(left, op.Constant(value_ints=[1])) + right = op.Reshape(right, op.Constant(value_ints=[1])) + + left_win = op.Expand(op.Constant(value_ints=[0]), left) + right_win = op.Expand(op.Constant(value_ints=[0]), right) + right_win = op.CastLike(right_win, window) + left_win = op.CastLike(left_win, window) + window = op.Concat(left_win, window, right_win, axis=0) + elif window is None: + if win_length is not None: + window = _create_window_from_win_length(win_length, n_fft) + else: + window = _create_window_from_n_fft(n_fft) + + if onesided is None or onesided: + onesided = 1 + else: + onesided = 0 + window = op.CastLike(window, self) + result = op.STFT(self, frame_step_const, window, n_fft, onesided=onesided) + result = op.Transpose(result, perm=[0, 2, 1, 3]) + # Remove batch dimension, if needed + if is_signal_rank1: + result = op.Squeeze(result, op.Constant(value_ints=[0])) + + # Normalize, if needed + if normalized: + result = _normalize_fft_result(self, result, n_fft) + + return result + + @torch_op( ( "aten::sub.Tensor", - "aten::sub.Scalar", "aten::subtract.Tensor", - "aten::subtract.Scalar", "_operator::sub", ), trace_only=True, @@ -8229,6 +8862,14 @@ def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return op.Sub(self, other) +@torch_op(("aten::sub.Scalar", "aten::subtract.Scalar"), trace_only=True) +def aten_sub_scalar(self: TTensor, other: float, alpha: float = 1.0) -> TTensor: + """sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor""" + + other = op.Constant(value=ir.tensor(other, dtype=self.dtype)) + return aten_sub(self, other, alpha=alpha) + + @torch_op( ( "aten::sub.Tensor", @@ -8309,6 +8950,14 @@ def aten_sym_size(self: TensorType, dim: int = 0) -> INT64: return op.Squeeze(op.Shape(self, end=dim + 1, start=dim)) +@torch_op("aten::sym_storage_offset", trace_only=True) +def aten_sym_storage_offset(self: TensorType, dim: int = 0) -> INT64: + """sym_storage_offset(Tensor self, int dim) -> SymInt""" + # storage offset is not used in onnx world. + # the output of this function is not used. + return op.Constant(value_int=0) + + def aten_symeig( self: TensorType, eigenvectors: bool = False, upper: bool = True ) -> tuple[TensorType, TensorType]: @@ -8321,7 +8970,7 @@ def aten_symeig( def aten_t(self: TTensor) -> TTensor: """t(Tensor(a) self) -> Tensor(a)""" - rank = Rank(self) + rank = len(self.shape) if rank == 2: result = op.Transpose(self, perm=[1, 0]) else: @@ -8404,26 +9053,24 @@ def aten_threshold_backward( raise NotImplementedError() -@torch_op("aten::tile") -def aten_tile(self: TTensor, dims: INT64) -> TTensor: +@torch_op("aten::tile", trace_only=True) +def aten_tile(self: TTensor, dims: Sequence[int]) -> TTensor: """tile(Tensor self, int[] dims) -> Tensor""" - self_rank = Rank(self) - dims_rank = op.Size(dims) - diff = op.Sub(self_rank, dims_rank) + self_rank = len(self.shape) + dims_rank = len(dims) + diff = self_rank - dims_rank if diff > 0: # dims is shorter than self.shape # pad dims with 1 - diff_1d = op.Reshape(diff, op.Constant(value_ints=[1])) - exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) - dims = op.Concat(exapnd_ones, dims, axis=0) + exapnd_ones = [1] * diff + dims = [*exapnd_ones, *dims] - if diff < 0: + elif diff < 0: # dims is longer than self.shape # pad self.shape with 1 - diff_1d = op.Reshape(op.Abs(diff), op.Constant(value_ints=[1])) - exapnd_ones = op.Expand(op.Constant(value_ints=[1]), diff_1d) + exapnd_ones = op.Constant(value_ints=[1] * (-diff)) self_shape = op.Shape(self) self_final_shape = op.Concat(exapnd_ones, self_shape, axis=0) self = op.Reshape(self, self_final_shape, allowzero=True) @@ -8574,7 +9221,7 @@ def aten_triangular_solve( raise NotImplementedError() -@torch_op("aten::tril") +@torch_op("aten::tril", trace_only=True) def aten_tril(self: TTensor, diagonal: int = 0) -> TTensor: """tril(Tensor self, int diagonal=0) -> Tensor""" @@ -8602,7 +9249,7 @@ def aten_triplet_margin_loss( raise NotImplementedError() -@torch_op("aten::triu") +@torch_op("aten::triu", trace_only=True) def aten_triu(self: TTensor, diagonal: int = 0) -> TTensor: """triu(Tensor self, int diagonal=0) -> Tensor""" @@ -8622,6 +9269,14 @@ def aten_trunc(self: TFloat) -> TFloat: return op.Floor(op.Abs(self)) * op.Sign(self) +@torch_op("math::trunc", trace_only=True) +def python_math_trunc(self: TFloat) -> TInt: + """trunc(Tensor self) -> Tensor""" + # NOTE: This is used in SymInt/SymBool/SymFloat context, so + # we don't expect overflow to happen here. + return op.Cast(self, to=INT64.dtype) + + @torch_op("aten::type_as", trace_only=True) def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: """type_as(Tensor self, Tensor other) -> Tensor""" @@ -8629,12 +9284,27 @@ def aten_type_as(self: TTensor, other: TTensor2) -> TTensor2: return op.CastLike(self, other) -@torch_op("aten::unbind.int") +@torch_op("aten::unbind.int", trace_only=True) def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]: """unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]""" - split_sizes = op.Constant(value_int=1) - return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False) + if isinstance(self.shape[dim], int): + num_outputs = self.shape[dim] + results = [] + for i in range(num_outputs): + # Slice to get a single element at position i along dim + sliced = op.Slice( + self, + starts=op.Constant(value_ints=[i]), + ends=op.Constant(value_ints=[i + 1]), + axes=op.Constant(value_ints=[dim]), + ) + # Squeeze to remove the dimension of size 1 + squeezed = op.Squeeze(sliced, axes=[dim]) + results.append(squeezed) + return results + + return op.SplitToSequence(self, axis=dim, keepdims=False) @torch_op("aten::unflatten.int", trace_only=True) @@ -8727,15 +9397,57 @@ def aten_unfold_copy(self: TensorType, dimension: int, size: int, step: int) -> raise NotImplementedError() +@torch_op("aten::unique_consecutive", trace_only=True) def aten_unique_consecutive( - self: TensorType, + x: TensorType, return_inverse: bool = False, return_counts: bool = False, dim: Optional[int] = None, ) -> tuple[TensorType, TensorType, TensorType]: """unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)""" + assert x.dtype in {INT64.dtype, INT32.dtype}, ( + "unique_consecutive not implemented for other type than int32, int64" + ) + rank_x = len(x.shape) - raise NotImplementedError() + zero = op.Constant(value=ir.tensor([0], dtype=x.dtype)) + zero64 = op.Constant(value=ir.tensor([0], dtype=INT64.dtype)) + minus_one = op.Constant(value=ir.tensor([-1], dtype=INT64.dtype)) + + if dim is None: + if rank_x != 1: + x = op.Reshape(x, minus_one) + else: + assert rank_x == 1 and dim == 0, ( + f"Not implemented for x={x!r} with rank={rank_x} and dim={dim}." + ) + + lag = op.Concat( + # Hopefully this will never be equal to the first value of the tensor x + # ideally we could do differently but with a higher cost + op.Constant(value=ir.tensor([_INT32_MAX], dtype=x.dtype)), + op.Slice(x, zero64, minus_one, zero64), + axis=0, + ) + eq = op.Equal(x, lag) + diff = op.Not(eq) + res = op.Compress(x, diff, axis=0) + + zero_no_dim = op.Constant(value=ir.tensor(0, dtype=x.dtype)) + one_no_dim = op.Constant(value=ir.tensor(1, dtype=x.dtype)) + one = op.Constant(value=ir.tensor([1], dtype=x.dtype)) + + inverse = op.Sub(op.CumSum(op.Cast(diff, to=x.dtype), zero), one) + shape_x = op.Shape(x) + indices = op.Range(zero_no_dim, op.Squeeze(shape_x), one_no_dim) + points = op.Compress(indices, diff, axis=0) + lagp = op.Concat( + op.Slice(points, one, op.Shape(points), zero), + shape_x, + axis=0, + ) + counts = op.Sub(lagp, points) + return res, inverse, counts @torch_op("aten::_unique", trace_only=True) @@ -9058,23 +9770,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType: @torch_op(("aten::view", "aten::_unsafe_view"), trace_only=True) -def aten_view(self: TTensor, size: IntType) -> TTensor: +def aten_view(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size, allowzero=True) -@torch_op(("aten::view", "aten::_unsafe_view"), complex=True) -def aten_view_complex(self: TTensor, size: IntType) -> TTensor: +@torch_op(("aten::view", "aten::_unsafe_view"), complex=True, trace_only=True) +def aten_view_complex(self: TTensor, size: Sequence[INT64]) -> TTensor: """view(Tensor(a) self, SymInt[] size) -> Tensor(a)""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input - complex_size = op.Concat(size, op.Constant(value_ints=[2]), axis=0) + complex_size = common_ops.merge_dims([*size, 2]) return op.Reshape(self, complex_size, allowzero=True) -@torch_op("aten::view_as") +@torch_op("aten::view_as", trace_only=True) def aten_view_as(self: TTensor, other: TTensor2) -> TTensor: """view_as(Tensor(a) self, Tensor other) -> Tensor(a)""" @@ -9118,11 +9829,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::view_copy") -def aten_view_copy(self: TTensor, size: IntType) -> TTensor: +@torch_op("aten::view_copy", trace_only=True) +def aten_view_copy(self: TTensor, size: Sequence[INT64]) -> TTensor: """view_copy(Tensor self, SymInt[] size) -> Tensor""" - size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input + size = common_ops.merge_dims(size) return op.Reshape(self, size) @@ -9150,7 +9861,8 @@ def reshape_to_2d(tensor): "aten::where.ScalarSelf", "aten::where.ScalarOther", "aten::where.self", - ) + ), + trace_only=True, ) def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor: """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor""" @@ -9166,7 +9878,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType: @torch_op("aten::zeros", trace_only=True) def aten_zeros( - size: IntType, + size: Sequence[INT64], dtype: int = FLOAT.dtype, layout: str = "", device: str = "", @@ -9175,9 +9887,9 @@ def aten_zeros( """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" if dtype == -1: dtype = FLOAT.dtype - size = op.Cast(size, to=INT64.dtype) - zero = op.Constant(value_float=0.0) - zero = op.Cast(zero, to=dtype) + + zero = op.Constant(value=ir.tensor(0, dtype=ir.DataType(dtype))) + size = common_ops.merge_dims(size) return op.Expand(zero, size) diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 05bac181ca..c9d870bd86 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -330,8 +330,9 @@ def aten_linalg_vector_norm( keepdim = False else: dim = op.Reshape(dim, op.Constant(value_ints=[-1])) - self = op.Abs(self) + if math.isinf(ord): + self = op.Abs(self) if ord > 0: return op.ReduceMax(self, dim, keepdims=keepdim) else: @@ -345,6 +346,9 @@ def aten_linalg_vector_norm( elif ord == 2.0: return op.ReduceL2(self, dim, keepdims=keepdim) else: + if ord < 0 or ord % 2 != 0: + # Not an even integer (could be odd, fractional or negative), use Abs + self = op.Abs(self) self_pow = op.Pow(self, ord) exp = op.CastLike(1 / ord, self) return op.Pow(op.ReduceSum(self_pow, dim, keepdims=keepdim), exp) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index bccddb88a6..5edcc233d0 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -18,7 +18,6 @@ from typing import Optional, Sequence, Tuple, TypeVar, Union from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir -from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( IntType, @@ -32,7 +31,6 @@ from onnxscript.onnx_types import TensorType _MATH_PI = math.pi -Rank = common_ops.Rank _INT64_MAX = 9223372036854775807 _INT64_MIN = -9223372036854775808 @@ -116,6 +114,33 @@ def _adjust_attributes_of_avg_pool( return (kernel_shape, strides, pads) +def _aten_avg_pool_onnx( + self: TFloat, + kernel_shape: Sequence[int], + strides: Sequence[int], + pads: Sequence[int], + ceil_mode: bool, + count_include_pad: bool, +) -> TFloat: + self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1 + if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 + self = op.Unsqueeze(self, [0]) + + result = op.AveragePool( + self, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + kernel_shape=kernel_shape, + pads=pads, + strides=strides, + ) + + if self_rank_is_unbatched_rank: + result = op.Squeeze(result, [0]) + + return result + + @torch_op("aten::avg_pool1d", trace_only=True) def aten_avg_pool1d( self: TFloat, @@ -136,16 +161,7 @@ def aten_avg_pool1d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) @torch_op("aten::avg_pool2d", trace_only=True) @@ -169,15 +185,6 @@ def aten_avg_pool2d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - ceil_mode=ceil_mode, - count_include_pad=count_include_pad, - kernel_shape=kernel_shape, - pads=pads, - strides=strides, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -191,7 +198,7 @@ def aten_avg_pool2d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool2d_backward( @@ -230,15 +237,6 @@ def aten_avg_pool3d( expand_size, kernel_size, stride, padding ) - result = op.AveragePool( - self, - kernel_shape=kernel_shape, - strides=strides, - pads=pads, - count_include_pad=count_include_pad, - ceil_mode=ceil_mode, - ) - # TODO: if want to support divisor_override argument, need to op.Mul(result, mask) # mask = [ # 1, 2, 3, S,..3, 2, 1 @@ -252,7 +250,7 @@ def aten_avg_pool3d( # S is stride size, in this case S=4, # S may dup lot of times according to the image size - return result + return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad) def aten_avg_pool3d_backward( @@ -294,20 +292,16 @@ def aten_binary_cross_entropy_backward( @torch_op("aten::celu", trace_only=True) -def aten_celu(self: FLOAT, alpha: float = 1.0) -> FLOAT: +def aten_celu(self: TFloat, alpha: float = 1.0) -> TFloat: """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" - return op.Celu(self, alpha=alpha) # op.Celu only support float32 - + if self.dtype != FLOAT.dtype: + self_upcasted = op.Cast(self, to=FLOAT.dtype) -@torch_op("aten::celu", trace_only=True) -def aten_celu_type_promoted( - self: TFloatUnlessFloat32, alpha: float = 1.0 -) -> TFloatUnlessFloat32: - """celu(Tensor self, Scalar alpha=1.0) -> Tensor""" + # op.Celu only support float32 + return op.Cast(op.Celu(self_upcasted, alpha=alpha), to=self.dtype) - self_upcasted = op.Cast(self, to=FLOAT.dtype) - return op.CastLike(op.Celu(self_upcasted, alpha=alpha), self) + return op.Celu(self, alpha=alpha) @torch_op("aten::col2im", trace_only=True) @@ -580,7 +574,7 @@ def aten_group_norm( norm = op.Reshape(norm, op.Shape(input)) # Using the input weight and bias to do affine # But need to unsqueeze to the target shape for broading cast easy - input_rank = Rank(input) + input_rank = len(input.shape) one = op.Constant(value_int=1) axes_unsqueeze = op.Range(one, op.Sub(input_rank, one), one) weight_full_shape = op.Unsqueeze(weight, axes_unsqueeze) @@ -1003,7 +997,7 @@ def _aten_max_pool_onnx( ceil_mode: bool, unbatched_rank: int, ) -> TFloatOrUInt8: - self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1 self = op.Unsqueeze(self, [0]) @@ -1137,7 +1131,7 @@ def _aten_max_pool_with_indices_onnx( n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], ) -> Tuple[TFloatOrUInt8, INT64]: - self_rank_is_unbatched_rank = Rank(self) == unbatched_rank + self_rank_is_unbatched_rank = len(self.shape) == unbatched_rank if self_rank_is_unbatched_rank: self = op.Unsqueeze(self, axes=[0]) @@ -1366,11 +1360,11 @@ def aten_nll_loss( ) -> TFloat: """nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor""" - self_rank_is_1 = Rank(self) == 1 + self_rank_is_1 = len(self.shape) == 1 if self_rank_is_1: # self rank should be at least 2 self = op.Unsqueeze(self, [0]) - rank_target = Rank(target) + rank_target = len(target.shape) if rank_target == 0: # target rank should be at least 1 target = op.Unsqueeze(target, [0]) @@ -1503,7 +1497,7 @@ def _process_padding(padding: Sequence[INT64 | int], rank: int) -> INT64: paddings = [*paddings, *zeros] # Interleave the padding values paddings = paddings[-2::-2] + paddings[-1::-2] - return op.Concat(paddings, axis=0) + return op.Concat(*paddings, axis=0) @torch_op("aten::pad", trace_only=True) @@ -1741,12 +1735,70 @@ def _attention_scale(query: TFloat) -> TFloat: return scale +def _attention_repeat_kv_for_group_query( + query: TFloat, key: TFloat, value: TFloat +) -> Tuple[TFloat, TFloat]: + """Expand key and value for group query attention. + + repeat_interleave is applied on key and value to match the number of heads in query. + + Args: + query: Tensor of shape [B, q_num_heads, q_S, E] + key: Tensor of shape [B, k_num_heads, kv_S, E] + value: Tensor of shape [B, v_num_heads, kv_S, E] + + Returns: + Tuple of (expanded_key, expanded_value) where: + - expanded_key: Tensor of shape [B, q_num_heads, kv_S, E] + - expanded_value: Tensor of shape [B, q_num_heads, kv_S, E + """ + + assert ( + query.shape[1] > key.shape[1] == value.shape[1] and query.shape[1] % key.shape[1] == 0 + ), ( + "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" + ) + + # NOTE: QKV are expected to be 4D tensors + + batch_size = op.Shape(query, start=0, end=1) # [B] + q_num_heads = op.Shape(query, start=1, end=2) # [Hq] + kv_num_heads = op.Shape(key, start=1, end=2) # [Hk] + qk_head_size = op.Shape(key, start=3, end=4) # [Dk] + v_head_size = op.Shape(value, start=3, end=4) # [Dv] + new_kv_seq_len = op.Shape(key, start=2, end=3) # [T] + + interleave_dim = op.Div(q_num_heads, kv_num_heads) # Hq / Hk + two = op.Constant(value_int=2) + k_unsqueezed = op.Unsqueeze(key, two) # [B, Hk, 1, T, Dk] + v_unsqueezed = op.Unsqueeze(value, two) # [B, Hv, 1, T, Dv] + + k_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, qk_head_size, axis=0 + ) + k_expand = op.Expand(k_unsqueezed, k_expand_shape) + v_expand_shape = op.Concat( + batch_size, kv_num_heads, interleave_dim, new_kv_seq_len, v_head_size, axis=0 + ) + v_expand = op.Expand(v_unsqueezed, v_expand_shape) + + k_attention_shape = op.Concat( + batch_size, q_num_heads, new_kv_seq_len, qk_head_size, axis=0 + ) + v_attention_shape = op.Concat(batch_size, q_num_heads, new_kv_seq_len, v_head_size, axis=0) + + expanded_key = op.Reshape(k_expand, k_attention_shape) + expanded_value = op.Reshape(v_expand, v_attention_shape) + + return expanded_key, expanded_value + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention( query: TFloat, key: TFloat, value: TFloat, - attn_mask: Optional[TFloat] = None, + attn_mask: Optional[TensorType] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1772,8 +1824,8 @@ def aten_scaled_dot_product_attention( "is_causal and attn_mask cannot be set at the same time" ) - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" + assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( + "only 4D query, key, and value are supported" ) # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html @@ -1784,11 +1836,23 @@ def aten_scaled_dot_product_attention( if is_causal: attn_mask = _causal_attention_mask(query, key) + if enable_gqa: + key, value = _attention_repeat_kv_for_group_query(query, key, value) + else: + assert query.shape[1] == key.shape[1] == value.shape[1], ( + "SDPA (MHA) requires q_num_heads = kv_num_heads" + ) + if attn_mask is None: return _aten_scaled_dot_product_attention_no_mask_onnx( query, key, value, scale, dropout_p ) + if attn_mask.dtype == ir.DataType.BOOL: + return _aten_scaled_dot_product_attention_bool_mask_onnx( + query, key, value, attn_mask, scale, dropout_p + ) + return _aten_scaled_dot_product_attention_float_mask_onnx( query, key, value, attn_mask, scale, dropout_p ) @@ -1856,7 +1920,6 @@ def aten__scaled_dot_product_flash_attention( ) -@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, @@ -1951,62 +2014,6 @@ def aten__scaled_dot_product_efficient_attention( ) -@torch_op("aten::scaled_dot_product_attention", trace_only=True) -def aten_scaled_dot_product_attention_bool_mask( - query: TFloat, - key: TFloat, - value: TFloat, - attn_mask: Optional[BOOL] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, -) -> TFloat: - """scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor - - Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - - Equivalent to the PyTorch code:: - scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale - attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask - attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask - attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1) - attn_weight = torch.dropout(attn_weight, dropout_p) - return attn_weight @ V - - where Q, K, V are the query, key, and value tensors, respectively. - L is the target sequence length, S is the source sequence length, and E is the embedding size. - """ - # Use trace_only to handle optional inputs - assert (not is_causal) or (is_causal and attn_mask is None), ( - "is_causal and attn_mask cannot be set at the same time" - ) - - assert not enable_gqa, ( - "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" - ) - - if scale is None: - scale = _attention_scale(query) - scale = op.CastLike(scale, query) - - if is_causal: - attn_mask = _causal_attention_mask(query, key) - # The causal mask is always float - return _aten_scaled_dot_product_attention_float_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - if attn_mask is None: - return _aten_scaled_dot_product_attention_no_mask_onnx( - query, key, value, scale, dropout_p - ) - - return _aten_scaled_dot_product_attention_bool_mask_onnx( - query, key, value, attn_mask, scale, dropout_p - ) - - def _aten_scaled_dot_product_attention_no_mask_onnx( query: TFloat, key: TFloat, diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ed870b0d7d..f53e9c1133 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -176,12 +176,33 @@ def prims_bitwise_xor(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() +@torch_op("prims::broadcast_in_dim", trace_only=True) def prims_broadcast_in_dim( - a: TensorType, shape: INT64, broadcast_dimensions: Sequence[int] + a: TensorType, shape: Sequence[INT64], broadcast_dimensions: Sequence[int] ) -> TensorType: """broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)""" - raise NotImplementedError() + target_rank = len(shape) + + if not broadcast_dimensions: + # Special case: no broadcast dimensions - all target dims should be 1 + return op.Expand(a, common_ops.merge_dims(shape)) + + # Create base shape of all 1s + ones = [1] * target_rank + + # For each broadcast dimension, we'll replace the 1 with the actual input dimension + # Since broadcast_dimensions is compile-time known, we can do this with individual operations + intermediate_shape = ones + + for i, broadcast_dim in enumerate(broadcast_dimensions): + # Get the input dimension value + input_dim_value = op.Shape(a, start=i, end=i + 1) + intermediate_shape[broadcast_dim] = input_dim_value + + # Reshape input to intermediate shape and expand to target + reshaped = op.Reshape(a, common_ops.merge_dims(intermediate_shape)) + return op.Expand(reshaped, shape) def prims_cat(tensors: Sequence[TensorType], dim: int) -> TensorType: diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 3fa204b405..6240347886 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -1,154 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""In-memory intermediate representation for ONNX graphs.""" - -__all__ = [ - # Modules - "serde", - "traversal", - "convenience", - "external_data", - "tape", - # IR classes - "Tensor", - "ExternalTensor", - "StringTensor", - "LazyTensor", - "SymbolicDim", - "Shape", - "TensorType", - "OptionalType", - "SequenceType", - "SparseTensorType", - "TypeAndShape", - "Value", - "Attr", - "RefAttr", - "Node", - "Function", - "Graph", - "GraphView", - "Model", - # Constructors - "AttrFloat32", - "AttrFloat32s", - "AttrGraph", - "AttrGraphs", - "AttrInt64", - "AttrInt64s", - "AttrSparseTensor", - "AttrSparseTensors", - "AttrString", - "AttrStrings", - "AttrTensor", - "AttrTensors", - "AttrTypeProto", - "AttrTypeProtos", - "Input", - # Protocols - "ArrayCompatible", - "DLPackCompatible", - "TensorProtocol", - "ValueProtocol", - "ModelProtocol", - "NodeProtocol", - "GraphProtocol", - "GraphViewProtocol", - "AttributeProtocol", - "ReferenceAttributeProtocol", - "SparseTensorProtocol", - "SymbolicDimProtocol", - "ShapeProtocol", - "TypeProtocol", - "MapTypeProtocol", - "FunctionProtocol", - # Enums - "AttributeType", - "DataType", - # Types - "OperatorIdentifier", - # Protobuf compatible types - "TensorProtoTensor", - # Conversion functions - "from_proto", - "from_onnx_text", - "to_proto", - # Convenience constructors - "tensor", - "node", - # Pass infrastructure - "passes", - # IO - "load", - "save", -] - -from onnx_ir import ( - ArrayCompatible, - Attr, - AttrFloat32, - AttrFloat32s, - AttrGraph, - AttrGraphs, - AttributeProtocol, - AttributeType, - AttrInt64, - AttrInt64s, - AttrSparseTensor, - AttrSparseTensors, - AttrString, - AttrStrings, - AttrTensor, - AttrTensors, - AttrTypeProto, - AttrTypeProtos, - DataType, - DLPackCompatible, - ExternalTensor, - Function, - FunctionProtocol, - Graph, - GraphProtocol, - GraphView, - GraphViewProtocol, - Input, - LazyTensor, - MapTypeProtocol, - Model, - ModelProtocol, - Node, - NodeProtocol, - OperatorIdentifier, - OptionalType, - RefAttr, - ReferenceAttributeProtocol, - SequenceType, - Shape, - ShapeProtocol, - SparseTensorProtocol, - SparseTensorType, - StringTensor, - SymbolicDim, - SymbolicDimProtocol, - Tensor, - TensorProtocol, - TensorProtoTensor, - TensorType, - TypeAndShape, - TypeProtocol, - Value, - ValueProtocol, - convenience, - external_data, - from_onnx_text, - from_proto, - load, - node, - passes, - save, - serde, - tape, - tensor, - to_proto, - traversal, -) +# pylint: disable=wildcard-import,unused-wildcard-import +from onnx_ir import * # type: ignore # noqa: F403 diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 79312eaefa..78dce2739e 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -17,7 +17,17 @@ class Builder(tape.Tape): - """An extension of the tape that provides a more convenient API for constructing the IR.""" + """An extension of the tape that provides a more convenient API for constructing the IR. + + Example: + >>> from onnxscript import ir + >>> from onnxscript.ir import _tape + >>> op = _tape.Builder() + >>> input = ir.Value(name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))) + >>> relu_val = op.Relu(input, _name="relu_node", _domain="", _version=18, _outputs=["relu_out"]) + + Note: When passing `_name`, ensure it is unique to avoid duplicate node names. + """ def __getattr__(self, op_type: str) -> Any: return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) @@ -26,6 +36,8 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, domain = kwargs.pop("_domain", "") version = kwargs.pop("_version", None) outputs = kwargs.pop("_outputs", 1) + name = kwargs.pop("_name", None) + if isinstance(outputs, Sequence): num_outputs = len(outputs) else: @@ -34,7 +46,12 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, if num_outputs == 1: value = super().op( - op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + name=name, ) if isinstance(outputs, Sequence): value.name = outputs[0] @@ -45,6 +62,7 @@ def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, attributes=kwargs, domain=domain, version=version, + name=name, num_outputs=num_outputs, ) if isinstance(outputs, Sequence): diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py index 46cbcc23fe..f8210e7a0b 100644 --- a/onnxscript/ir/_tape_test.py +++ b/onnxscript/ir/_tape_test.py @@ -5,6 +5,7 @@ import unittest from onnxscript import ir +from onnxscript.ir import _tape class TestTape(unittest.TestCase): @@ -72,5 +73,32 @@ def test_op_multi_out(self): self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) +class TestBuilder(unittest.TestCase): + def test_op_name(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + input_b = ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + add = op.Add(input_a, input_b, _name="add_node") + _ = op.Relu(add, _name="relu_node") + self.assertEqual(op.nodes[0].name, "add_node") + self.assertEqual(op.nodes[1].name, "relu_node") + + def test_op_name_multi_out(self): + op = _tape.Builder() + + input_a = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ) + + _ = op.CustomOp(input_a, _name="custom_node", _outputs=3) + self.assertEqual(op.nodes[0].name, "custom_node") + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index b4d378bd17..4274bf2062 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -214,7 +214,7 @@ def __str__(self): def debug_print(self): if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s: %s", type(self), str(self)) + logger.debug("%s: %s", type(self), self) def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( @@ -321,6 +321,7 @@ def to_model_proto( input_types: Optional[Sequence[ONNXType]] = None, output_types: Optional[Sequence[ONNXType]] = None, value_infos: dict[str, ONNXType] | None = None, + opset_version: int | None = None, **kwargs, ) -> onnx.ModelProto: """Converts this instance into a `onnx.ModelProto`. @@ -336,6 +337,8 @@ def to_model_proto( are set to be of the corresponding type in this list. value_infos: A dictionary mapping intermediate variable names to ONNX types. Used to set value_info for intermediate variables. + opset_version: The standard opset version to use for the model if it + cannot be inferred. Otherwise defaults to the current opset version. kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. Returns: @@ -393,8 +396,8 @@ def to_proto(f): if "" not in opsets: # No operator is using the standard opset. - # A default value is given. - opsets[""] = onnx_opset_version() + # Use the specified version if provided or the default value. + opsets[""] = opset_version if opset_version is not None else onnx_opset_version() if "ir_version" not in kwargs: kwargs["ir_version"] = select_ir_version(opsets[""]) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 6260829249..978a1b4d65 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -19,12 +19,8 @@ import onnxscript.optimizer._constant_folding as constant_folding from onnxscript import ir -from onnxscript.optimizer._constant_folding import ( - basic_constant_propagation, -) -from onnxscript.optimizer._constant_folding import ( - fold_constants as fold_constants_ir, -) +from onnxscript.optimizer._constant_folding import basic_constant_propagation +from onnxscript.optimizer._constant_folding import fold_constants as fold_constants_ir from onnxscript.optimizer._optimizer import optimize_ir _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 3269f9d51e..27b09557e7 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -5,11 +5,18 @@ from __future__ import annotations +__all__ = [ + "basic_constant_propagation", + "fold_constants", + "FoldConstantsPass", + "FOLDED_FROM_KEY", +] + import dataclasses import logging import math import typing -from typing import Any, Callable, Collection, Iterable, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union import numpy as np import onnx @@ -19,10 +26,21 @@ import onnxscript.utils.utils as utils from onnxscript.ir import _tape +DEFAULT_CONSTANT_FOLD_BLACKLIST = [ + # ConstantOfShape is preserved to avoid increasing model size unnecessarily + "ConstantOfShape", + # Quantize/DequantizeLinear are preserved to keep the quantization info + "QuantizeLinear", + "DequantizeLinear", +] + DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 8192 DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT = 512 * 512 +# Key used to store the metadata +FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.folded_from" + _NON_DETERMINISTIC_OPS = frozenset( { @@ -34,6 +52,13 @@ } ) +# A list of ops to always fold regardless of their input size limits, as long as +# they are the single consumer of the large input tensors +_DEFAULT_ALWAYS_FOLD_OPS = frozenset( + { + ("", "Transpose"), + } +) logger = logging.getLogger(__name__) @@ -59,7 +84,7 @@ def _is_onnx_op(node: ir.Node, op_type: str) -> bool: def _process_constant_node(node: ir.Node) -> None: """Sets const_value of output value of a Constant op node.""" - if node.op_type != "Constant" or node.domain != "": + if not _is_onnx_op(node, "Constant"): return if len(node.attributes) != 1: return @@ -332,12 +357,6 @@ def _get_output(node: ir.Node, index: int) -> ir.Value | None: return None -def _update_type(value: ir.Value, type: ir.TypeProtocol | None) -> None: - if type is not None: - # TODO: merge types - value.type = type - - def _get_input_element_type(node: ir.Node, index: int) -> int: input = _get_input(node, index) if input is not None and input.type is not None: @@ -485,15 +504,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: if input is None or output is None: return None - # TODO(rama): Parts of the following logic (implementing type/shape inference - # for Cast op) should be unnecessary. Generic incremental shape-inference - # should handle this. Only the optimization to eliminate redundant Cast ops - # should be needed here. - - input_shape = input.shape - if input_shape is not None: - output.shape = input_shape.copy() - input_dtype = _get_input_element_type(node, 0) output_dtype = _get_int_attribute(node, "to", None) if output_dtype is not None: @@ -599,6 +609,19 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = node.inputs[0] output = node.outputs[0] if input is not None and output is not None: + # NOTE: backward shape inference + try: + input.shape = _merge_shapes(input.shape, output.shape) + except Exception as e: + logger.warning( + "[Constant folder] Cannot merge shapes on Identity node '%s' " + "(folded from: %s) because of error: %s", + node.name, + input.meta.get(FOLDED_FROM_KEY, set()), + e, + ) + if input.type is None: + input.type = output.type state.set_sym_value(output, input) return None @@ -783,6 +806,9 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: This allows downstream `SequenceAt` users to be replaced by `split_x` accordingly. """ input = node.inputs[0] + if len(node.inputs) == 1: + # split is not provided + return None split = node.inputs[1] output = node.outputs[0] @@ -800,27 +826,45 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: axis = axis + rank if axis < 0 or axis >= rank: return None - split_dimension_size = shape[axis] - if not isinstance(split_dimension_size, int): - return None + # NOTE: Split needs to either be a scalar or a 1-D tensor. We need to + # calculate the number of outputs for Split. + # If split is a scalar, we split into chunks of size 'split' if possible. + # * the split dimension size and split_value has to be known. + # If split is a 1-D tensor, we split into 'size(split)' chunks + # * Get the size from split_value if it's numpy array. + # * Get the size from symbolic shape if split_value is not available. split_value = _get_numpy_value(split) - if split_value is None: + split_shape = ( + split.shape.numpy() if split.shape is not None and split.shape.is_static() else None + ) + + # No information about split value or shape. + if split_value is None and split_shape is None: return None - assert isinstance(split_value, np.ndarray) - if split_value.ndim == 0: - # split into chunks all of size 'split' if possible. - num_outputs = math.ceil(split_dimension_size / split_value.item()) + if isinstance(split_shape, tuple) and len(split_shape) == 1: + # If split_shape is known, we can use it to determine the number of outputs. + split_dimension_size = split_shape[0] + assert isinstance(split_dimension_size, int) + num_outputs = split_dimension_size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] - split_values = op.Split( - input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs - ) + split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) elif split_value.ndim == 1: # split into 'size(split)' chunks num_outputs = split_value.size split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] split_values = op.Split(input, split, axis=axis, _outputs=split_outputs) + elif split_value.ndim == 0: + # split into chunks all of size 'split' if possible. + split_dimension_size = shape[axis] + if not isinstance(split_dimension_size, int): + return None + num_outputs = math.ceil(split_dimension_size / split_value.item()) + split_outputs = [f"{output.name}_split_{i}" for i in range(num_outputs)] + split_values = op.Split( + input, axis=axis, num_outputs=num_outputs, _outputs=split_outputs + ) else: return None @@ -871,7 +915,11 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue: return None -def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None: +def _merge_shapes( + preferred_shape: ir.Shape | None, other_shape: ir.Shape | None +) -> ir.Shape | None: + """Merge two shapes, preferring dimensions from preferred_shapes.""" + def merge_dims(dim1, dim2): if dim1 == dim2: return dim1 @@ -883,13 +931,35 @@ def merge_dims(dim1, dim2): return dim2 return dim1 - if shape1 is None: - return shape2 - if shape2 is None: - return shape1 - if len(shape1) != len(shape2): - raise ValueError("Shapes must have the same rank.") - return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)]) + if preferred_shape is None: + return other_shape + if other_shape is None: + return preferred_shape + if len(preferred_shape) != len(other_shape): + raise ValueError( + f"Shapes must have the same rank, got preferred_shape={preferred_shape}, other_shape={other_shape}" + ) + return ir.Shape( + [merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)] + ) + + +def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None: + """Record the set of original input values that contributed to the constant-folded outputs.""" + folded_from: set[str] = set() + for input in original_node.inputs: + if input is None: + continue + folded_from.update(input.meta.get(FOLDED_FROM_KEY, set())) + assert input.name is not None + folded_from.add(input.name) + + for new_output in replacement.new_outputs: + if new_output is None: + continue + new_output.meta[FOLDED_FROM_KEY] = folded_from + # Store the string representation of the set to metadata_props to persist it across serialization + new_output.metadata_props[FOLDED_FROM_KEY] = repr(sorted(folded_from)) class FoldConstantsPass(ir.passes.InPlacePass): @@ -899,9 +969,10 @@ class FoldConstantsPass(ir.passes.InPlacePass): shape_inference: Whether to perform shape inference. input_size_limit: Maximum size of input tensors to fold. output_size_limit: Maximum size of output tensors to fold. - always_fold_ops: Collection of op types that should always be folded. - For ops from the default opset, only op_type is neede (e.g. "Transpose"), - otherwise specify the domain with ``{domain}::{op_type}``. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. """ def __init__( @@ -910,18 +981,12 @@ def __init__( shape_inference: bool, input_size_limit: int, output_size_limit: int, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> None: self.shape_inference = shape_inference self.input_size_limit = input_size_limit self.output_size_limit = output_size_limit - ops = [] - for name in always_fold_ops: - domain, op_type = name.split("::", 1) if "::" in name else ("", name) - if domain == "ai.onnx": - domain = "" - ops.append((domain, op_type)) - self.always_fold_ops: frozenset[tuple[str, str]] = frozenset(ops) + self.should_fold = should_fold self._opset_imports: dict[str, int] = {} self._counts: dict[str, int] = {} @@ -961,7 +1026,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: input_data = {k: v for k, v in input_data.items() if v is not None} if any(t is None for t in input_types.values()): logger.debug( - "Skipping shape inference for node %s due to missing input type.", + "Skipping shape inference for node %r due to missing input type.", node.name, ) else: @@ -983,62 +1048,99 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None: inferred_shape = ir.serde.deserialize_type_proto_for_shape( inferred_type ) + # NOTE: forward shape inference output.shape = _merge_shapes(output.shape, inferred_shape) output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) except Exception as e: logger.debug( - "Skipping shape inference for node %s due to exception: %s", - node.name, + "Skipping shape inference for node %r due to exception: %s", + node, e, ) - def new_constant(self, node: ir.Node, value) -> ir.Node | None: - irvalue = node.outputs[0] - if not isinstance(value, np.ndarray): - # ONNX does not have a way to represent non-tensor constants, eg. a sequence. - # So, a constant-value of type sequence is not folded, but it can be used - # to optimize subsequent operations when possible. + def _prepare_folded_tensor( + self, node: ir.Node, output_name: str, output_array: np.ndarray | Any + ) -> ir.Tensor | None: + """ + Shared helper for constant/init creation: + - Validates the folded Python value is a numpy ndarray. + - Wraps it in an ir.Tensor and names it. + - Applies output_size_limit logic with input-usage compensation. + Returns the ir.Tensor or None if it should be skipped. + """ + if not isinstance(output_array, np.ndarray): logger.info( "Skip storing constant folded value %s due to unsupported type %s.", - irvalue.name, - type(value), + output_name, + type(output_array), ) return None - tensor = ir.tensor(value) - tensor.name = irvalue.name - irvalue.const_value = tensor + tensor = ir.tensor(output_array) + tensor.name = output_name - if value.size > self.output_size_limit: - # Handle examples like Transpose(weight) to be folded even if the size is large, - # as long as weight has no other uses. This won't increase model size. + # Size gating (shared logic) + if output_array.size > self.output_size_limit: removed_input_size = 0 - for input in node.inputs: - if (input is not None) and (len(input.uses()) == 1): - array = _get_numpy_value(input) - if array is not None: - removed_input_size += array.size - increased_size = value.size - removed_input_size + for input_val in node.inputs: + if (input_val is not None) and (len(input_val.uses()) == 1): + input_array = _get_numpy_value(input_val) + if input_array is not None: + removed_input_size += input_array.size + increased_size = output_array.size - removed_input_size if increased_size > 0: logger.info( - "Skip storing constant folded nvalue %s due to large size %s.", - irvalue.name, - value.size, + "Skip storing constant folded array %s due to large size %s.", + output_name, + output_array.size, ) return None + return tensor + + def new_constant(self, node: ir.Node, array: np.ndarray | Any) -> ir.Node | None: + """Create a new Constant node with the given array as its value.""" + original_value = node.outputs[0] + + tensor = self._prepare_folded_tensor(node, original_value.name, array) + if tensor is None: + return None + logger.debug( "New constant for value %s dtype: %s shape: %s", - irvalue.name, - value.dtype, - value.shape, + original_value.name, + array.dtype, + array.shape, ) - attributes = ir.convenience.convert_attributes({"value": tensor}) - node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1) + node = ir.Node("", "Constant", inputs=[], attributes=(ir.AttrTensor("value", tensor),)) return node - def process_node(self, node: ir.Node) -> Replacement | None: + def new_initializer(self, node: ir.Node, array: np.ndarray | Any) -> ir.Value | None: + """Create a new initializer value with the given array as its value.""" + original_value = node.outputs[0] + + tensor = self._prepare_folded_tensor(node, original_value.name, array) + if tensor is None: + return None + + initializer = ir.Value( + name=original_value.name, + type=ir.TensorType(ir.DataType(tensor.dtype)), + shape=tensor.shape, # type: ignore[arg-type] + const_value=tensor, + ) + + logger.debug( + "New Initializer for value %s dtype: %s shape: %s", + original_value.name, + array.dtype, + array.shape, + ) + + return initializer + + def process_node(self, node: ir.Node, is_function: bool) -> Replacement | None: """Process a node and return a Replacement if the node can be replaced.""" for i, value in enumerate(node.inputs): sym_value = self._state.get_sym_value(value) @@ -1053,18 +1155,33 @@ def process_node(self, node: ir.Node) -> Replacement | None: self._modified = True # TODO(rama): consider merging type/other info from both values + # Propagate const_value, and manually find out shape and type + # to avoid potentially expensive shape inference on large tensors. + if _is_onnx_op(node, "Constant"): + _process_constant_node(node) # Do incremental shape inference - if self.shape_inference and not _is_control_flow_op(node): + elif self.shape_inference and not _is_control_flow_op(node): self._do_inference(node) if node.domain not in self._opset_imports: + logger.debug( + "Skipping constant folding for node %r due to missing opset import for domain %r.", + node.name, + node.domain, + ) return None + version = self._opset_imports[node.domain] op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version) for optimizer in op_optimizers: assert optimizer context = RewriterContext() - output = optimizer(node, context, self._state) + try: + output = optimizer(node, context, self._state) + except Exception as e: + raise RuntimeError( + f"Error during constant folding for node {node.name!r} ({node.domain}::{node.op_type})" + ) from e if output is not None: if isinstance(output, Replacement): return output @@ -1072,54 +1189,96 @@ def process_node(self, node: ir.Node) -> Replacement | None: output = [output] return Replacement(output, context.nodes) - if _is_control_flow_op(node) or _is_non_deterministic_op(node): + if _is_onnx_op(node, "Constant"): + logger.debug("Skipping constant folding for Constant node %r", node.name) return None - if _is_onnx_op(node, "Constant"): - _process_constant_node(node) + if _is_control_flow_op(node): + logger.info( + "Skipping constant folding for control flow op %r (%s::%s) because it is not supported yet", + node.name, + node.domain, + node.op_type, + ) + + return None + + if _is_non_deterministic_op(node): + logger.info( + "Skipping constant folding for non-deterministic op %r (%s::%s)", + node.name, + node.domain, + node.op_type, + ) return None if any(x.is_graph_input() for x in node.inputs if x is not None): - # Do not fold any graph inputs to preserve graph signature + logger.info( + "Skipping constant folding for node %r because it is graph input to preserve graph signature", + node.name, + ) return None - # Ensure all node inputs are constants + # Ensure all node inputs are constants or initializers if any(x.const_value is None for x in node.inputs if x is not None): - if logger.isEnabledFor(logging.DEBUG): - logger.debug( - "Skipping constant folding for node %s because it has non-constant inputs", - node, - [x.name for x in node.inputs if x is not None], - ) return None - input_tensors = [x.const_value if x is not None else None for x in node.inputs] - if any( - tensor.size > self.input_size_limit - for tensor in input_tensors - if tensor is not None - ): - if (node.domain, node.op_type) in self.always_fold_ops and all( - len(input.consumers()) == 1 for input in node.inputs if input is not None - ): - # If the op is in always_fold_ops and all inputs are used only by this node, - # we can still fold it even if the input size exceeds the limit. - logger.debug( - "Folding large constant for node %s because it is in the always_fold_ops list", - node, - ) - else: - # Skip folding large tensors - if logger.isEnabledFor(logging.DEBUG): - input_sizes = [ - tensor.size for tensor in input_tensors if tensor is not None - ] - logger.debug( - "Skipping constant folding for node %s due to large input size: %s", - node, - input_sizes, + should_fold = self.should_fold(node) + + if should_fold is False: + logger.info( + "Skipping constant folding for node %r because should_fold returned False", + node.name, + ) + return None + + elif should_fold is None: + # Use default rules to decide whether to fold the node: + # - Nodes in the DEFAULT_CONSTANT_FOLD_BLACKLIST list are not folded + # - If the any tensor input size exceeds the input_size_limit, skip folding the node + for op_type in DEFAULT_CONSTANT_FOLD_BLACKLIST: + if _is_onnx_op(node, op_type): + logger.info( + "Skipping constant folding for node %r because " + "%s is preserved by default", + node.name, + op_type, ) - return None + return None + + input_tensors = [x.const_value if x is not None else None for x in node.inputs] + large_inputs = [ + tensor is not None and tensor.size > self.input_size_limit + for tensor in input_tensors + ] + if any(large_inputs): + # Decide whether to fold large constants + assert len(node.inputs) == len(large_inputs) + if (node.domain, node.op_type) in _DEFAULT_ALWAYS_FOLD_OPS and all( + len(input.consumers()) == 1 or (not is_large) + for input, is_large in zip(node.inputs, large_inputs) + if input is not None + ): + # If the op is in _DEFAULT_ALWAYS_FOLD_OPS and all large inputs are used only by this node, + # we can still fold it even if the input size exceeds the limit + pass + else: + # Skip folding large tensors + if logger.isEnabledFor(logging.INFO): + input_sizes = [ + tensor.size for tensor in input_tensors if tensor is not None + ] + logger.info( + "Skipping constant folding for node %r due to large input sizes: %s", + node, + input_sizes, + ) + return None + else: + logger.info( + "Constant folding node %r because should_fold returned True", + node.name, + ) input_values = [_get_numpy_value(x) for x in node.inputs] @@ -1128,6 +1287,7 @@ def convert(av): return ir.serde.serialize_tensor(av.value) return av.value + # TODO(justinchuby): We should find a way to avoid serializing tensors every time we want to evaluate a node attr_values = {name: convert(attr) for name, attr in node.attributes.items()} outputs = _reference_evaluator.evaluate( node.domain, node.op_type, version, *input_values, **attr_values @@ -1136,23 +1296,46 @@ def convert(av): if outputs is None: return None if len(node.outputs) == 1 and not isinstance(outputs, (tuple, list)): - replacement = self.new_constant(node, outputs) - if _is_onnx_op(node, "ConstantOfShape") or replacement is None: + # We don't support initializers in functions, so we need to create Constant nodes + if is_function: + replacement = self.new_constant(node, outputs) + if replacement is None: + return None + return Replacement(replacement.outputs, [replacement]) + new_initializer_value = self.new_initializer(node, outputs) + if new_initializer_value is None: return None - return Replacement(replacement.outputs, [replacement]) + # Add the new initializer to the graph + assert node.graph is not None + node.graph.register_initializer(new_initializer_value) + return Replacement([new_initializer_value], []) else: logger.warning( "Skipping constant folding for op %s with multiple outputs.", node.op_type ) return None - def replace_node(self, node: ir.Node, replacement, root: ir.Graph | ir.Function) -> None: + def replace_node( + self, node: ir.Node, replacement: Replacement, root: ir.Graph | ir.Function + ) -> None: logger.debug("Replacing node: %s::%s %s", node.domain, node.op_type, node.name) + # Record the names of the values that has contributed to the replacement + _record_contributing_values(node, replacement) + + # Obtain the list of non-None inputs to the node before it is cleared by + # replace_nodes_and_values to check for unused initializers later. + node_inputs = [v for v in node.inputs if v is not None] + ir.convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) + if isinstance(root, ir.Graph): + # The old node should now be detached from the graph + assert node.graph is None + _clear_unused_initializers(node_inputs) + self._modified = True # TODO: what about new opset_imports? @@ -1168,7 +1351,8 @@ def visit_attribute(self, attr: ir.Attr) -> None: self.visit_graph(graph) def visit_node(self, node: ir.Node, root: ir.Graph | ir.Function) -> None: - replacement = self.process_node(node) + is_function = isinstance(root, ir.Function) + replacement = self.process_node(node, is_function=is_function) if replacement is None: # No change. Process attributes. for attr in node.attributes.values(): @@ -1229,6 +1413,19 @@ def _sym_value_can_replace_graph_output( return True +def _clear_unused_initializers(values: Sequence[ir.Value]) -> None: + # Detach all inputs to the node, then check for unused initializers + for value in values: + if value is None or not value.is_initializer(): + continue + + if not value.uses(): + assert value.is_initializer() + assert value.graph is not None + assert value.name is not None + value.graph.initializers.pop(value.name) + + @dataclasses.dataclass class FoldConstantsResult(ir.passes.PassResult): symbolic_value_map: dict[ir.Value, SymbolicValue] @@ -1245,7 +1442,7 @@ def fold_constants( onnx_shape_inference: bool = False, input_size_limit: int = DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, - always_fold_ops: Collection[str] = frozenset(["Transpose"]), + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, ) -> FoldConstantsResult: """ Applies constant folding optimization to the model. @@ -1260,10 +1457,9 @@ def fold_constants( output_size_limit: The maximum size of output tensors that can be stored after constant folding. Defaults to `DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT`. - always_fold_ops: A collection of op types that should always be folded, - regardless of their input or output sizes. For ops from the default opset, - only op_type is neede (e.g. "Transpose"), otherwise specify the domain - with ``{domain}::{op_type}``. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding, False if it should not be folded, + or None to use the default rules. Defaults to a function that always returns None. Returns: An instance of `FoldConstantsResult`. @@ -1273,6 +1469,6 @@ def fold_constants( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, - always_fold_ops=always_fold_ops, + should_fold=should_fold, ) return folder_pass(model) # type: ignore[return-value] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index 8c05fbc0a4..ae5c9901bd 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -14,13 +14,20 @@ class FoldConstantsTest(unittest.TestCase): - def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): + def _fold( + self, + model: ir.Model | str, + onnx_shape_inference: bool = False, + dce: bool = True, + **kwargs, + ): if isinstance(model, str): model = ir.from_onnx_text(model) _constant_folding.fold_constants( model, onnx_shape_inference=onnx_shape_inference, **kwargs ) - optimizer.remove_unused_nodes(model) + if dce: + optimizer.remove_unused_nodes(model) # Ensure the model is valid after optimization onnx.checker.check_model(ir.serde.serialize_model(model)) return model @@ -36,8 +43,8 @@ def test_fold_add(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_cast_like(self): model = """ @@ -50,9 +57,16 @@ def test_fold_cast_like(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + optimized = self._fold(model, dce=False) + self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape(self): model = """ @@ -66,9 +80,18 @@ def test_fold_shape(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + optimized = self._fold(model, dce=False) + self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + self.assertNotIn("rank", optimized.graph.initializers) + self.assertNotIn("shape", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape_slice(self): model = """ @@ -83,8 +106,8 @@ def test_fold_shape_slice(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "four") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("four", optimized.graph.initializers) def test_fold_if_cond(self): model = """ @@ -130,9 +153,11 @@ def test_fold_inside_if_branch(self): optimized = self._fold(model) self.assertEqual(len(optimized.graph), 1) then_graph = optimized.graph[0].attributes["then_branch"].as_graph() - self.assertEqual(len(then_graph), 2) + self.assertEqual(len(then_graph), 1) + self.assertIn("temp", then_graph.initializers) else_graph = optimized.graph[0].attributes["else_branch"].as_graph() - self.assertEqual(len(else_graph), 2) + self.assertEqual(len(else_graph), 1) + self.assertIn("temp", else_graph.initializers) def test_fold_if_propagate(self): model = """ @@ -154,9 +179,8 @@ def test_fold_if_propagate(self): """ optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "m_square") - self.assertEqual(optimized.graph[0].op_type, "Constant") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("m_square", optimized.graph.initializers) def test_fold_redundant_cast(self): model = """ @@ -209,8 +233,8 @@ def test_shape_inference(self): """ optimized = self._fold(model, onnx_shape_inference=True) - self.assertEqual(len(optimized.graph), 2) - self.assertEqual(optimized.graph[0].outputs[0].name, "C") + self.assertEqual(len(optimized.graph), 1) + self.assertIn("C", optimized.graph.initializers) def test_static_split_to_sequence_with_scalar_split_and_squence_at_is_folded_as_split( self, @@ -346,6 +370,60 @@ def test_split_to_sequence_and_concat_from_sequence_with_new_axis_1( self.assertEqual(len(optimized.graph), 7) self.assertEqual(optimized.graph[6].op_type, "Concat") + def test_dynamic_split_to_sequence_list_shape_rewrite(self): + # split is a graph input with known 1-D static shape [4]; values unknown (not constant) + # Ensures the branch: if isinstance(split_shape, tuple) and len(split_shape) == 1 + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[2,N] x, int64[4] split) => (float[2,N] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + i3 = Constant () + s3 = SequenceAt (splits, i3) + return_val = Concat (s0, s1, s2, s3) +}""" + optimized = self._fold(model) + # Expect: Split + Concat (index constants & SequenceAt removed) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 4) + self.assertEqual(split_nodes[0].op_type, "Split") + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + + def test_dynamic_split_to_sequence_list_shape_no_keepdims(self): + # keepdims=0 path with dynamic (non-constant) splits input; triggers squeeze logic. + model = """ +< + ir_version: 8, + opset_import: ["" : 18] +> +func (float[1,M] x, int64[3] split) => (float[1,M] return_val) { + splits = SplitToSequence (x, split) + i0 = Constant () + s0 = SequenceAt (splits, i0) + i1 = Constant () + s1 = SequenceAt (splits, i1) + i2 = Constant () + s2 = SequenceAt (splits, i2) + return_val = Concat (s0, s1, s2) +}""" + optimized = self._fold(model) + split_nodes = [n for n in optimized.graph if n.op_type == "Split"] + self.assertEqual(len(split_nodes), 1) + self.assertEqual(len(split_nodes[0].outputs), 3) + self.assertTrue(all(n.op_type != "SequenceAt" for n in optimized.graph)) + # Each split output should have a corresponding Squeeze (keepdims=0 branch) + squeeze_nodes = [n for n in optimized.graph if n.op_type == "Squeeze"] + self.assertEqual(len(squeeze_nodes), 3) + def test_initializer_input_not_folded(self): model_text = """ @@ -560,7 +638,8 @@ def test_input_size_limit(self): # Since there is no increase in model-size, output-size is not a concern. optimized = self._fold(model, input_size_limit=256 * 256, output_size_limit=256 * 256) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant", "Add"]) + self.assertEqual(ops, ["Add"]) + self.assertIn("w_squared", optimized.graph.initializers) def test_transpose_is_always_folded(self): model_text = """ @@ -579,7 +658,36 @@ def test_transpose_is_always_folded(self): # Input size limit will not prevent folding of Transpose op optimized = self._fold(model, input_size_limit=1) ops = [node.op_type for node in optimized.graph] - self.assertEqual(ops, ["Constant"]) + self.assertEqual(ops, []) + self.assertIn("z", optimized.graph.initializers) + + def test_node_is_folded_if_specified_as_should_fold(self): + model_text = """ + + agraph (float[M, 256] x) => (float[42, 42] z) + + { + z = ConstantOfShape (w) + } + """ + model = ir.from_onnx_text(model_text) + + # ConstantOfShape is not folded by default + optimized = self._fold(model) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, ["ConstantOfShape"]) + + # But ConstantOfShape is folded when specified in should_fold + optimized = self._fold( + model, should_fold=lambda node: node.op_type == "ConstantOfShape" or None + ) + ops = [node.op_type for node in optimized.graph] + self.assertEqual(ops, []) + self.assertIn("z", optimized.graph.initializers) + np.testing.assert_array_equal( + optimized.graph.initializers["z"].const_value, + np.ones((42, 42), dtype=np.int64), + ) def test_multi_graph_identity_output_preserves_output_name(self): model = """ @@ -613,6 +721,26 @@ def test_attribute_reference(self): optimized = self._fold(model) self.assertEqual(len(optimized.graph), 2) + def test_constant_folding_creates_constant_nodes_in_function(self): + model = """ + + model (float x) => (float return_val) { + return_val = this.function (x) + } + + function (x) => (return_val) { + tmp = Constant () + tmp_0 = Cast (tmp) + return_val = Sub (tmp_0, x) + } + """ + optimized = self._fold(model) + self.assertEqual(len(optimized.functions), 1) + for func in optimized.functions.values(): + # Ensure that constant folding has created constant nodes in the function + constant_nodes = [n for n in func.graph if n.op_type == "Constant"] + self.assertEqual(len(constant_nodes), 1) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 384cc12fd4..307144462f 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from typing import Callable import onnx_ir as ir import onnx_ir.passes.common as common_passes @@ -21,6 +22,7 @@ def optimize_ir( stop_if_no_change: bool = True, input_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT, output_size_limit: int = _constant_folding.DEFAULT_CONSTANT_FOLD_OUTPUT_SIZE_LIMIT, + should_fold: Callable[[ir.Node], bool | None] = lambda node: None, inline: bool = True, ) -> None: """Optimizes a model. @@ -29,11 +31,15 @@ def optimize_ir( model: The model to be optimized. num_iterations: Number of times the optimization loop is repeated. onnx_shape_inference: Applies node-level shape-inference as part of optimization + stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. input_size_limit: Will not apply constant folding to ops with any input of size greater than this. Does not apply to special ops like Shape() and Size(). output_size_limit: Will not rewrite any foldable-op into a Constant op if the size of the output tensor is greater than this. - stop_if_no_change: Stop the optimization loop if no change is detected in an iteration. + should_fold: An optional function that takes a node and returns True if + the node should be considered for folding. + The function should return True/False value to indicate if this particular + node should be folded, or None to use the default folding rules. inline: If True, inlines all functions in the model. """ passes = [ @@ -43,6 +49,7 @@ def optimize_ir( shape_inference=onnx_shape_inference, input_size_limit=input_size_limit, output_size_limit=output_size_limit, + should_fold=should_fold, ), rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), common_passes.RemoveUnusedNodesPass(), diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index d3e7a7891e..fb93bc703f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -5,6 +5,7 @@ from typing import Sequence, TypeVar, Union __all__ = [ + "merge_metadata", "pattern", "rewrite", "RewritePass", @@ -16,41 +17,50 @@ "RewriterContext", "MatchingTracer", "MatchStatus", + "RULE_NAME_TAG", ] import onnx import onnx_ir.passes.common as common_passes from onnxscript import ir -from onnxscript.rewriter import ( - basic_rules, - broadcast_to_matmul, - cast_constant_of_shape, - collapse_slices, - fuse_pad_into_conv, - fuse_relus_clips, - no_op, - pattern, - redundant_scatter_nd, -) +from onnxscript.rewriter import pattern from onnxscript.rewriter._basics import MatchContext, MatchingTracer, MatchResult, MatchStatus from onnxscript.rewriter._rewrite_rule import ( + RULE_NAME_TAG, RewriterContext, RewriteRule, RewriteRuleClassBase, RewriteRuleSet, + merge_metadata, +) +from onnxscript.rewriter.rules.common import ( + _basic_rules, + _broadcast_to_matmul, + _cast_constant_of_shape, + _collapse_slices, + _fuse_batchnorm, + _fuse_pad_into_conv, + _fuse_relus_clips, + _min_max_to_clip, + _no_op, + _redundant_scatter_nd, + _remove_optional_bias, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) _DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - *cast_constant_of_shape.rules.rules, - *collapse_slices.rules.rules, - *fuse_relus_clips.fuse_relus_clips_rules().rules, - *basic_rules.basic_optimization_rules().rules, - *redundant_scatter_nd.rules.rules, - *fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules, + *_no_op.rules, # TODO: merge this rule into constant folding? + *_broadcast_to_matmul.rules, + *_cast_constant_of_shape.rules, + *_collapse_slices.rules, + *_min_max_to_clip.rules, + *_fuse_relus_clips.rules, + *_basic_rules.basic_optimization_rules(), + *_redundant_scatter_nd.rules, + *_fuse_pad_into_conv.rules, + *_fuse_batchnorm.rules, + *_remove_optional_bias.rules, ) diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index dbf16ae3d3..f6a7204ac8 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -13,7 +13,7 @@ Dim = Union[int, ir.SymbolicDim] -def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: +def check_shape_bool(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]) -> bool: if val.shape is None: return False if val.shape.rank() != len(shape): diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 6af84dd1d8..953d5f33d5 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -78,23 +78,34 @@ def get_numpy_value(val: ir.Value | None) -> np.ndarray | None: return None -def get_singleton_value(val: ir.Value | None, rank: int | None = None): +def get_singleton_value(val: ir.Value | None, rank: int | Sequence[int] | None = None): """Returns element of a single element tensor constant value, and None otherwise. - If rank is specified, it checks that the value has the given rank. + If an int rank is specified, it checks that the value has the given rank. + If the rank is a sequence of ints, it checks that the value has one of the given ranks. + + Thus, `rank=0` checks for a scalar, `rank=1` checks for a 1D tensor, and + `rank=(0,1)` checks for either a scalar or a 1D tensor. """ np_val = get_numpy_value(val) if np_val is not None and np_val.size == 1: - if rank is None or (np_val.ndim == rank): - return np_val.item() + value = np_val.item() + if (rank is None) or (isinstance(rank, int) and (np_val.ndim == rank)): + return value + if isinstance(rank, Sequence) and (np_val.ndim in rank): + return value return None def is_singleton_value( - val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None + val: ir.Value | None, + expected: float | int | Callable, + *, + rtol: float | None = None, + rank: int | Sequence[int] | None = None, ) -> bool: """Returns True if the value is a single element tensor with given value, and False otherwise.""" - scalar = get_singleton_value(val) + scalar = get_singleton_value(val, rank=rank) if scalar is None: return False if callable(expected): @@ -141,3 +152,27 @@ def get_dim(value: ir.Value | None, dim: int) -> ir.SymbolicDim | int | None: if dim < 0 or dim >= shape.rank(): return None return shape[dim] + + +def same_shape(shape1: ir.Shape | None, shape2: ir.Shape | None) -> bool: + """Check if two shapes are semantically the same.""" + if shape1 is None or shape2 is None: + return False + + # If any dim is unknown, the shapes are not the same + if shape1.has_unknown_dim() or shape2.has_unknown_dim(): + return False + + return shape1 == shape2 + + +def same_dim(dim1: ir.SymbolicDim | int, dim2: ir.SymbolicDim | int) -> bool: + """Check if two dimensions are semantically the same.""" + if type(dim1) is not type(dim2): + return False + if isinstance(dim1, int) and isinstance(dim2, int): + return dim1 == dim2 + assert isinstance(dim1, ir.SymbolicDim) and isinstance(dim2, ir.SymbolicDim) + if dim1.value is None or dim2.value is None: + return False + return dim1.value == dim2.value diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index e347b98375..f54b77033f 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -87,7 +87,7 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu ) try: - constant_value_numpy = constant_value.numpy() + numpy_value = constant_value.numpy() except FileNotFoundError: return self.fail(f"Constant value of {value.name} not available.") @@ -95,11 +95,13 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu if isinstance(pattern_constant_value, list): expected_shape = (len(pattern_constant_value),) - if constant_value_numpy.shape != expected_shape: - return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") + if numpy_value.shape != expected_shape: + return self.fail( + f"Value {value.name} has shape {numpy_value.shape}, expecting {expected_shape}." + ) if not all( math.isclose( - constant_value_numpy.item(i), + numpy_value.item(i), pattern_constant_value[i], rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, @@ -107,24 +109,24 @@ def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Valu for i in range(len(pattern_constant_value)) ): return self.fail( - f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." + f"Value mismatch: expected {pattern_constant_value}, got {numpy_value}." ) return True # TODO (rama): allow users to specify shape requirement, if desired. - if constant_value_numpy.size != 1: + if numpy_value.ndim != 0: return self.fail( f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", ) if not math.isclose( - constant_value_numpy.item(), + numpy_value.item(), pattern_constant_value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( - f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", + f"Constant value mismatch: expected {pattern_constant_value}, got {numpy_value.item()}.", ) return True diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index f64d3fca3c..9b81e33581 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -126,7 +126,14 @@ def __init__(self, value: SupportedAttrTypes): self._value = value def matches(self, attr: ir.Attr) -> bool: - return isinstance(attr, ir.Attr) and attr.value == self._value + if attr.type in { + ir.AttributeType.INTS, + ir.AttributeType.FLOATS, + ir.AttributeType.STRINGS, + }: + # Since the type of attr.value is Sequence, we need to convert to the same type for comparison. + return tuple(attr.value) == tuple(self._value) + return attr.value == self._value def __str__(self) -> str: return str(self._value) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 9481ca5077..7c73a738ce 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,6 +18,7 @@ import onnxscript.rewriter._ir_utils as _ir_utils import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir +import onnxscript.utils.metadata_merger as metadata_merger from onnxscript import ir from onnxscript.ir import _tape, convenience @@ -25,6 +26,11 @@ RewriterContext = _tape.Builder +# TODO(rama): Standardize metadata property keys. May be worth standardizing at ONNX level for +# source/producer metadata. + +RULE_NAME_TAG = "pkg.onnxscript.rewriter.rule_name" + @dataclasses.dataclass class ReplacementSubgraph: @@ -82,12 +88,7 @@ def __init__( if isinstance(matcher, _matcher.PatternMatcher): self._matcher = matcher elif matcher is None: - if target_pattern.has_single_output_node: - self._matcher = _matcher.SimplePatternMatcher(self._target_pattern) - else: - import onnxscript.rewriter.generic_pattern as generic_pattern - - self._matcher = generic_pattern.GenericPatternMatcher(self._target_pattern) + self._matcher = _matcher.SimplePatternMatcher(self._target_pattern) else: self._matcher = matcher(self._target_pattern) self._verbose = verbose @@ -392,7 +393,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False """ @@ -463,7 +464,7 @@ def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: - if perm.as_ints() == list(range(len(perm.as_ints()))): + if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False @@ -614,6 +615,15 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: overload += 1 +_default_metadata_merger: metadata_merger.MetadataMerger = metadata_merger.MetadataMerger( + {RULE_NAME_TAG: metadata_merger.comma_separator_merger} +) + +# TODO(rama): Generalize this to support custom metadata mergers. For now, we just allow +# enabling/disabling the default merger. +merge_metadata: bool = True + + class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: if not rules: @@ -724,6 +734,13 @@ def _apply_to_graph_or_function( _ir_utils.display_nodes(delta.new_nodes) print("++++End Replacement Nodes++++") + # Capture rewrite rule name as metadata. + # TODO(rama): This is just a basic version. We may wish to compose "source" metadata + # from multiple rules in future. + if rule.name: + for n in delta.new_nodes: + n.metadata_props[RULE_NAME_TAG] = rule.name + convenience.replace_nodes_and_values( graph_or_function, node, @@ -733,6 +750,11 @@ def _apply_to_graph_or_function( delta.new_outputs, ) + if merge_metadata: + _default_metadata_merger.copy_merged_metadata( + delta.match.nodes, delta.new_nodes + ) + count += 1 break diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py deleted file mode 100644 index 12827b3116..0000000000 --- a/onnxscript/rewriter/generic_pattern.py +++ /dev/null @@ -1,702 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import collections -import inspect -import os -import textwrap -import warnings -from typing import Any, Callable, Iterator, Sequence - -import onnxscript.rewriter.pattern as orp -from onnxscript import ir - - -class PatternMatchResult: - """Stores information about a match if a match was successful. - - * pattern: the GraphPattern which found this result - * model_nodes: the graph nodes that matched the pattern - * matched_pattern_to_model_value: a mapping from ValuePattern to ir.Value - * kwargs: additional attributes the user may add through the method - :meth:`PatternMatchResult.add_kwargs` - """ - - def __init__( - self, - pattern: orp.GraphPattern, - model_nodes: Sequence[ir.Node], - ): - pattern_nodes: list[orp.NodePattern] = list(pattern) - assert len(model_nodes) == len(pattern_nodes) - self.pattern = pattern - self.model_nodes = model_nodes - self.kwargs: dict[str, Any] = {} - self.matched_pattern_to_model_value: dict[orp.ValuePattern, ir.Value] = {} - - for graph_node, pattern_node in zip(model_nodes, pattern_nodes): - assert graph_node.op_identifier() == pattern_node.op_identifier(), ( - f"Unexpected type mismatch {graph_node.op_identifier()!r} != {pattern_node.op_identifier()!r}" - ) - assert len(graph_node.inputs) == len(pattern_node.inputs), ( - f"Unexpected number of inputs for type {graph_node.op_identifier()}" - ) - for a, b in zip(graph_node.inputs, pattern_node.inputs): - if b is None: - # optional input or not an interesting input - continue - self._bind(b, a) - - assert len(graph_node.outputs) == len(pattern_node.outputs), ( - f"Unexpected number of outputs for type {graph_node.op_identifier()}" - ) - for a, b in zip(graph_node.outputs, pattern_node.outputs): - self._bind(b, a) - - def _bind(self, value_pattern: orp.ValuePattern, value: ir.Value) -> None: - map = self.matched_pattern_to_model_value - if value_pattern in map: - assert map[value_pattern] == value, ( - f"Ambiguities, pattern output {value_pattern!r} means " - f"{value!r} or {map[value_pattern]}" - ) - else: - map[value_pattern] = value - - def add_kwargs(self, name: str, value: Any): - """Adds an attribute, it can be done when the match is being validated, - this attribute can be used when building the replacement nodes. - """ - self.kwargs[name] = value - - def __repr__(self) -> str: - return ( - f"PatternMatchResult: {len(self.model_nodes)} nodes ..., {self.pattern.inputs}, " - f"{self.pattern.outputs})" - ) - - -def _to_match_result(pmr: PatternMatchResult) -> orp.MatchResult: - """Converts a PatternMatchResult into a MatchResult. - - TODO: This is a temporary hack until MatchResult and PatternMatchResult are unified. - """ - result = orp.MatchResult() - for node in pmr.model_nodes: - result.add_node(node) - - for var, val in pmr.matched_pattern_to_model_value.items(): - if var.name is not None: - result.bind(var.name, val) - result.outputs.extend([pmr.matched_pattern_to_model_value[v] for v in pmr.pattern.outputs]) - return result - - -def _value_to_str(value: ir.Value | orp.ValuePattern) -> str: - return value.name if value.name is not None else "anonymous:" + str(id(value)) - - -def _opt_value_to_str(value: ir.Value | orp.ValuePattern | None) -> str: - return _value_to_str(value) if value is not None else "None" - - -def _node_to_str(node: ir.Node | orp.NodePattern) -> str: - inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs) - outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs) - op_type = node.op_type - domain = str(node.domain) - qualified_op = f"{domain}.{op_type}" if domain else op_type - return f"{outputs} = {qualified_op}({inputs})" - - -# def _pattern_node_to_str(node: orp.NodePattern) -> str: -# inputs = ", ".join(_opt_value_to_str(input) for input in node.inputs) -# outputs = ", ".join(_opt_value_to_str(output) for output in node.outputs) -# return f"{outputs} = {node.op_type}({inputs})" - - -class GenericPatternMatcher(orp.PatternMatcher): - """ - Implements a pattern optimization for quick experimentation. - - Current limitation: - - * The current implementation does match on domain name (easy fix). - * It does not compares attributes either (easy fix as well). - """ - - def __init__(self, pattern: orp.GraphPattern) -> None: - super().__init__(pattern) - - def enumerate_matches( - self, - model: ir.Model, - graph_or_function: ir.Graph | ir.Function, - node: ir.Node | None = None, - verbose: int = 0, - ) -> Iterator: - """Enumerates all the matches.""" - if node is None: - matched = [] - for node in graph_or_function: - res = self.match(model, graph_or_function, node, verbose=verbose) - if res: - matched.append(res) - yield res - else: - res = self.match(model, graph_or_function, node, verbose=verbose) - if res: - yield res - - def none( - self, - node: ir.Node | None = None, - lineno: int | None = None, - msg: str = "", - ) -> None: - """Must be called every time a match fails to trace it. - - It may be useful which reason made a pattern matching fail. - Instead of returning None, method *match* can return the following - expression: - - :: - - return self.none(node, inspect.currentframe().f_lineno) - - By setting the verbosity (see next Section), the user may then know - which lines in the code returned None and which condition failed. - If logs are fully enabled, it shows information about matched none - and the line deciding the matched failed. - For example, this tells the matching failed at line 601 in ``generic_pattern.py``. - It happens when propagating the match in the backward directions. - The unmatched types are Mul, MatMul and below, - it shows the matched nodes. The first one was Cast. - And the failure happened at iteration 5. - ``139774002356544-139774000632672`` is the pair of ids used in container ``matched``. - ``id(node)`` is used as a unique identifiers of the nodes. - - :: - - [RotaryEmbeddingPattern.match] NONE - line: 601:__main__, op_type=Cast - --hint--: BACKWARD: different node types - --pattern - Mul(pos_ids, cast) -> (mul) - -- model - MatMul(/_original_modu...Expand_output_0, /_original_modu...b/Cast_output_0) -> (/_original_modu...MatMul_output_0) - iteration=5 - --matched-- #6 - Cast(/_original_modu...mb/Cos_output_0) ~ Cast(cos) [139774002356544-139774000632672] - Cos(/_original_modu...ncat_1_output_0) ~ Cos(concattraining-transpose-0) [139774002356448-139774000632048] - ConcatTraining(/_original_modu...nspose_output_0,/_original_modu...nspose_output_0) ~ ConcatTraining(transpose,transpose) [139774002356352-139774000631712] - Transpose(/_original_modu...MatMul_output_0) ~ Transpose(mul) [139774002356256-139774000631184] - Sin(/_original_modu...ncat_1_output_0) ~ Sin(concattraining-transpose-0) [139774002358512-139774000631568] - Cast(/_original_modu...mb/Sin_output_0) ~ Cast(sin) [139774002358608-139774000632384] - len(stack)=0:[] - - 'hints' are not added everywhere. More can easily be added with method ``_hint``. - """ - if node and self.verbose: - if self.verbose >= 10: - if hasattr(self, "_debug"): - msg2 = self._debug_print() - if msg2: - msg2 = f"\n{textwrap.indent(msg2, ' ')}" - else: - msg2 = "" - print( - f"[{self.__class__.__name__}.match] Match failed at line: {lineno}:" - f"{os.path.split(self.__class__.__module__)[-1]}, " - f"op_type={node.op_type}{msg}{msg2}" - ) - return None - - def print_match(self, graph_node: ir.Node, pattern_node: orp.NodePattern) -> str: - s1 = _node_to_str(graph_node) - s2 = _node_to_str(pattern_node) - return f"match {s1} with pattern: {s2}" - - def _debug_print(self) -> str: - if not hasattr(self, "_debug"): - return "" - - def _s(s: str) -> str: - if len(s) <= 30: - return s - return f"{s[:15]}...{s[-15:]}" - - def _p(n: ir.Node, full: bool = False) -> str: - if full: - return str(n) - return _node_to_str(n) - - rows = [] - for k, v in sorted(self._debug.items()): - if k == "stack": - rows.append(f"len({k})={len(v)}:{v}") # type: ignore[arg-type] - continue - if k == "iteration": - rows.append(f"{k}={v}") - continue - if k == "matched": - rows.append(f"--matched-- #{len(v)}") # type: ignore[arg-type] - for pattern_node, graph_node in v.items(): - rows.append( - f" {_p(pattern_node)} ~ {_p(graph_node)} [{id(pattern_node)}-{id(graph_node)}]" - ) - continue - if k == "hint": - rows.append(f"--hint--: {v[0]}") # type: ignore[arg-type] - for i in v[1:]: - if isinstance(i, str): - rows.append(" " + i) - if isinstance(i, ir.Node): - rows.append(" " + _p(i, full=True)) - continue - if k in {"node", "pattern", "pattern_node", "pattern_nodes"}: - continue - rows.append(f"-- not shown {k}") - - return "\n".join(rows) - - def _hint(self, *args: Any) -> None: - """Add debugging information to help users.""" - self._debug["hint"] = args - - def _match_backward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_node: ir.Node, - pattern_node: orp.NodePattern, - ) -> int | None: - """ - Matches backward. - - Args: - starting_node: root node (the node the matched begain with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes, None or False to indicate a failed match - """ - match_count = 0 - - # predecessors - if len(graph_node.inputs) != len(pattern_node.inputs): - # not the same number of inputs - self._hint( - "BACKWARD: not the same number of inputs", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_input, pattern_input in zip(graph_node.inputs, pattern_node.inputs): - if len(graph_input.uses()) != len(pattern_input.uses()): - self._hint( - "BACKWARD: one input is used outside the pattern", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_value, pattern_value in zip(graph_node.inputs, pattern_node.inputs): - # TODO(rama): Handle constant-pattern - pattern_pred = pattern_value.producer() - if pattern_pred is None: - # pattern_pred is None means the pattern backward search ends here. - result = self._match_values_forward( - starting_node, matched, stack, graph_value, pattern_value - ) - if result is None: - return result - match_count += result - continue - graph_pred = graph_value.producer() - if graph_pred is None: - # No node in the graph. - return self.none(starting_node, inspect.currentframe().f_lineno) - if graph_pred.op_identifier() != pattern_pred.op_identifier(): - self._hint( - "BACKWARD: different node types", - "--pattern", - _node_to_str(pattern_pred), - "-- model", - _node_to_str(graph_pred), - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - # matching backward - if pattern_pred not in matched: - if self.verbose >= 10: - print( - f"[GenericPattern._match_backward] {self.print_match(graph_pred, pattern_pred)}" - ) - matched[pattern_pred] = graph_pred - stack.append(pattern_pred) - match_count += 1 - if self.verbose > 5 and match_count > 0: - print(f"[GenericPatternMatcher._match_backward] add {match_count} nodes") - return match_count - - def _match_values_forward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_value: ir.Value, - pattern_value: orp.ValuePattern, - ) -> int | None: - """ - Matches forward. - - Args: - starting_node: root node (the node the match begins with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_value: value coming from the graph - pattern_value: pattern value coming from the pattern - - Returns: - number of matched nodes to continue, None or False to indicate a failed match - """ - match_count = 0 - graph_node_users = [user for user, _ in graph_value.uses()] - pattern_node_users = [user for user, _ in pattern_value.uses()] - if not pattern_node_users: - # The pattern has no node forward, the matching stops. - return match_count - if len(graph_node_users) < len(pattern_node_users): - # Not enough node in the graph to match the pattern. A match is not possible - return self.none(starting_node, inspect.currentframe().f_lineno) - - # Here comes the fun part, there is the same number of successors or more - # nodes in the graph to match with the pattern. - # And we have to handle the nodes already matched as found. - # Hopefully, there is only one option. - - if len(graph_node_users) == len(pattern_node_users) == 1: - # Let's deal with the simple case - if graph_node_users[0].op_identifier() != pattern_node_users[0].op_identifier(): - return self.none(starting_node, inspect.currentframe().f_lineno) - - node = pattern_node_users[0] - if node not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward]{self.print_match(graph_node_users[0], pattern_node_users[0])}" - ) - matched[node] = graph_node_users[0] - stack.append(node) - match_count += 1 - return match_count - - # Let's remove the nodes already matched. - pattern_node_users_not_matched = [ - unmatched_node - for unmatched_node in pattern_node_users - if unmatched_node not in matched - ] - pattern_node_users_matched = [ - matched[matched_node] - for matched_node in pattern_node_users - if matched_node in matched - ] - assert len(pattern_node_users_matched) + len(pattern_node_users_not_matched) == len( - pattern_node_users - ), ( - f"pattern_node_users_not_matched={pattern_node_users_not_matched}, " - f"pattern_node_users_matched={pattern_node_users_matched}, " - f"pattern_node_users={pattern_node_users}, " - f"matched={matched}" - ) - free = list(set(graph_node_users) - set(pattern_node_users_matched)) - if not pattern_node_users_not_matched: - # Everything is already matched. - return match_count - if len(free) < len(pattern_node_users_not_matched): - # Not enough successors to match the remaining patterns. - return self.none(starting_node, inspect.currentframe().f_lineno) - if len(pattern_node_users_not_matched) == len(free) == 1: - # Only one option again. - graph_node = free[0] - if pattern_node_users_not_matched[0].op_identifier() != graph_node.op_identifier(): - return self.none(starting_node, inspect.currentframe().f_lineno) - - key = pattern_node_users_not_matched[0] - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward] {self.print_match(graph_node, pattern_node_users_not_matched[0])}" - ) - matched[key] = graph_node - stack.append(key) - match_count += 1 - return match_count - - # And now another fun part, let's try to handle the case when - # there is only one option, matching on node type only returns one - # option. - expected_op_type = [_.op_identifier() for _ in pattern_node_users_not_matched] - got_op_type = [_.op_identifier() for _ in free] - - ec = collections.Counter(expected_op_type) - gc = collections.Counter(got_op_type) - if len(ec) != len(gc) or set(ec) != set(gc): - # unique operator types is different. - self._hint( - "FORWARD: unique operator types are different", - "-- pattern", - ec, - pattern_value, - "-- model", - gc, - graph_value, - "-- model-matched", - pattern_node_users_matched, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - for k, v in ec.items(): - if gc[k] < v: - # Not enough types to match. - return self.none(starting_node, inspect.currentframe().f_lineno) - - # At this stage, we know matching the types is possible. - # We first mark whatever is possible. - ptype_to_node = {_.op_identifier(): _ for _ in pattern_node_users_not_matched} - gtype_to_node = {_.op_identifier(): _ for _ in free} - missing = [] - for k, v in ec.items(): - if gc[k] == v == 1: - key = id(ptype_to_node[k]) - if key not in matched: - if self.verbose >= 10: - print( - f"[GenericPatternMatcher._match_values_forward] match " - f"{self.print_match(gtype_to_node[k], ptype_to_node[k])}" - ) - matched[key] = gtype_to_node[k] - stack.append(key) - match_count += 1 - else: - missing.append(k) - - if not missing: - return match_count - - # At this stage, there are mutiple options for matching. We can: - # 1. make assumptions and continue - # 2. mark the node as incomplete matching, we could end up stuck anyway. - raise NotImplementedError( - f"There are more than one option, this will be implemented later, ec={ec}, gc={gc}" - ) - - def _match_forward( - self, - starting_node: ir.Node, - matched: dict[orp.NodePattern, ir.Node], - stack: list[orp.NodePattern], - graph_node: ir.Node, - pattern_node: orp.NodePattern, - ) -> int | None: - """ - Matches forward. - - Args: - starting_node: root node (the node the match begins with, used only for debugging) - matched: nodes of the pattern matched as already matched - stack: next node to look into - graph_node: node coming from the graph - pattern_node: node coming from the pattern - - Returns: - number of matched nodes to continue, None or False to indicate a failed match - """ - match_count = 0 - - # successors - if len(graph_node.outputs) != len(pattern_node.outputs): - # not the same number of outputs - self._hint( - "FORWARD: not the same number of output_names", - "-- pattern", - pattern_node, - "-- model", - graph_node, - ) - return self.none(starting_node, inspect.currentframe().f_lineno) - - for graph_output, pattern_output in zip(graph_node.outputs, pattern_node.outputs): - result = self._match_values_forward( - starting_node, matched, stack, graph_output, pattern_output - ) - if result is None: - return result - match_count += result - - if self.verbose > 5 and match_count > 0: - print(f"[GenericPatternMatcher._match_forward] add {match_count} nodes") - return match_count - - def match( - self, - model: ir.Model, - graph_or_function: ir.Graph | ir.Function, - node: ir.Node, - *, - verbose: int = 0, - remove_nodes: bool = True, - tracer: orp.MatchingTracer | None = None, - ) -> orp.MatchResult | None: - if not remove_nodes: - raise NotImplementedError( - "remove_nodes=False is not implemented in GenericPatternMatcher" - ) - del model - del graph_or_function - self.verbose = verbose - self._debug = {} - - # Let's match the last node. - # Then we need to match successors and predecessors. - last_pattern_node = self.pattern.node(-1) - if node.op_identifier() != last_pattern_node.op_identifier(): - # The last node does not have the same op_identifier(). - return self.none() - - if self.verbose > 5: - print( - f"[GenericPatternMatcher.match] Matching started at node: {_node_to_str(node)}" - ) - if self.verbose >= 10: - print(f"[GenericPatternMatcher.match] match pattern {self}") - - all_pattern_nodes = set(self.pattern) - matched: dict[orp.NodePattern, ir.Node] = {last_pattern_node: node} - stack: list[orp.NodePattern] = [last_pattern_node] - iteration = 0 - - if self.verbose > 5: - self._debug = dict( - pattern=self.pattern, - matched=matched, - stack=stack, - iteration=iteration, - node=node, - pattern_node=last_pattern_node, - pattern_nodes=self.pattern, - ) - - max_iter = self.pattern.num_nodes() * 2 - while stack and iteration < max_iter: - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - f"\nall_pattern_nodes={all_pattern_nodes}" - ) - - # TODO(justinchuby): Change to a for loop - iteration += 1 - if self.verbose > 5: - print( - f"[GenericPatternMatcher.match] iteration={iteration} " - f"n_matched={len(matched)}, n_stack={len(stack)}, " - f"matched_types={collections.Counter(_.op_identifier() for _ in matched)}" - ) - next_pattern_node = stack.pop() - next_graph_node = matched[next_pattern_node] - - result = self._match_backward( - node, matched, stack, next_graph_node, next_pattern_node - ) - if result is None: - if self.verbose > 5: - print("[GenericPatternMatcher.match] done. backward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - ) - - result = self._match_forward( - node, matched, stack, next_graph_node, next_pattern_node - ) - if result is None: - if self.verbose > 5: - print("[GenericPatternMatcher.match] done. forward failed.") - return result - - nodes_not_in_pattern = set(matched.keys()) - all_pattern_nodes - assert not nodes_not_in_pattern, ( - f"Some nodes are not part of the pattern: {nodes_not_in_pattern}" - ) - - if self.verbose > 5: - self._debug["iteration"] = iteration - - if iteration >= max_iter and stack: - self._hint(f"reached {iteration}>={max_iter} iterations") - return self.none(node, inspect.currentframe().f_lineno) - - if self.verbose > 5: - print(f"[GenericPatternMatcher.match] done. {len(matched)} matched nodes") - - # At this point, the pattern is matched but let's make sure. - assert len(matched) == self.pattern.num_nodes(), ( - f"Number of matched nodes is different, {len(matched)} matched nodes, " - f"and {len(self.pattern)} nodes in the pattern, matched is {matched}" - ) - assert len(stack) == 0, f"There are still {len(stack)} nodes to explore." - - # We order the matched nodes in the same order than the pattern - # to let next functions to be able to build the matching again. - matched_nodes = [matched[pattern_node] for pattern_node in self.pattern] - return _to_match_result(PatternMatchResult(self.pattern, matched_nodes)) - - -def make_pattern_rule( - match_pattern_function: Callable, - apply_pattern_function: Callable, - validate_mapping: Callable | None = None, - verbose: int = 0, -) -> orp.RewriteRule: - """ - Creates a rewriting rule from a callable or a function proto. - - Args: - match_pattern_function: an onnxscript-like function that defines - the pattern subgraph (nodes) to be replaced - apply_pattern_function: an onnxscript-like function that constructs - the replacement subgraph (new nodes replacing the matched nodes) - validate_mapping: a function that validates the matching subgraph once - it is found. If it returns False the pattern is not applied. - If not specified, it is equivalent to a function that always return True - verbose: verbosity level - - Returns: - the rewriting rule - """ - - warnings.warn( - "make_pattern_rule(...) is deprecated, use pattern.RewriteRule(...) instead", - FutureWarning, - stacklevel=2, - ) - pattern = orp._to_graph_pattern(match_pattern_function) - matcher = GenericPatternMatcher(pattern) - return orp.RewriteRule( - pattern, - apply_pattern_function, - validate_mapping, - matcher, - verbose=verbose, - ) diff --git a/onnxscript/rewriter/generic_pattern_test.py b/onnxscript/rewriter/generic_pattern_test.py deleted file mode 100644 index dadaf5e8bb..0000000000 --- a/onnxscript/rewriter/generic_pattern_test.py +++ /dev/null @@ -1,607 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import contextlib -import io -import os -import unittest - -import numpy as np -import onnx -import onnx.parser -import onnx.reference -import onnxruntime as ort -import parameterized - -from onnxscript import ir -from onnxscript.rewriter import generic_pattern, pattern - -FLOAT = onnx.TensorProto.FLOAT - - -@parameterized.parameterized_class( - ("matcher_algo",), - [ - (generic_pattern.GenericPatternMatcher,), - (pattern.SimplePatternMatcher,), - ], -) -class GenericPatternTest(unittest.TestCase): - def _range(self, *shape, bias: float | None = None): - n = np.prod(shape) - x = np.arange(n).astype(np.float32) / n - if bias: - x = x + bias - return x.reshape(tuple(shape)).astype(np.float32) - - def test_graph_pattern_builder(self): - """Test replacing Add + Add by AddAdd.""" - - def match_pattern(op, x, y, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - return op.Add(tmp, z) - - def apply_pattern(op, x, y, z, **_): - """Builds the replacement graph.""" - return op.AddAdd(x, y, z, _domain="ZZZ") - - def validate_mapping(context, x, y, z, **_) -> bool: - """Validates the mapping.""" - del context - return True - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - ) - - class AddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, z): - return (x + y + z,) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["final"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - ], - [onnx.helper.make_tensor_value_info("final", FLOAT, [None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - rewriten_model = ir.serde.serialize_model(ir_model) - self.assertEqual( - ["AddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def test_graph_pattern_builder_multi_outputs(self): - def match_pattern(op, x, y, w, z): - """Builds the pattern to match.""" - tmp = op.Add(x, y) - tmp2 = op.Add(tmp, w) - r1 = op.Add(tmp, z) - return tmp2, r1 - - def apply_pattern(op, x, y, w, z, **_): - """Builds the pattern to match.""" - return op.AddAddAddAdd(x, y, w, z, _domain="ZZZ", _outputs=2) - - def validate_mapping(context, **_) -> bool: - return True - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - verbose=10, - ) - - class AddAddAddAdd(onnx.reference.op_run.OpRun): - op_domain = "ZZZ" - - def _run(self, x, y, w, z): - return (x + y + w, x + y + z) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["x", "y"], ["gggg"]), - onnx.helper.make_node("Add", ["gggg", "w"], ["f1"]), - onnx.helper.make_node("Add", ["gggg", "z"], ["f2"]), - ], - "dummy", - [ - onnx.helper.make_tensor_value_info("x", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("y", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("z", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("w", FLOAT, [None, None]), - ], - [ - onnx.helper.make_tensor_value_info("f1", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("f2", FLOAT, [None, None]), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ir_version=9, - ) - onnx.checker.check_model(model) - - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule.apply_to_model(ir_model) - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in ir_model.graph], - ) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual( - ["AddAddAddAdd"], - [n.op_type for n in rewriten_model.graph.node], - ) - - feeds = { - "x": self._range(5, 6), - "y": self._range(5, 6), - "w": self._range(5, 6), - "z": self._range(5, 6), - } - ref1 = onnx.reference.ReferenceEvaluator(model) - expected = ref1.run(None, feeds) - - self.assertEqual(0, len(rewriten_model.graph.initializer)) - opsets = {v.domain: v.version for v in rewriten_model.opset_import} - self.assertIn("ZZZ", opsets) - self.assertEqual(opsets["ZZZ"], 1) - - ref2 = onnx.reference.ReferenceEvaluator(rewriten_model, new_ops=[AddAddAddAdd]) - got = ref2.run(None, feeds) - np.testing.assert_almost_equal(expected[0], got[0]) - - def check_with_ort(self, model: onnx.ModelProto, providers=None): - if providers is None: - providers = ["CPUExecutionProvider"] - - if isinstance(model, onnx.ModelProto): - model = model.SerializeToString() - session = ort.InferenceSession(model, providers=providers) - return session - - def get_rotary_model(self): - inputs = [ - onnx.helper.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), - onnx.helper.make_tensor_value_info("pos_ids", FLOAT, shape=[]), - onnx.helper.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), - ] - nodes = [ - onnx.helper.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), - onnx.helper.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), - onnx.helper.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), - onnx.helper.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), - onnx.helper.make_node( - "ConcatTraining", - ["_onx_transpose0", "_onx_transpose0"], - ["_onx_concattraining0", "_onx_concattraining1"], - domain="com.microsoft", - ), - onnx.helper.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), - onnx.helper.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), - onnx.helper.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), - onnx.helper.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), - ] - outputs = [ - onnx.helper.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), - onnx.helper.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), - ] - model = onnx.helper.make_model( - onnx.helper.make_graph( - nodes, - "experiment", - inputs, - outputs, - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 18), - ], - ) - return model - - def test_shared_root_value_test(self): - def match_pattern(op, x): - t1 = op.Sin(x) - t2 = op.Cos(x) - return t1, t2 - - def apply_pattern(op, x, **_): - return op.SinCos(x, _domain="com.microsoft", _outputs=2) - - rule = pattern.RewriteRule(match_pattern, apply_pattern, matcher=self.matcher_algo) - model_proto = onnx.parser.parse_model( - """ - - agraph (float[N] y) => (float[N] z) - { - temp1 = Sin(y) - temp2 = Cos(y) - z = Add(temp1, temp2) - } - """ - ) - onnx.checker.check_model(model_proto) - model = onnx.shape_inference.infer_shapes(model_proto) - ir_model = ir.serde.deserialize_model(model) - rule.apply_to_model(ir_model) - rewritten_model = ir.serde.serialize_model(ir_model) - graph = rewritten_model.graph - self.assertEqual(len(graph.node), 2) - self.assertEqual(graph.node[0].op_type, "SinCos") - - def test_shared_root_value_extra_use(self): - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - raise unittest.SkipTest("GenericPatternMatcher does not support extra uses yet.") - - def match_pattern(op, x): - t1 = op.Sin(x) - t2 = op.Cos(x) - return t1, t2 - - def apply_pattern(op, x, **_): - return op.SinCos(x, _domain="com.microsoft", _outputs=2) - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - matcher=self.matcher_algo, - ) - model_proto = onnx.parser.parse_model( - """ - - agraph (float[N] y) => (float[N] z) - { - temp1 = Sin(y) - temp2 = Cos(y) - w = Add(temp1, temp2) - z = Mul(w, y) - } - """ - ) - onnx.checker.check_model(model_proto) - model = onnx.shape_inference.infer_shapes(model_proto) - ir_model = ir.serde.deserialize_model(model) - rule.apply_to_model(ir_model) - graph = ir_model.graph - self.assertEqual(len(graph), 3) - self.assertEqual(graph.node(0).op_type, "SinCos") - - def test_rotary_embedding(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def match_pattern(op, x, pos_ids, axis): - # original code: the code does verifies the constant yet - # unsqueeze = op.Unsqueeze(x, [1]) - - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, - transpose, - _domain="com.microsoft", - _outputs=2, - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_mapping(match_result, **_) -> bool: - del match_result - return True - - def apply_pattern(op, x, pos_ids, axis, **_): - del axis - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - return op.RotaryEmbedding( - x, - pos_ids, - cos_cache, - sin_cache, - _domain="com.microsoft", - _outputs=2, - ) - - rule = pattern.RewriteRule( - match_pattern, - apply_pattern, - validate_mapping, - self.matcher_algo, - verbose=10, - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(Rama): What is this assertion testing? Is it to check that `verbose` is working? - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - self.assertIn("[GenericPatternMatcher.match", out) - - def test_rotary_embedding_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, transpose, _domain="com.microsoft", _outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis, **_): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 - ) - return part1, part2 - - rule = pattern.RewriteRule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - self.matcher_algo, - verbose=10, - ) - - model = self.get_rotary_model() - - buffer = io.StringIO() - with contextlib.redirect_stdout(buffer): - # back to ir - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - # starts matching - rule.apply_to_model(ir_model) - ir_model.opset_imports["com.microsoft"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Constant", "Constant", "RotaryEmbedding"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - out = buffer.getvalue() - # TODO(justinchuby): Remove this assert - capturing stdout is not robust - if self.matcher_algo is generic_pattern.GenericPatternMatcher: - self.assertIn("[GenericPatternMatcher.match", out) - - def test_rotary_emb_file_onnxscript(self): - # The test work on a model if it has the expected name. - # A dummy model is used if not present (not implemented yet). - - def rotary_match_pattern(op, x, pos_ids, axis): - unsqueeze = op.Unsqueeze(x, axis) - cast = op.Cast(unsqueeze, to=FLOAT) - - matmul = op.MatMul(pos_ids, cast) - transpose = op.Transpose(matmul) - output, _length = op.ConcatTraining( - transpose, transpose, _domain="com.microsoft", _outputs=2 - ) - - sin = op.Sin(output) - cast1 = op.Cast(sin, to=FLOAT) - cos = op.Cos(output) - cast2 = op.Cast(cos, to=FLOAT) - return cast1, cast2 - - def validate_rotary_mapping(match_result, **_) -> bool: - # If some pattern needs to be rejected. - del match_result - return True - - def rotary_apply_pattern(op, x, pos_ids, axis): - cos_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - sin_cache = op.Constant( - value=onnx.numpy_helper.from_array(np.random.rand(256, 256).astype(np.float16)) - ) - part1, part2 = op.RotaryEmbedding( - x, pos_ids, cos_cache, sin_cache, _domain="com.microsoft", _outputs=2 - ) - return part1, part2 - - model_path = "gemma_optimized_pre_grad_training_2.onnx" - if not os.path.exists(model_path): - raise unittest.SkipTest(f"{model_path!r} is missing") - model = onnx.load(model_path) - model = onnx.shape_inference.infer_shapes(model) - ir_model = ir.serde.deserialize_model(model) - - rule = pattern.RewriteRule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, - self.matcher_algo, - verbose=10, - ) - - rule.apply_to_model(ir_model) - # TODO: do that in pattern.py. - ir_model.opset_imports["ZZZ"] = 1 - - rewriten_model = ir.serde.serialize_model(ir_model) - - buffer = rewriten_model.SerializeToString() - with open(f"{model}.opt.onnx", "wb") as f: - f.write(buffer) - self.check_with_ort(rewriten_model) - - def test_transpose_transpose_onnxscript(self): - # TODO(rama): Attribute-parameters not yet supported in multi-output matching. - # def transpose_transpose_pattern(op, X, perm0, perm1): - # xt = op.Transpose(X, perm=perm0) - # Y = op.Transpose(xt, perm=perm1) - # return Y - - def transpose_transpose_pattern(op, X): - XT = op.Transpose(X, _outputs=["XT"]) - Y = op.Transpose(XT, _outputs=["Y"]) - return Y - - def transpose_transpose_mapping(perm0, perm1): - new_perm = [0 for p in perm0] - for i, p in enumerate(perm1): - new_perm[i] = perm0[p] - # replace by return [perm0[p] for p in perm1] ? - return new_perm - - def transpose_transpose_check(op, **_) -> bool: - return True - - def transpose_transpose_apply_pattern(op, X, XT: ir.Value, Y, **_): - perm0 = XT.producer().attributes.get("perm") - if perm0 is not None: - perm0 = perm0.value # TODO(rama): handle RefAttr - perm1 = Y.producer().attributes.get("perm") - if perm1 is not None: - perm1 = perm1.value # TODO(rama): handle RefAttr - if perm0 is None and perm1 is None: - return op.Identity(X) - if perm0 is None: - perm0 = range(len(perm1) - 1, -1, -1) - if perm1 is None: - perm1 = range(len(perm0) - 1, -1, -1) - composed_perm = transpose_transpose_mapping(perm0, perm1) - return op.Transpose(X, perm=composed_perm) - - rule = pattern.RewriteRule( - transpose_transpose_pattern, - transpose_transpose_apply_pattern, - transpose_transpose_check, - self.matcher_algo, - verbose=0, - ) - - model = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["xt"], perm=[1, 2, 0]), - onnx.helper.make_node("Transpose", ["xt"], ["Y"], perm=[1, 2, 0]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [None, None, None])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [None, None, None])], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ) - - # back to ir - ir_model = ir.serde.deserialize_model(model) - - # starts matching - - rule.apply_to_model(ir_model) - rewriten_model = ir.serde.serialize_model(ir_model) - - expected = ["Transpose"] - self.assertEqual(expected, [n.op_type for n in rewriten_model.graph.node]) - node = rewriten_model.graph.node[0] - self.assertEqual(len(node.attribute), 1) - att = node.attribute[0] - self.assertEqual(att.name, "perm") - self.assertEqual(list(att.ints), [2, 0, 1]) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/models/_rotary_embedding_models.py b/onnxscript/rewriter/models/_rotary_embedding_models.py index ecdb7d138b..3709cd04f7 100644 --- a/onnxscript/rewriter/models/_rotary_embedding_models.py +++ b/onnxscript/rewriter/models/_rotary_embedding_models.py @@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1 emb = op.Concat(freqs, freqs, axis=-1) cos = op.Cos(emb) sin = op.Sin(emb) - cos_4d = op.Unsqueeze(cos, 1) - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) + sin_4d = op.Unsqueeze(sin, [1]) x1 = op.Slice(x, [0], [4], [3], [1]) x2 = op.Slice(x, [4], [8], [3], [1]) @@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query): # Split the query for partial embedding to_embed = op.Slice(query, [0], [32], [3], [1]) unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1]) - cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd] - sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd] + cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd] + sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd] # Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X) # essentially represents X rotated by 90 degrees to_embed_times_cos = op.Mul(to_embed, cos_4d) diff --git a/onnxscript/rewriter/models/_smollm_1.py b/onnxscript/rewriter/models/_smollm_1.py index d592eb2572..e3efecfe17 100644 --- a/onnxscript/rewriter/models/_smollm_1.py +++ b/onnxscript/rewriter/models/_smollm_1.py @@ -59,8 +59,8 @@ def main_graph( minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38]) mask_10x10 = opset18.Trilu(minus_inf_10x10, 1) slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10]) - unsqueeze_2 = opset18.Unsqueeze(input1, 1) - unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2) + unsqueeze_2 = opset18.Unsqueeze(input1, [1]) + unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2]) add = slice_5 + unsqueeze_3 eq = add == 0.0 slice_10 = slice_5 @@ -69,7 +69,7 @@ def main_graph( slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3]) val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3]) slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3]) - unsqueeze_6 = opset18.Unsqueeze(input2, 1) + unsqueeze_6 = opset18.Unsqueeze(input2, [1]) to_copy_1 = opset18.Cast(unsqueeze_6, to=1) view_1 = opset18.Constant( value=ir.tensor( @@ -138,8 +138,8 @@ def main_graph( transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3]) view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - unsqueeze_7 = opset18.Unsqueeze(cos, 1) - unsqueeze_8 = opset18.Unsqueeze(sin, 1) + unsqueeze_7 = opset18.Unsqueeze(cos, [1]) + unsqueeze_8 = opset18.Unsqueeze(sin, [1]) mul_5 = transpose_1 * unsqueeze_7 val_267 = opset18.Constant(value_ints=[1]) slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267) diff --git a/onnxscript/rewriter/models/_smollm_2.py b/onnxscript/rewriter/models/_smollm_2.py index 62d857a2d6..47ad451895 100644 --- a/onnxscript/rewriter/models/_smollm_2.py +++ b/onnxscript/rewriter/models/_smollm_2.py @@ -51,7 +51,7 @@ def main_graph( gt = arange_1 > view convert_element_type_default = opset18.Cast(gt, to=1) mul = triu * convert_element_type_default - dim__2 = opset18.Constant(value_int=0) + dim__2 = opset18.Constant(value_ints=[0]) dim_0__2 = opset18.Cast(dim__2, to=7) unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2) val_15 = opset18.Cast(0, to=7) @@ -65,7 +65,7 @@ def main_graph( val_25 = opset18.Reshape(val_23, val_24, allowzero=0) val_26 = opset18.Constant(value_ints=[1]) slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26) - dim__3 = opset18.Constant(value_int=2) + dim__3 = opset18.Constant(value_ints=[2]) dim_0__3 = opset18.Cast(dim__3, to=7) unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3) _to_copy = opset18.Cast(unsqueeze_1, to=1) @@ -83,7 +83,7 @@ def main_graph( val_36 = opset18.Reshape(val_34, val_35, allowzero=0) val_37 = opset18.Constant(value_ints=[1]) slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37) - dim__5 = opset18.Constant(value_int=1) + dim__5 = opset18.Constant(value_ints=[1]) dim_0__5 = opset18.Cast(dim__5, to=7) unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5) val_38 = opset18.Cast(0, to=7) @@ -160,10 +160,10 @@ def main_graph( val_71 = opset18.Cast([1, 30, 32, 64], to=7) view_12 = opset18.Reshape(view_9, val_71, allowzero=0) transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3]) - dim__8 = opset18.Constant(value_int=1) + dim__8 = opset18.Constant(value_ints=[1]) dim_0__8 = opset18.Cast(dim__8, to=7) unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8) - dim__9 = opset18.Constant(value_int=1) + dim__9 = opset18.Constant(value_ints=[1]) dim_0__9 = opset18.Cast(dim__9, to=7) unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9) mul_5 = transpose_1 * unsqueeze_3 @@ -222,10 +222,10 @@ def main_graph( add_2 = mul_7 + mul_8 cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2) cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2) - dim__10 = opset18.Constant(value_int=0) + dim__10 = opset18.Constant(value_ints=[0]) dim_0__10 = opset18.Cast(dim__10, to=7) unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10) - dim__11 = opset18.Constant(value_int=1) + dim__11 = opset18.Constant(value_ints=[1]) dim_0__11 = opset18.Cast(dim__11, to=7) unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11) val_114 = opset18.Cast(0, to=7) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 0a45f3017c..008a995764 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding +from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -24,6 +24,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: counts: dict[str, int] = {} counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) + counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) return counts diff --git a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py index b9666fba3a..c527855bb7 100644 --- a/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py +++ b/onnxscript/rewriter/onnxruntime/bfloat16_utils/bfloat16_converter_test.py @@ -3,7 +3,6 @@ import unittest import numpy as np -import onnx import onnx.checker import onnx.shape_inference import onnxruntime @@ -14,11 +13,11 @@ class Bfloat16ConversionTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4])) + self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4])) self.v0.dtype = ir.DataType.BFLOAT16 - self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4])) + self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4])) self.v1.dtype = ir.DataType.BFLOAT16 - self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4])) + self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4])) self.v2.dtype = ir.DataType.BFLOAT16 self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index faca1f9ba8..fa1f0c109b 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -8,7 +8,7 @@ import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize -from onnxscript.rewriter import gemm_to_matmul_add, rewrite +from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import ( instance_to_group_normalization, softmax, @@ -29,10 +29,12 @@ fuse_rotary_embedding, ) from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa +from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha from onnxscript.rewriter.ort_fusions.skip_normalization import ( fuse_skip_layer_normalization, fuse_skip_rms_normalization, ) +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add ORT_PATTERN_REWRITE_RULES = [ *softmax.rules.rules, @@ -103,6 +105,7 @@ def fuse(func, **kwargs): fusion_count["attention"] = fuse(fuse_attention) fusion_count["gelu"] = fuse(fuse_gelu) fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) + fusion_count["sdpa_via_mha"] = fuse(replace_sdpa_by_mha) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) @@ -114,6 +117,7 @@ def optimize_for_ort( config_name: str | None = None, *, debug: bool = False, + clear_metadata: bool = False, ) -> tuple[ir.Model, dict[str, int]]: """ Optimize the model for ORT backend. @@ -127,13 +131,14 @@ def optimize_for_ort( Typically it identifies the Execution Provider (EP) to optimize for. If None, the default configuration will be used. debug: If debug is True, enable pattern matching tracer for debugging. + clear_metadata: If True, clear metadata and doc strings from the model. Returns: A tuple containing: - The optimized `ir.Model` after applying transformer-specific fusions. - A dictionary with a count of each of the fusions applied. """ - rewrite(model, [gemm_to_matmul_add.rule]) + rewrite(model, [_gemm_to_matmul_add.gemm_to_matmul_add_rule]) model, fusion_count = fuse_xformers( model, debug=debug, @@ -144,14 +149,16 @@ def optimize_for_ort( passes = ir.passes.Sequential( # Apply the ORT optimization passes. # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 - common_passes.ClearMetadataAndDocStringPass(), # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), common_passes.RemoveInitializersFromInputsPass(), common_passes.ShapeInferencePass(), - common_passes.CheckerPass(), ) assert passes.in_place result = passes(model) assert result.model is model + + if clear_metadata: + common_passes.ClearMetadataAndDocStringPass()(model) + return model, fusion_count diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 4a4cd0ad8e..ce234bbb63 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -160,7 +160,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index cba06d2fb7..8e6ec1d9da 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -148,8 +148,8 @@ def pattern( sin = op.Sin(emb) if self._cast: sin = op.Cast(sin, to=dtype) - cos_4d = op.Unsqueeze(cos, 1) # convert - sin_4d = op.Unsqueeze(sin, 1) + cos_4d = op.Unsqueeze(cos, [1]) # convert + sin_4d = op.Unsqueeze(sin, [1]) return op.RotaryEmbedding( x, cos_4d, diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py index 48842aa429..4245916c64 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py @@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor): original_outputs = ort_run("original", model, inputs) count = fuse_rotary_embedding(model) self.assertGreater(count, 0) - count = fuse_cos_sin_cache(model) + count = fuse_cos_sin_cache(model, debug=True) self.assertGreater(count, 0) new_outputs = ort_run("optimized", model, inputs) assert_allclose(new_outputs, original_outputs) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index 5082c20464..cdc50c99ae 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -79,7 +79,7 @@ def check( # Check that last two dimensions are swapped expected_perm = list(range(len(perm))) expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: + if list(perm) != expected_perm: return check_result.fail("Permutation values for Transpose are not correct.") elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( self._pos == 2 and not _ir_utils.has_rank(y, 2) @@ -188,7 +188,7 @@ def check( trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) transposed_node = _get_node(transposed, "Transpose") - perm = transposed_node.attributes["perm"].as_ints() + perm = list(transposed_node.attributes["perm"].as_ints()) if not perm: return check_result.fail("Permutation values for Transpose are not correct.") @@ -296,7 +296,7 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): if perm: # Check that the two dimensions are swapped - if perm != [1, 0]: + if tuple(perm) != (1, 0): return check_result.fail( "Permutation values for Transpose are not correct." ) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index 527d4826d5..f82702d557 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -284,7 +284,7 @@ def _check_model( opt = onnx.reference.ReferenceEvaluator(optimized_model, new_ops=[FusedMatMul]) expected = ref.run(None, feeds) got = opt.run(None, feeds) - self.assertEqual(len(expected), len(got)) + self.assertEqual(len(got), len(expected)) for a, b in zip(expected, got): np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) @@ -319,7 +319,7 @@ def test_fused_matmul_div_models(self, name, script_func, input_types, output_ty rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Constant", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -354,7 +354,7 @@ def test_fused_matmul_with_transpose(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) @@ -366,8 +366,8 @@ def test_should_not_match(self, _, script_func): self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) self.assertEqual( - ["Transpose", "MatMul", "Transpose"], [n.op_type for n in ir_model.graph], + ["Transpose", "MatMul", "Transpose"], ) self._check_model(model_proto, rewritten_model, atol=1e-6) @@ -391,7 +391,7 @@ def test_fused_matmul_with_other_node_in_middle(self, _, script_func): common_passes.ShapeInferencePass()(ir_model) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["Identity", "FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) @parameterized.parameterized.expand( @@ -440,7 +440,7 @@ def test_transpose_fused_matmul_with_batch(self, _, script_func): ir_model = ir.serde.deserialize_model(model_proto) self._apply_fusion_rules(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) - self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self.assertEqual([n.op_type for n in ir_model.graph], ["FusedMatMul"]) self._check_model(model_proto, rewritten_model, atol=1e-6) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 99852f712a..bf883c58bc 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -163,14 +163,38 @@ def pattern( ): # Reshape query from (B, S, D) to (B, S, H, D/H) query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Qwen variant uses normalization of query/key before rotary embedding: + # The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma). + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BSHDh_normalized"] + ) + query_BSHDh = pattern.OrValue([query_BSHDh, query_BSHDh_normalized]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + # Gemma variant uses normalization of query/key before rotary embedding: + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BHSDh_normalized"] + ) + query_BHSDh = pattern.OrValue([query_BHSDh, query_BHSDh_normalized]) + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BSHkvDh_normalized"] + ) + key_BSHkvDh = pattern.OrValue([key_BSHkvDh, key_BSHkvDh_normalized]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + # Gemma variant uses normalization of query/key before rotary embedding: + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BHkvSDh_normalized"] + ) + key_BHkvSDh = pattern.OrValue([key_BHkvSDh, key_BHkvSDh_normalized]) + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) @@ -197,7 +221,9 @@ def pattern( # that share key/value. key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) - key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + # Concat with past_key is optional: + key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope]) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, [2]) key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) key_seq_BHTDh = op.Reshape( key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] @@ -206,7 +232,9 @@ def pattern( # Concatenate past_value cache and current value, expand across heads # that share key/value. value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) - value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + # Concat with past_value is optional: + value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh]) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, [2]) value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) value_seq_BHTDh = op.Reshape( value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] @@ -242,12 +270,27 @@ def check( query_BSHDh, key_BSHkvDh, mask, + query_BSHDh_normalized=None, + query_BHSDh_normalized=None, + key_BSHkvDh_normalized=None, + key_BHkvSDh_normalized=None, **_, ): + result = pattern.MatchResult() + if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None: + return result.fail( + "Query normalized twice", + [query_BSHDh_normalized, query_BHSDh_normalized], + ) + if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None: + return result.fail( + "Key normalized twice", + [key_BSHkvDh_normalized, key_BHkvSDh_normalized], + ) bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return False @@ -256,9 +299,9 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: if no_match(value_BSDkv, ["B", "S", "Dkv"]): return False - if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + if past_key is not None and no_match(past_key, ["B", "Hkv", "P", "Dh"]): return False - if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + if past_value is not None and no_match(past_value, ["B", "Hkv", "P", "Dv"]): return False # TODO: verify Reshapes: @@ -266,7 +309,6 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: # or check Reshape's shape-input value - result = pattern.MatchResult() num_heads = _ir_utils.get_dim(query_BSHDh, 2) kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) if not isinstance(num_heads, int): @@ -316,6 +358,12 @@ def rewrite( cos, sin, mask, + query_BSHDh, + key_BSHkvDh, + query_BSHDh_normalized=None, + query_BHSDh_normalized=None, + key_BSHkvDh_normalized=None, + key_BHkvSDh_normalized=None, **_, ): # Note that the following optimization is specific to current ORT GenAI attention-mask @@ -335,6 +383,31 @@ def rewrite( seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32) max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) + + normalized_query = query_BHSDh_normalized or query_BSHDh_normalized + if normalized_query is not None: + # We apply normalization without the transpose, which is fused into GQA + norm_node = normalized_query.producer() + norm_attrs = norm_node.attributes + norm_scale = norm_node.inputs[1] + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, norm_scale, **norm_attrs + ) + reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1]) + query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD) + + normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized + if normalized_key is not None: + # We apply normalization without the transpose, which is fused into GQA + norm_node = normalized_key.producer() + norm_attrs = norm_node.attributes + norm_scale = norm_node.inputs[1] + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, norm_scale, **norm_attrs + ) + reshape_BSHkvDh_to_BSDkv = op.Constant(value_ints=[0, 0, -1]) + key_BSDkv = op.Reshape(key_BSHkvDh_normalized, reshape_BSHkvDh_to_BSDkv) + return op.GroupQueryAttention( query_BSD, key_BSDkv, diff --git a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py index 0d404b2754..51355fc8cf 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py +++ b/onnxscript/rewriter/ort_fusions/gqa_packed_qkv.py @@ -84,7 +84,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) # Check that if x is being split into q, k, v correctly # based on hidden sizes diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 64cb84d18e..1a79b9c29f 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -10,6 +10,7 @@ import onnx_ir as ir import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort +import parameterized import torch import onnxscript @@ -194,11 +195,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin): value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) # Now, expand from shared heads to all heads - key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2) + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) - value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, 2) + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) @@ -361,6 +362,313 @@ def test_fusion(self): assert_allclose(outputs3, source_model_outputs) +@parameterized.parameterized_class( + [ + {"with_past": True, "transpose_first": True}, + {"with_past": True, "transpose_first": False}, + {"with_past": False, "transpose_first": True}, + {"with_past": False, "transpose_first": False}, + ] +) +class GemmaGQAFusionTest(unittest.TestCase): + with_past = True + transpose_first = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Config parameters + self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? + self.seqlen = 8 + self.kv_seqlen = self.seqlen + self.past_seqlen = 16 if self.with_past else 0 + self.head_size = 16 + self.num_heads = 20 + self.kv_num_heads = 10 + + # Computed config parameters + self.hidden_size = self.head_size * self.num_heads + self.kv_hidden_size = self.head_size * self.kv_num_heads + assert (self.num_heads % self.kv_num_heads) == 0, ( + "num_heads must be divisible by kv_num_heads" + ) + self.num_groups = self.num_heads // self.kv_num_heads + self.total_seqlen = self.seqlen + self.past_seqlen + + # Abbreviations + B = self.batchsize + S = self.seqlen + P = self.past_seqlen + D = self.hidden_size + Dkv = self.kv_hidden_size + Dh = self.head_size + Hkv = self.kv_num_heads + total_seqlen = S + P + max_seqlen = total_seqlen + + # Input/output types have some dimensions as dynamic (even though the + # test case instance has specific values above). + self.input_types = ( + FLOAT["B", "S", D], # query + FLOAT["B", "S", Dkv], # key + FLOAT["B", "S", Dkv], # value + FLOAT["B", Hkv, "P", Dh], # past_key + FLOAT["B", Hkv, "P", Dh], # past_value + FLOAT["max_seqlen", Dh // 2], # cos + FLOAT["max_seqlen", Dh // 2], # sin + FLOAT["Dh"], # query_scale + FLOAT["Dh"], # key_scale + ) + self.output_types = ( + FLOAT["B", "S", D], # attention + FLOAT["B", Hkv, "T", Dh], # present_key + FLOAT["B", Hkv, "T", Dh], # present_value + ) + + self.inputs = { + "query": np.random.rand(B, S, D).astype(np.float32), + "key": np.random.rand(B, S, Dkv).astype(np.float32), + "value": np.random.rand(B, S, Dkv).astype(np.float32), + "past_key": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "past_value": np.random.rand(B, Hkv, P, Dh).astype(np.float32), + "cos": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "sin": np.random.rand(max_seqlen, Dh // 2).astype(np.float32), + "query_scale": np.random.rand(Dh).astype(np.float32), + "key_scale": np.random.rand(Dh).astype(np.float32), + } + + def source_model_script(self): + with_past = self.with_past + transpose_first = self.transpose_first + scale_factor = math.sqrt(math.sqrt(self.head_size)) + minval = torch.finfo(torch.float32).min + minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) + H = [self.num_heads] + Hkv = [self.kv_num_heads] + Dh = [self.head_size] + G = [self.num_groups] + minus_1 = [-1] # inferred dimension in Reshape op + plus_1 = [1] + + @script() + def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scale): + # Shapes used for Reshape ops. Note that we have a few different options on how shapes are + # specified in an ONNX Reshape op (which supports special values 0 and -1 to propagate + # existing dimension and one inferred dimension respectively). The following shapes are + # based on what is observed in Phi models generated by the exporter. + B = op.Shape(query, start=0, end=1) + S = op.Shape(query, start=1, end=2) + past_seq_length = op.Shape(past_key, start=2, end=3) + total_seq_length = op.Add(past_seq_length, S) + + shape_BSHDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSHkvDh = op.Concat(B, S, minus_1, Dh, axis=0) + shape_BSD = op.Concat(B, S, minus_1, axis=0) + shape_BHkvGSDh = op.Concat(B, Hkv, G, total_seq_length, Dh, axis=0) + + shape_BHSDh = op.Concat(B, H, total_seq_length, Dh, axis=0) + + # First, get Q, K, V into right shapes. Inputs are 3D tensors in the BSD format. + # D is different for Q and K/V (not reflected in the names, unfortunately). + # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only + # one sequence length (S) for all Q, K, and V (with no cache). + query_BSHDh = op.Reshape(query, shape_BSHDh) + key_BSHkvDh = op.Reshape(key, shape_BSHkvDh) + + if transpose_first: + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + query_BHSDh_normalized = op.SimplifiedLayerNormalization( + query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + key_BHkvSDh_normalized = op.SimplifiedLayerNormalization( + key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + else: + query_BSHDh_normalized = op.SimplifiedLayerNormalization( + query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + query_BHSDh_normalized = op.Transpose( + query_BSHDh_normalized, perm=[0, 2, 1, 3] + ) + key_BSHkvDh_normalized = op.SimplifiedLayerNormalization( + key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1 + ) + key_BHkvSDh_normalized = op.Transpose( + key_BSHkvDh_normalized, perm=[0, 2, 1, 3] + ) + + value_BSHkvDh = op.Reshape(value, shape_BSHkvDh) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + # Concat past and do rotary embedding + position_ids_1d = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids_1d, [0]) + position_ids_k = op.Unsqueeze(position_ids_1d, [0]) + + query_BHSDh_rope = msft_op.RotaryEmbedding( + query_BHSDh_normalized, + position_ids_q, + cos, + sin, + ) + key_BHkvSDh_rope = msft_op.RotaryEmbedding( + key_BHkvSDh_normalized, + position_ids_k, + cos, + sin, + ) + + if with_past: + key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + else: + key_seq_BHkvSkvDh = key_BHkvSDh_rope + value_seq_BHkvSkvDh = value_BHkvSDh + + # Now, expand from shared heads to all heads + key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, [2]) + key_BHkvGSDh = op.Expand(key_BHkv1SDh, shape_BHkvGSDh) + key_BHSDh = op.Reshape(key_BHkvGSDh, shape_BHSDh) + + value_BHkv1SDh = op.Unsqueeze(value_seq_BHkvSkvDh, [2]) + value_BHkvGSDh = op.Expand(value_BHkv1SDh, shape_BHkvGSDh) + value_BHSDh = op.Reshape(value_BHkvGSDh, shape_BHSDh) + + # Generate causal mask: + # where every row looks like [0, 0, ..., /*diagonal=*/ 0, minval, minval, ...] + seq_len = op.Shape(query, end=2, start=1) + seq_len_0D = op.Squeeze(seq_len) + + past_seq_len_0D = op.Squeeze(past_seq_length) + + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) + total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But duplicating same logic here. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_val = op.Constant(value=minval_tp) + mask_all_min = op.Expand(min_val, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + shape_B111 = op.Concat(B, plus_1, plus_1, plus_1, axis=0) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + + # Now, compute attention: + key_transposed = op.Transpose(key_BHSDh, perm=[0, 1, 3, 2]) + divisor = op.Constant(value_float=scale_factor) + scaled_query = op.Div(query_BHSDh_rope, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + masked_attn_score = op.Add(attn_score, mask_B1ST) + attn_weight = op.Softmax(masked_attn_score, axis=-1) + attention_BHSDh = op.MatMul(attn_weight, value_BHSDh) + + # Reshape back to BSD format + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + attention_BSD = op.Reshape(attention_BSHDh, shape_BSD) + + return attention_BSD, key_seq_BHkvSkvDh, value_seq_BHkvSkvDh + + return gqa + + def test_fusion(self): + """Test that GQA fusion is successful on source model and produces an equivalent model.""" + inputs = self.inputs + + source_model = self.source_model_script().to_model_proto( + input_types=self.input_types, + output_types=self.output_types, + ) + session = ort.InferenceSession( + source_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + source_model_outputs = session.run(None, inputs) + + # Some shapes need to be present in input model for fusion to be successful. + # (i) Shape inference doesn't handle handle ORT contrib ops. + # (ii) TODO: investigate if Reshape(..., ["B", "S", -1, Dh]) handled precisely + # by shape inference. + query_BHSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "query_BHSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.seqlen, self.head_size], + ) + key_BHkvSDh_rope_value_info = onnx.helper.make_tensor_value_info( + "key_BHkvSDh_rope", + onnx.TensorProto.FLOAT, + ["B", self.kv_num_heads, self.seqlen, self.head_size], + ) + query_BSHDh_value_info = onnx.helper.make_tensor_value_info( + "query_BSHDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.num_heads, self.head_size], + ) + key_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "key_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( + "key_BSHkvDh", + onnx.TensorProto.FLOAT, + ["B", self.seqlen, self.kv_num_heads, self.head_size], + ) + key_transposed_value_info = onnx.helper.make_tensor_value_info( + "key_transposed", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.head_size, self.total_seqlen], + ) + value_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "value_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) + source_model.graph.value_info.extend( + [ + query_BHSDh_rope_value_info, + key_BHkvSDh_rope_value_info, + query_BSHDh_value_info, + key_BHSDh_value_info, + key_BSHkvDh_value_info, + key_transposed_value_info, + value_BHSDh_value_info, + ] + ) + + source_model_ir = ir.serde.from_proto(source_model) + inferred_model = shape_inference.infer_shapes(source_model_ir) + onnxscript.optimizer.optimize(inferred_model) + + count = fuse_sdpa(inferred_model, debug=True) + self.assertGreater(count, 0) + + count = fuse_gqa(inferred_model, debug=True) + self.assertGreater(count, 0) + + fused_model = ir.serde.to_proto(inferred_model) + session = ort.InferenceSession( + fused_model.SerializeToString(), providers=("CPUExecutionProvider",) + ) + outputs3 = session.run(None, inputs) + + self.assertEqual(len(outputs3), len(source_model_outputs)) + assert_allclose(outputs3, source_model_outputs) + + class GQAFusionTest2(unittest.TestCase): @unittest.skip("Needs too much memory.") def test_phi4lm(self): diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e2987cfc5e..321e895f44 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -157,7 +157,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(query_BSD, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py index 28b9646ddc..9ecf2ce017 100644 --- a/onnxscript/rewriter/ort_fusions/mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -78,7 +78,7 @@ def check( self.bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(self.bindings, val, dims) + return not _fusion_utils.check_shape_bool(self.bindings, val, dims) if query_matmul.dtype not in valid_float_types: return check_result.fail("Query is not a float or float16 type.", query_matmul) diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index de6e51a5c0..6e9810ce63 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -31,6 +31,10 @@ class RmsNormFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, _mul_order: bool): + super().__init__(name) + self._mul_order = _mul_order + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) x_square = op.Pow(x, 2.0) @@ -42,7 +46,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) # To support float16, we need to ensure the scale is casted or not. scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale]) - return op.Mul(scale, normalized) + # Workaround: can't use OrValue for final (returned) value + if self._mul_order: + return op.Mul(normalized, scale) + else: + return op.Mul(scale, normalized) def check( self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ @@ -77,8 +85,10 @@ def rewrite(self, op, x, scale, epsilon, **_): ) -_rule = RmsNormFusion.rule() -rms_normalization_rules = [_rule] +_rule1 = RmsNormFusion.rule("RmsNormFusion1", _mul_order=False) +_rule2 = RmsNormFusion.rule("RmsNormFusion2", _mul_order=True) + +rms_normalization_rules = [_rule1, _rule2] rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 1d339f43e7..821537afe5 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -12,6 +12,18 @@ Dim = Union[int, ir.SymbolicDim] +# This file contains a fusion rule that recognizes various patterns of scaled dot-product attention +# (SDPA) implementations and replaces them with a single SDPA op. The SDPA op is a temporary fusion +# op defined in the ai.onnxruntime._fusion domain. Subsequent fusion rules will map it into one +# of the various ops defined in ORT: MHA, GQA, or Attention depending on the input patterns. +# The SDPA is a standard scalar dot-product attention with an optional mask input and scaling factor. +# Currently, it is restricted to query, key, and values of rank 4 with shapes: +# Query: [batch_size, num_heads, seq_len, head_size_qk] +# Key: [batch_size, num_heads, seq_len_kv, head_size_qk] +# or [batch_size, seq_len_kv, num_heads, head_size_qk]) +# Value: [batch_size, num_heads, seq_len_kv, head_size_v] +# The key_format attribute indicates which of the two formats the key uses and can be either "BHSd" or "BSHd". + class SDPA(pattern.RewriteRuleClassBase): _scale: float | None @@ -88,6 +100,9 @@ def pattern( ) attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + adj_attn_weight = op.Where(is_nan, 0.0, attn_weight) + attn_weight = pattern.OrValue([adj_attn_weight, attn_weight]) attn_output = op.MatMul(attn_weight, value) return attn_output diff --git a/onnxscript/rewriter/ort_fusions/sdpa_test.py b/onnxscript/rewriter/ort_fusions/sdpa_test.py index 90bcd26097..3b29418cc6 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_test.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_test.py @@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value): scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value): scaled_key = op.Div(key_transposed, divisor) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value): scaled_key = op.Mul(key_transposed, multiplier_k) attn_score = op.MatMul(scaled_query, scaled_key) attn_weight = op.Softmax(attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Div(attn_score, divisor) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value): attn_score = op.MatMul(query, key_transposed) scaled_attn_score = op.Mul(attn_score, multiplier) attn_weight = op.Softmax(scaled_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask): scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask): scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask): attn_score = op.MatMul(scaled_query, scaled_key) masked_attn_score = op.Add(attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask): scaled_attn_score = op.Div(attn_score, divisor) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output @@ -237,24 +285,48 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask): scaled_attn_score = op.Mul(attn_score, multiplier) masked_attn_score = op.Add(scaled_attn_score, mask) attn_weight = op.Softmax(masked_attn_score, axis=-1) - attn_output = op.MatMul(attn_weight, value) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) + return attn_output + + +# This tests a scenario where the key is in BSHd format instead of BHSd, which +# happens due to an optimization that fuses two transposes together, the one +# to convert from BSHd to BHSd and then to BHdS before MatMul. Hence, the first +# transpose down below is different from other test cases. +@script() +def _unmasked_pre_div_sdpa_BSHd_key_script(query, key, value): + key_transposed = op.Transpose(key, perm=[0, 2, 3, 1]) # BSHd to BHdS + divisor = op.Constant(value_float=SQRT_SCALE_FACTOR) + scaled_query = op.Div(query, divisor) + scaled_key = op.Div(key_transposed, divisor) + attn_score = op.MatMul(scaled_query, scaled_key) + attn_weight = op.Softmax(attn_score, axis=-1) + is_nan = op.IsNaN(attn_weight) + zero = op.Constant(value_float=0.0) + adj_attn_weight = op.Where(is_nan, zero, attn_weight) + attn_output = op.MatMul(adj_attn_weight, value) return attn_output class SDPATestCase: - def __init__(self, script_func, *, with_mask): + def __init__(self, script_func, *, with_mask, BSHd_key=False): self.script_func = script_func self.with_mask = with_mask + self.BSHd_key = BSHd_key def get_onnx_model(self): if not hasattr(self, "_onnx_model"): - qkv_type = FLOAT[B, N, S, H] + qv_type = FLOAT[B, N, S, H] mask_type = FLOAT[B, N, S, S] - input_types = [qkv_type, qkv_type, qkv_type] + k_type = FLOAT[B, S, N, H] if self.BSHd_key else FLOAT[B, N, S, H] + input_types = [qv_type, k_type, qv_type] if self.with_mask: input_types.append(mask_type) model_proto = self.script_func.to_model_proto( - input_types=input_types, output_types=[qkv_type] + input_types=input_types, output_types=[qv_type] ) self._onnx_model = ir.serde.deserialize_model(model_proto) return self._onnx_model @@ -263,7 +335,9 @@ def get_ort_inputs(self): if not hasattr(self, "_ort_inputs"): inputs = { "query": numpy.random.rand(B, N, S, H).astype(numpy.float32), - "key": numpy.random.rand(B, N, S, H).astype(numpy.float32), + "key": numpy.random.rand(B, S, N, H).astype(numpy.float32) + if self.BSHd_key + else numpy.random.rand(B, N, S, H).astype(numpy.float32), "value": numpy.random.rand(B, N, S, H).astype(numpy.float32), } if self.with_mask: @@ -323,10 +397,13 @@ class TestSDPAFusion(unittest.TestCase): "_custom_multi_scale_pre_mul_sdpa_script", _custom_multi_scale_pre_mul_sdpa_script, ), + ("pre_div_sdpa_BSHd_key", _unmasked_pre_div_sdpa_BSHd_key_script), ] ) def test_sdpa_fusion(self, name, script_func): - test_case = SDPATestCase(script_func, with_mask="masked" in name) + test_case = SDPATestCase( + script_func, with_mask="masked" in name, BSHd_key="BSHd_key" in name + ) model = test_case.get_onnx_model() onnxscript.optimizer.optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py index e6484406a9..acbc0705fa 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py @@ -7,43 +7,57 @@ import onnx_ir as ir from onnxscript.rewriter import _fusion_utils, pattern +from onnxscript.rewriter._basics import MatchFailureError Dim = Union[int, ir.SymbolicDim] class SDPAImplementation(pattern.RewriteRuleClassBase): - def pattern(self, op, query, key, value): + def pattern(self, op, query, key, value, key_format): + """Pattern matches any call to SDPA. See sdpa.py for documentation on the SDPA op.""" return op.SDPA( query, key, value, - key_format="BHSd", + key_format=key_format, _allow_other_inputs=True, # Mask is optional _outputs=["sdpa_output"], _domain="ai.onnxruntime._fusion", ) - def check(self, context, query, key, value, sdpa_output): + def check(self, context, query, key, value, key_format, sdpa_output): bindings: dict[str, Dim] = {} _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) + if key_format.value == "BHSd": + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + elif key_format.value == "BSHd": + _fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"]) + else: + raise MatchFailureError( + f"Unexpected key_format value: {key_format.value}", key_format + ) + self._num_heads = bindings["H"] if not isinstance(self._num_heads, int): return False self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed return isinstance(self._num_heads, int) - def rewrite(self, op, query, key, value, sdpa_output): + def rewrite(self, op, query, key, value, key_format, sdpa_output): sdpa_node = sdpa_output.producer() scale = sdpa_node.attributes.get("scale", None) to_3d_shape = op.Constant(value_ints=[0, 0, -1]) to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1]) query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape) - key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape) + if key_format.value == "BHSd": + key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) + else: # BSHd + key_3d = op.Reshape(key, to_3d_shape) + inputs = [query_3d, key_3d, value_3d] if len(sdpa_node.inputs) > 3: mask = sdpa_node.inputs[3] diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index f7a376aef9..c76a7454cb 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -60,7 +60,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( @@ -184,7 +184,7 @@ def check( bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: - return not _fusion_utils._check_shape(bindings, val, dims) + return not _fusion_utils.check_shape_bool(bindings, val, dims) if no_match(input, ["B", "S", "D"]): return check_result.fail( diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index bf5940e97c..f296b5320c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -10,9 +10,11 @@ import onnx.parser import onnxscript.optimizer +import onnxscript.rewriter from onnxscript import FLOAT, ir, script from onnxscript import opset17 as op -from onnxscript.rewriter import cast_constant_of_shape, pattern +from onnxscript.rewriter import pattern +from onnxscript.rewriter.rules.common import _cast_constant_of_shape logger = logging.getLogger(__name__) @@ -306,7 +308,7 @@ def test_delayed_run_provides_correct_bindings_for_multiple_matches(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 2) self.assertEqual(len(model.graph), 2) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) @@ -673,7 +675,7 @@ def test_model(x: FLOAT[1024, 512], y: FLOAT[1024, 512]) -> FLOAT[512, 1024]: function = model.functions[function_id] self.assertEqual([x.op_type for x in function], ["Add", "Transpose"]) transpose_node = function[1] - self.assertEqual(transpose_node.attributes["perm"].value, [1, 0]) + self.assertEqual(list(transpose_node.attributes["perm"].value), [1, 0]) onnxscript.optimizer.inline(model) self.assertEqual([x.op_type for x in model.graph], ["Add", "Transpose"]) @@ -935,6 +937,44 @@ def add_pattern(op, x, y): match_result = rule_pattern.match(model, model.graph, add_nodes[2]) self.assertFalse(bool(match_result)) + def test_rule_name_metadata(self): + """Test that RewriteRule carries name metadata.""" + + class ReciprocalMulRule(pattern.RewriteRuleClassBase): + def __init__(self, name: str | None = None): + super().__init__(name) + + def pattern(self, op, x, y): + return (1 / x) * y + + def rewrite(self, op, x, y): + return op.Div(y, x) + + @script() + def test_script(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]: + return op.Mul(op.Div(op.Constant(value_float=1.0), x), y) + + rule = ReciprocalMulRule.rule(name="ReciprocalMulToDiv") + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulToDiv") + + # By default, the rule name is the class name (if not provided) + rule = ReciprocalMulRule.rule() + model_proto = test_script.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + count = rule.apply_to_model(model) + self.assertEqual(count, 1) + for node in model.graph: + if node.op_type == "Div": + tag = onnxscript.rewriter.RULE_NAME_TAG + self.assertEqual(node.metadata_props.get(tag), "ReciprocalMulRule") + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self): diff --git a/onnxscript/rewriter/rules/__init__.py b/onnxscript/rewriter/rules/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py new file mode 100644 index 0000000000..76d9e4f4b0 --- /dev/null +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -0,0 +1,133 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +__all__ = [ + "add_0_rule", + "affine_conv_fusion_rule", + "cast_cast_rule", + "cast_constant_of_shape_rule", + "cast_constant_of_shape_without_value_rule", + "collapse_slice_rule", + "collapse_slice2_rule", + "conv_affine_fusion_rule", + "div_by_1_rule", + "dropout_inference_rule", + "dropout_zero_rule", + "flatten_to_reshape_rule", + "fuse_batchnorm_into_conv_rule", + "fuse_batchnorm_into_conv_transpose_rule", + "fuse_batchnorm_into_gemm_rule", + "fuse_hardswish_rules", + "fuse_pad_into_conv_integer_rule", + "fuse_pad_into_conv_rule", + "min_min_rule", + "max_max_rule", + "min_max_rule", + "max_min_rule", + "gemm_to_matmul_add_rule", + "matmul_add_to_gemm_rule", + "mul_by_1_rule", + "no_op_cast_rule", + "no_op_dynamic_scatter_nd_rule", + "no_op_expand_rule", + "no_op_static_scatter_nd_rule", + "no_op_transpose_rule", + "normalize_pad_format_conv_integer_rule", + "normalize_pad_format_conv_rule", + "one_reshape_matmul_reshape_rule", + "remove_optional_bias_from_conv_rule", + "remove_optional_bias_from_conv_transpose_rule", + "remove_optional_bias_from_gemm_rule", + "remove_optional_bias_from_qlinear_conv_rule", + "reshape_reshape_rule", + "slice_split_rule", + "squeeze_reshape_1d_rule", + "sub_0_rule", + "successive_clip_relu_rule", + "successive_clip_rule", + "successive_relu_clip_rule", + "successive_relu_rule", + "transpose_a_matmul_add_to_gemm_rule", + "transpose_ab_matmul_add_to_gemm_rule", + "transpose_b_matmul_add_to_gemm_rule", + "transpose_transpose_rule", + "two_reshapes_matmul_reshape_rule", + "unsqueeze_unsqueeze_rule", +] + +from onnxscript.rewriter.rules.common._basic_rules import ( + cast_cast_rule, + flatten_to_reshape_rule, + no_op_cast_rule, + no_op_expand_rule, + no_op_transpose_rule, + reshape_reshape_rule, + slice_split_rule, + squeeze_reshape_1d_rule, + transpose_transpose_rule, + unsqueeze_unsqueeze_rule, +) +from onnxscript.rewriter.rules.common._broadcast_to_matmul import ( + one_reshape_matmul_reshape_rule, + two_reshapes_matmul_reshape_rule, +) +from onnxscript.rewriter.rules.common._cast_constant_of_shape import ( + cast_constant_of_shape_rule, + cast_constant_of_shape_without_value_rule, +) +from onnxscript.rewriter.rules.common._collapse_slices import ( + collapse_slice2_rule, + collapse_slice_rule, +) +from onnxscript.rewriter.rules.common._fuse_batchnorm import ( + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, +) +from onnxscript.rewriter.rules.common._fuse_conv_affine import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) +from onnxscript.rewriter.rules.common._fuse_hardswish import fuse_hardswish_rules +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_integer_rule, + fuse_pad_into_conv_rule, + normalize_pad_format_conv_integer_rule, + normalize_pad_format_conv_rule, +) +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, + successive_relu_rule, +) +from onnxscript.rewriter.rules.common._gemm_to_matmul_add import gemm_to_matmul_add_rule +from onnxscript.rewriter.rules.common._matmul_add_to_gemm import ( + matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_ab_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, +) +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, +) +from onnxscript.rewriter.rules.common._no_op import ( + add_0_rule, + div_by_1_rule, + dropout_inference_rule, + dropout_zero_rule, + mul_by_1_rule, + sub_0_rule, +) +from onnxscript.rewriter.rules.common._redundant_scatter_nd import ( + no_op_dynamic_scatter_nd_rule, + no_op_static_scatter_nd_rule, +) +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) diff --git a/onnxscript/rewriter/basic_rules.py b/onnxscript/rewriter/rules/common/_basic_rules.py similarity index 75% rename from onnxscript/rewriter/basic_rules.py rename to onnxscript/rewriter/rules/common/_basic_rules.py index 2788cb7cda..b7a648880a 100644 --- a/onnxscript/rewriter/basic_rules.py +++ b/onnxscript/rewriter/rules/common/_basic_rules.py @@ -11,6 +11,8 @@ from typing import ClassVar, Sequence +import numpy as np + from onnxscript import ir from onnxscript.rewriter import _ir_utils as ir_utils from onnxscript.rewriter._basics import MatchResult @@ -123,16 +125,37 @@ def pattern(self, op, x, shape_ignored, shape): return op.Reshape(op.Reshape(x, shape_ignored), shape) def rewrite(self, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value): - return op.Reshape(x, shape) + new_shape = op.initializer(ir.Tensor(self._new_shape, name=shape.name)) + return op.Reshape(x, new_shape, allowzero=self._allowzero) def check(self, context, x, shape_ignored, shape) -> MatchResult: check_result = MatchResult() - if shape_ignored.const_value is None: - return check_result.fail("Shape ignored is not a constant.") - if shape.const_value is None: + + # Shape must be a constant. + if (np_shape := ir_utils.get_numpy_value(shape)) is None: return check_result.fail("Shape is not a constant.") - if shape.const_value.numpy().min() <= 0: - return check_result.fail("Shape has non-positive values.") + # Convert to array to support assignment destination. + self._new_shape = np.array(np_shape, np_shape.dtype) + + # Try to replace {0,-1} values in shape if reshape output is known. + if (reshape_output := context.output_values[0].shape) is not None: + for i, dim in enumerate(reshape_output): + if isinstance(dim, int) and dim > 0: + self._new_shape[i] = dim + + # Constraints for shape. + self._allowzero = context.nodes[0].attributes.get_int("allowzero", 0) + if self._allowzero == 1 and any(self._new_shape == 0): + return check_result + if any(self._new_shape == 0) and any(self._new_shape < 0): + return check_result.fail("Shape cannot contain both 0 and -1 dimensions.") + elif np.count_nonzero(self._new_shape == 0) > 1: + return check_result.fail("Shape cannot contain more than one 0 dimension.") + + # At this point, we can safely replace '0' with '-1'. + # Note allowzero is removed since at this point it does not have any effect. + self._allowzero = None + self._new_shape = np.where(self._new_shape == 0, -1, self._new_shape) return check_result @@ -279,16 +302,66 @@ def check(self, context, x, axes1, axes2) -> MatchResult: return check_result +class Flatten2Reshape(RewriteRuleClassBase): + """Convert ``Flatten(x)`` to Reshape.""" + + def pattern(self, op, x: ir.Value): + return op.Flatten(x) + + def rewrite(self, op, x: ir.Value): + new_shape = op.initializer(ir.Tensor(self._new_shape, name=f"{x.name}/shape")) + return op.Reshape(x, new_shape) + + def check(self, context, x: ir.Value) -> MatchResult: + check_result = MatchResult() + self._new_shape = np.array([-1, -1], "int64") + + # Convert axis in a positive value if possible. + axis = context.root.attributes.get_int("axis", 1) + input_rank = None + if (input_shape := x.shape) is not None: + input_rank = len(input_shape) + if axis < 0: + axis += input_rank + + # Compute reshape shape following axis attribute. + if axis == 0: + self._new_shape[0] = 1 + elif axis == 1: + self._new_shape[0] = 0 + elif axis == input_rank: + self._new_shape[1] = 1 + + # Try to update shape if output is known. + if (output_shape := context.output_values[0].shape) is not None: + for i, dim in enumerate(output_shape): + if isinstance(dim, int): + self._new_shape[i] = dim + + # Try to update shape if input is known. + if input_shape is not None: + if all(isinstance(dim, int) for dim in input_shape[:axis]): + self._new_shape[0] = np.prod(input_shape[:axis]) + if all(isinstance(dim, int) for dim in input_shape[axis:]): + self._new_shape[1] = np.prod(input_shape[axis:]) + + # Verify if it is possible to apply rule. + if np.count_nonzero(self._new_shape == -1) > 1: + return check_result.fail("Impossible to compute new shape.") + return check_result + + # Create rule instances cast_cast_rule = CastCast.rule() -cast_identity_rule = CastIdentity.rule() -expand_identity_rule = ExpandIdentity.rule() +no_op_cast_rule = CastIdentity.rule() +no_op_expand_rule = ExpandIdentity.rule() reshape_reshape_rule = ReshapeReshape.rule() slice_split_rule = SlicesSplit.rule() -transpose_identity_rule = TransposeIdentity.rule() +no_op_transpose_rule = TransposeIdentity.rule() transpose_transpose_rule = TransposeTranspose.rule() unsqueeze_unsqueeze_rule = UnsqueezeUnsqueeze.rule() squeeze_reshape_1d_rule = SqueezeReshape.rule() +flatten_to_reshape_rule = Flatten2Reshape.rule() def basic_optimization_rules() -> RewriteRuleSet: @@ -309,11 +382,13 @@ def basic_optimization_rules() -> RewriteRuleSet: return RewriteRuleSet( [ cast_cast_rule, - cast_identity_rule, - expand_identity_rule, + no_op_cast_rule, + no_op_expand_rule, + # flatten_to_reshape_rule is order sensitive to reshape_reshape_rule + flatten_to_reshape_rule, reshape_reshape_rule, slice_split_rule, - transpose_identity_rule, + no_op_transpose_rule, transpose_transpose_rule, unsqueeze_unsqueeze_rule, squeeze_reshape_1d_rule, diff --git a/onnxscript/rewriter/basic_rules_test.py b/onnxscript/rewriter/rules/common/_basic_rules_test.py similarity index 64% rename from onnxscript/rewriter/basic_rules_test.py rename to onnxscript/rewriter/rules/common/_basic_rules_test.py index bcb6db4aa8..7d4e9d9b33 100644 --- a/onnxscript/rewriter/basic_rules_test.py +++ b/onnxscript/rewriter/rules/common/_basic_rules_test.py @@ -12,9 +12,11 @@ import onnxscript import onnxscript.onnx_types as ot -import onnxscript.rewriter.basic_rules as basic_rules from onnxscript import ir from onnxscript.onnx_opset import opset18 +from onnxscript.rewriter import MatchingTracer, testing +from onnxscript.rewriter import pattern as orp +from onnxscript.rewriter.rules.common import _basic_rules FLOAT = onnx.TensorProto.FLOAT @@ -29,6 +31,10 @@ def _make_model(*args, **kwargs) -> ir.Model: return ir.serde.deserialize_model(onnx.helper.make_model(*args, **kwargs)) +def clone_model(model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + class BasicRulesTest(unittest.TestCase): def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: feeds: dict[str, Any] = {} @@ -98,7 +104,7 @@ def _check_model( ] ) def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -126,7 +132,7 @@ def test_basic_optimization_rules_identity(self, _: str, model: ir.Model): ] ) def test_basic_optimization_rules_transpose_transpose(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -153,7 +159,7 @@ def cast_cast_model(x): ] ) def test_cast_cast_rule(self, _: str, type1, type2, type3): - rule = basic_rules.cast_cast_rule + rule = _basic_rules.cast_cast_rule model_proto = self._double_cast_model(type1, type2, type3) model = ir.serde.deserialize_model(model_proto) rule.apply_to_model(model) @@ -172,7 +178,7 @@ def test_cast_cast_rule(self, _: str, type1, type2, type3): ] ) def test_cast_identity_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -228,7 +234,7 @@ def test_cast_identity_rule(self, _: str, model: ir.Model): def test_expand_identity_rule( self, _: str, model: ir.Model, expected_nodes: tuple[str, ...] ): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -310,7 +316,7 @@ def test_expand_identity_rule( ] ) def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() model_proto = ir.serde.serialize_model(model) rule_set.apply_to_model(model) rewritten_model = ir.serde.serialize_model(model) @@ -318,65 +324,6 @@ def test_unsqueeze_unsqueeze_rule(self, _: str, model: ir.Model): self.assertEqual(["Constant", "Unsqueeze"], [n.op_type for n in model.graph]) self._check_model(model_proto, rewritten_model) - @parameterized.parameterized.expand( - [ - ( - "double_reshape_1", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([4, 5, 3], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ( - "double_reshape_2", - _make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Reshape", ["X", "shape_"], ["Xu"]), - onnx.helper.make_node("Reshape", ["Xu", "shape"], ["Y"]), - ], - "name", - [onnx.helper.make_tensor_value_info("X", FLOAT, [3, 4, 5])], - [onnx.helper.make_tensor_value_info("Y", FLOAT, [5, 4, 3])], - [ - onnx.numpy_helper.from_array( - np.array([-1], dtype=np.int64), name="shape_" - ), - onnx.numpy_helper.from_array( - np.array([5, 4, 3], dtype=np.int64), name="shape" - ), - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], - ), - ), - ] - ) - def test_reshape_reshape_rule(self, _: str, model: ir.Model): - rule_set = basic_rules.basic_optimization_rules() - model_proto = ir.serde.serialize_model(model) - rule_set.apply_to_model(model) - rewritten_model = ir.serde.serialize_model(model) - - self.assertEqual(["Reshape"], [n.op_type for n in model.graph]) - self._check_model(model_proto, rewritten_model) - @classmethod def _slices_split_models(cls): models = [ @@ -420,7 +367,7 @@ def _slices_split_models(cls): def test_slices_split_rule(self): for model_proto in self._slices_split_models(): ir_model = ir.serde.deserialize_model(model_proto) - rule_set = basic_rules.basic_optimization_rules() + rule_set = _basic_rules.basic_optimization_rules() rule_set.apply_to_model(ir_model) rewritten_model = ir.serde.serialize_model(ir_model) @@ -428,7 +375,7 @@ def test_slices_split_rule(self): self._check_model(model_proto, rewritten_model) def test_squeeze_reshape_1d_rule(self): - rule = basic_rules.squeeze_reshape_1d_rule + rule = _basic_rules.squeeze_reshape_1d_rule def check(model_script, expected_count) -> None: model_proto = model_script.to_model_proto() @@ -465,5 +412,204 @@ def model3(X: ot.FLOAT[1, 1]): check(model3, 0) +class ReshapeReshapeTest(unittest.TestCase): + @staticmethod + def create_model( + input_shape, shape1, shape2, allowzero1=0, allowzero2=0, infer_shape=False + ): + def _convert_shape(shape, name): + if isinstance(shape, np.ndarray): + shape = tape.initializer(ir.Tensor(shape, name=name)) + elif isinstance(shape, (list, tuple)): + shape = ir.val(name, ir.DataType.INT64, ir.Shape(shape)) + tape.graph_like.inputs.append(shape) + else: + raise TypeError(f"Unsupported type {type(shape)} for shape.") + return shape + + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + reshape = tape.op( + "Reshape", + inputs=[x, _convert_shape(shape1, "shape_")], + attributes={"allowzero": allowzero1}, + ) + tape.op( + "Reshape", + inputs=[reshape, _convert_shape(shape2, "shape")], + attributes={"allowzero": allowzero2}, + output=y, + ) + model = ir.Model(tape.graph_like, ir_version=10) + + # Infer shapes. + if infer_shape: + model = ir.passes.common.ShapeInferencePass()(model).model + return model + + @parameterized.parameterized.expand( + [ + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 5), [4, 5, 3], [5, 4, 3]), + ((3, 4, 8), [2, 0, 3, -1], [0, 3, 2, 8]), + ((3, 4, 8), [3, 4, -1], [-1, 12], 1), + ((3, 4, 2), [0, 4, -1], [12, -1], 0, 1), + ((3, 0, 8), [4, 2, 0, 0], [3, 0], 1, 1), + ] + ) + def test_reshape_reshape_rule( + self, input_shape, shape1, shape2, allowzero1=0, allowzero2=0 + ): + model = self.create_model( + input_shape, + np.array(shape1, dtype="int64"), + np.array(shape2, dtype="int64"), + allowzero1=allowzero1, + allowzero2=allowzero2, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(10).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand([([3, 2, 3, 3, 3], 1), ([0, -1, 3, 2], 0)]) + def test_reshape_dynamic_reshape_rule(self, shape1, allowzero1=0): + input_shape = (3, 6, 9) + shape1 = np.array(shape1, dtype="int64") + # Build the model with unknown shape1. + model = self.create_model( + input_shape, + (shape1.size,), + np.array((1, 6, 27), dtype="int64"), + allowzero1=allowzero1, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + feeds = { + "X": np.random.default_rng(2).random(input_shape, dtype="float32"), + "shape_": shape1, + } + testing.assert_numerically_equal(model, updated_model, feeds, atol=0, rtol=0) + + @parameterized.parameterized.expand( + [((3, 6, 9), [0, 3, 2, -1]), ((0, 6, 2), [0, 0, 3], 1)] + ) + def test_reshape_reshape_dynamic_rule(self, input_shape, shape2, allowzero2=0): + # Note that shape inference is required for this test to be valid. + shape2 = np.array(shape2, dtype="int64") + model = self.create_model( + input_shape, + np.array((3, 2, -1), dtype="int64"), + shape2, + allowzero2=allowzero2, + infer_shape=True, + ) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.reshape_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(7).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand( + [ + ((3,), "is not a constant"), + (np.array([0, -1], dtype="int64"), "both 0 and -1 dimensions"), + (np.array([0, 0, 3], dtype="int64"), "more than one 0 dimension"), + ] + ) + def test_unsupported_reshape_reshape(self, shape2, error_msg): + model = self.create_model((1, 2, 3), np.array([1, 6], dtype="int64"), shape2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.reshape_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.reshape_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, error_msg) + + +class Flatten2ReshapeTest(unittest.TestCase): + @staticmethod + def create_model(input_shape, axis=1): + x = ir.val("X", ir.DataType.FLOAT, ir.Shape(input_shape)) + y = ir.val("Y", ir.DataType.FLOAT) + tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20})) + + # Build the graph. + tape.op("Flatten", inputs=[x], attributes={"axis": axis}, output=y) + model = ir.Model(tape.graph_like, ir_version=10) + return model + + @parameterized.parameterized.expand(list(range(-5, 6))) + def test_flatten_to_reshape_rule(self, axis): + input_shape = (1, 4, 8, 7, 5) + model = self.create_model(input_shape=input_shape, axis=axis) + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(13).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + @parameterized.parameterized.expand(list(range(-4, 5))) + def test_flatten_to_reshape_dynamic_input(self, axis): + model = self.create_model(input_shape=("N", "C1", "C2", "C3"), axis=axis) + # Rule is supported in all cases if the output shape is known for non-special cases. + input_shape = (1, 2, 3, 4) + if axis not in {-3, 0, 1, 4}: + out_shape = ir.Shape((np.prod(input_shape[:axis]), np.prod(input_shape[axis:]))) + model.graph.outputs[0].shape = out_shape + updated_model = clone_model(model) + + # check rewrite approach. + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(updated_model) + self.assertEqual(count, 1) + self.assertEqual(["Reshape"], [n.op_type for n in updated_model.graph]) + + # Check inference. + inputs = np.random.default_rng(17).random(input_shape, dtype="float32") + testing.assert_numerically_equal(model, updated_model, (inputs,), atol=0, rtol=0) + + def test_unsupported_flatten_to_reshape(self): + model = self.create_model(input_shape=("N", "C1", "C2"), axis=2) + + # Check rewrite approach. + tracer = MatchingTracer() + count = _basic_rules.flatten_to_reshape_rule.apply_to_model(model, tracer=tracer) + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[_basic_rules.flatten_to_reshape_rule][0] + self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, "Impossible to compute new shape") + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul.py similarity index 100% rename from onnxscript/rewriter/broadcast_to_matmul.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul.py diff --git a/onnxscript/rewriter/broadcast_to_matmul_test.py b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py similarity index 94% rename from onnxscript/rewriter/broadcast_to_matmul_test.py rename to onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py index c2f3b31f90..4e33544986 100644 --- a/onnxscript/rewriter/broadcast_to_matmul_test.py +++ b/onnxscript/rewriter/rules/common/_broadcast_to_matmul_test.py @@ -9,7 +9,7 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import broadcast_to_matmul +from onnxscript.rewriter.rules.common import _broadcast_to_matmul def _infer_shapes(model: ir.Model) -> ir.Model: @@ -38,7 +38,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -108,7 +108,7 @@ def test_reshape_matmul_reshape_does_not_replace_when_output_sizes_do_not_match( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) model = _infer_shapes(model) @@ -151,7 +151,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable_in_nest ) ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -178,7 +178,7 @@ def test_reshape_matmul_reshape_remain_when_input_last_dim_and_second_last_dim_n """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -202,7 +202,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_inputs_are_not_broadcast ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -226,7 +226,7 @@ def test_reshape_matmul_reshape_replace_when_inputs_are_broadcastable_with_one_i """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -249,7 +249,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_br """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -272,7 +272,7 @@ def test_reshape_matmul_reshape_replace_when_first_input_is_one_dimension_and_se """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -295,7 +295,7 @@ def test_reshape_matmul_reshape_remain_when_first_input_is_one_dimension_and_not """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -318,7 +318,7 @@ def test_reshape_matmul_reshape_replace_when_second_input_is_one_dimension_and_b """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -342,7 +342,7 @@ def test_reshape_matmul_reshape_remain_one_reshape_when_second_input_is_one_dime ) model_proto = onnx.shape_inference.infer_shapes(model_proto) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) # subset pattern matched self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) @@ -366,7 +366,7 @@ def test_reshape_matmul_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 7) @@ -387,7 +387,7 @@ def test_reshape_matmul_reshape_replace_when_nd_inputs_are_broadcastable(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = broadcast_to_matmul.rules.apply_to_model(model) + count = _broadcast_to_matmul.rules.apply_to_model(model) self.assertEqual(count, 1) # The constant nodes are not removed. They should be removed by a subsequent DCE in optimizer. self.assertEqual(len(model.graph), 3) diff --git a/onnxscript/rewriter/cast_constant_of_shape.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape.py similarity index 100% rename from onnxscript/rewriter/cast_constant_of_shape.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape.py diff --git a/onnxscript/rewriter/cast_constant_of_shape_test.py b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py similarity index 89% rename from onnxscript/rewriter/cast_constant_of_shape_test.py rename to onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py index 35151e17d9..794491024b 100644 --- a/onnxscript/rewriter/cast_constant_of_shape_test.py +++ b/onnxscript/rewriter/rules/common/_cast_constant_of_shape_test.py @@ -6,7 +6,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import cast_constant_of_shape +from onnxscript.rewriter.rules.common import _cast_constant_of_shape class CastConstantOfShapeTest(unittest.TestCase): @@ -23,7 +23,7 @@ def test_cast_after_constant_of_shape_is_fused(self): ) onnx.checker.check_model(input_model_proto, True) model = ir.serde.deserialize_model(input_model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) @@ -42,7 +42,7 @@ def test_cast_after_constant_of_shape_without_value_is_fused(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = cast_constant_of_shape.rules.apply_to_model(model) + count = _cast_constant_of_shape.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertEqual(model.graph[0].attributes["value"].value.dtype, 10) diff --git a/onnxscript/rewriter/collapse_slices.py b/onnxscript/rewriter/rules/common/_collapse_slices.py similarity index 91% rename from onnxscript/rewriter/collapse_slices.py rename to onnxscript/rewriter/rules/common/_collapse_slices.py index 291128157d..21b2694b82 100644 --- a/onnxscript/rewriter/collapse_slices.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices.py @@ -5,7 +5,7 @@ import logging from onnxscript import ir -from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._rewrite_rule import RewriteRule, RewriteRuleSet logger = logging.getLogger(__name__) @@ -82,20 +82,20 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ if data.shape is None or slice_output.shape is None: return False - if not is_singleton_value(steps, 1): + if not _ir_utils.is_singleton_value(steps, 1): return False - return data.shape == slice_output.shape + return _ir_utils.same_shape(data.shape, slice_output.shape) # Register the rewrite rules -remove_redundant_slice = RewriteRule( +collapse_slice_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _check_if_redundant_slice, ) -remove_redundant_slice2 = RewriteRule( +collapse_slice2_rule = RewriteRule( _potential_redundant_slice, _identity_to_itself, _same_shape, @@ -104,4 +104,4 @@ def _same_shape(op, data: ir.Value, slice_output: ir.Value, steps: ir.Value, **_ # NOTE: The second rule subsumes the first one. So, we may be able to remove the first one, # provided shape-inference is run before the rewriter and computes the shape of the slice output. -rules = RewriteRuleSet([remove_redundant_slice, remove_redundant_slice2]) +rules = RewriteRuleSet([collapse_slice_rule, collapse_slice2_rule]) diff --git a/onnxscript/rewriter/collapse_slices_test.py b/onnxscript/rewriter/rules/common/_collapse_slices_test.py similarity index 91% rename from onnxscript/rewriter/collapse_slices_test.py rename to onnxscript/rewriter/rules/common/_collapse_slices_test.py index 52b59f9037..727240344d 100644 --- a/onnxscript/rewriter/collapse_slices_test.py +++ b/onnxscript/rewriter/rules/common/_collapse_slices_test.py @@ -6,10 +6,10 @@ import numpy as np import onnx.parser -import onnx.shape_inference from onnxscript import ir -from onnxscript.rewriter import collapse_slices, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _collapse_slices _INT64_MAX = 9223372036854775807 @@ -30,7 +30,7 @@ def test_slice_is_redundant_when_ends_is_greater_than_input_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -55,7 +55,7 @@ def test_slice_is_redundant_when_ends_reaches_int64_max(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 5) self.assertIn("Identity", [node.op_type for node in model.graph]) @@ -80,7 +80,7 @@ def test_slice_unequal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 0) def test_slice_equal_dynamic_shape(self): @@ -98,7 +98,7 @@ def test_slice_equal_dynamic_shape(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) self.assertEqual(count, 1) def test_slice_equal_dynamic_shape_but_step_reverse(self): @@ -116,6 +116,6 @@ def test_slice_equal_dynamic_shape_but_step_reverse(self): """ ) model = ir.serde.deserialize_model(model_proto) - count = collapse_slices.rules.apply_to_model(model) + count = _collapse_slices.rules.apply_to_model(model) # Should not change the output shape if we did not use the default step of 1 self.assertEqual(count, 0) diff --git a/onnxscript/rewriter/fuse_batchnorm.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py similarity index 84% rename from onnxscript/rewriter/fuse_batchnorm.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm.py index 51e4e20db3..e3298ffbd8 100644 --- a/onnxscript/rewriter/fuse_batchnorm.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm.py @@ -15,7 +15,7 @@ """ from abc import ABC, abstractmethod -from typing import Mapping +from typing import ClassVar, Mapping import numpy as np @@ -33,16 +33,6 @@ def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarra class _FuseBatchNormBase(RewriteRuleClassBase, ABC): """Interface for BatchNormalization nodes fusion.""" - def __init__( - self, - op_type: str, - name: str | None = None, - remove_nodes: bool = True, - as_function: bool = False, - ) -> None: - super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) - self.op_type = op_type - @abstractmethod def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: """Return the axis along which BatchNorm scale should be broadcasted.""" @@ -78,7 +68,10 @@ def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Valu bias_name = inbound_node.inputs[2].name else: original_bias = np.zeros_like(input_mean) - bias_name = x.name + "_bias" + # Use inbound input 1 (should be weight) to derive a name for the bias + # to avoid name collision on initializer creation when there are multiple patterns + # sharing the same parent nodes. + bias_name = inbound_node.inputs[1].name + "_bias" fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) return op.op( @@ -116,8 +109,7 @@ def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> M class FuseBatchNormIntoConv(_FuseBatchNormBase): """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" - def __init__(self): - super().__init__("Conv") + op_type: ClassVar = "Conv" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 0 @@ -133,8 +125,7 @@ def pattern(self, op, x): class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" - def __init__(self): - super().__init__("ConvTranspose") + op_type: ClassVar = "ConvTranspose" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return 1 @@ -150,8 +141,7 @@ def pattern(self, op, x): class FuseBatchNormIntoGemm(_FuseBatchNormBase): """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" - def __init__(self): - super().__init__("Gemm") + op_type: ClassVar = "Gemm" def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: return ( @@ -167,21 +157,14 @@ def pattern(self, op, x): fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() -fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() +fuse_batchnorm_into_conv_transpose_rule = FuseBatchNormIntoConvTranspose().rule() fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() -def fuse_batchnorm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse BatchNormalization nodes - into preceding nodes such as Conv, ConvTranspose, and Gemm. - - Returns: - RewriteRuleSet - """ - return RewriteRuleSet( - [ - fuse_batchnorm_into_conv_rule, - fuse_batchnorm_into_convtranspose_rule, - fuse_batchnorm_into_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + fuse_batchnorm_into_conv_rule, + fuse_batchnorm_into_conv_transpose_rule, + fuse_batchnorm_into_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_batchnorm_test.py b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py similarity index 73% rename from onnxscript/rewriter/fuse_batchnorm_test.py rename to onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py index 20d272abd7..2007033ef6 100644 --- a/onnxscript/rewriter/fuse_batchnorm_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_batchnorm_test.py @@ -8,7 +8,8 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import fuse_batchnorm, testing +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import _fuse_batchnorm class FuseBatchnormTest(unittest.TestCase): @@ -73,7 +74,7 @@ def test_fuse_batchnorm_convtranspose(self, _: str, convtranspose_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -132,7 +133,7 @@ def test_fuse_batchnorm_conv(self, _: str, conv_bias: bool): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -196,7 +197,7 @@ def test_fuse_batchnorm_gemm(self, _: str, gemm_bias: bool, transB: int): model = ir.serde.deserialize_model(model_proto) # Apply rule - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # Check that BatchNorm was fused self.assertEqual(count, 1) @@ -223,7 +224,7 @@ def test_fuse_batchnorm_non_initializers(self): """) onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied self.assertEqual(count, 0) @@ -247,11 +248,69 @@ def test_fuse_batchnorm_graph_inputs(self): onnx.checker.check_model(model_proto, True) model = ir.serde.deserialize_model(model_proto) - count = fuse_batchnorm.fuse_batchnorm_rule_set().apply_to_model(model) + count = _fuse_batchnorm.rules.apply_to_model(model) # No changes were applied as W is a graph input self.assertEqual(count, 0) + def test_fuse_batchnorm_does_not_collide_names_with_same_parent_node(self): + model_proto = onnx.parser.parse_model(""" + < ir_version: 7, opset_import: ["" : 17] > + test_model (float[N, 32, 14, 16] X) => (float [N, ?, ?, ?] Y1, float [N, ?, ?, ?] Y2) + { + X1 = MaxPool(X) + X2 = Conv(X1, W1) + Y1 = BatchNormalization(X2, gamma_64, beta_64, input_mean_64, input_var_64) + X3 = Conv(X1, W2) + Y2 = BatchNormalization(X3, gamma_256, beta_256, input_mean_256, input_var_256) + } + """) + initializers = [ + onnx.numpy_helper.from_array( + np.random.randn(64, 32, 3, 3).astype(np.float32), name="W1" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="gamma_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="beta_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(64).astype(np.float32), name="input_mean_64" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(64)).astype(np.float32), name="input_var_64" + ), + onnx.numpy_helper.from_array( + np.random.randn(256, 32, 3, 3).astype(np.float32), name="W2" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="gamma_256" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="beta_256" + ), + onnx.numpy_helper.from_array( + np.random.randn(256).astype(np.float32), name="input_mean_256" + ), + onnx.numpy_helper.from_array( + np.abs(np.random.randn(256)).astype(np.float32), name="input_var_256" + ), + ] + model_proto.graph.initializer.extend(initializers) + onnx.checker.check_model(model_proto, True) + model = ir.serde.deserialize_model(model_proto) + count = _fuse_batchnorm.rules.apply_to_model(model) + + # Applied twice, once for each BatchNorm + self.assertEqual(count, 2) + # it should have different bias names for the two fused Conv nodes + conv_nodes = [node for node in model.graph if node.op_type == "Conv"] + self.assertEqual(len(conv_nodes), 2) + bias_names_1 = conv_nodes[0].inputs[2].name + bias_names_2 = conv_nodes[1].inputs[2].name + self.assertNotEqual(bias_names_1, bias_names_2) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py new file mode 100644 index 0000000000..2aaba5cd73 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine.py @@ -0,0 +1,112 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Absorbs affine operation into convolution (best effort): +- Conv(Mul(Add(x))) -> Conv (only conv without padding can be fused) +- Add(Mul(Conv)) -> Conv (for all convolutions) +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import get_const_value, get_singleton_value + + +class _ConvAffineFusionBase(pattern.RewriteRuleClassBase): + def check( + self, + context, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + if get_const_value(w) is None: + return check_result.fail("The weight of Conv should be constant") + if get_const_value(b) is None: + return check_result.fail("The bias of Conv should be constant") + if get_singleton_value(scale) is None: + return check_result.fail("Operand for Mul should be constant scalar value") + if get_singleton_value(offset) is None: + return check_result.fail("Operand for Add should be constant scalar value") + return check_result + + +class AffineConvFusion(_ConvAffineFusionBase): + """Pattern: scalar Mul + scalar Add + Conv (1x1) --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return op.Conv( + x * scale + offset, + w, + b, + pads=[0, 0, 0, 0], + _allow_other_attributes=True, + _outputs=["conv_out"], + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_value = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor( + b_value + np.sum(w_value * offset_value, axis=(1, 2, 3), keepdims=False) + ) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_value, offset_bias, **conv_attributes) + + +class ConvAffineFusion(_ConvAffineFusionBase): + """Pattern: Conv + scalar Mul + scalar Add --> Conv(1x1)""" + + def pattern( + self, op, x: ir.Value, w: ir.Value, b: ir.Value, scale: ir.Value, offset: ir.Value + ) -> ir.Value: + return ( + op.Conv(x, w, b, _allow_other_attributes=True, _outputs=["conv_out"]) * scale + + offset + ) + + def rewrite( + self, + op: ir.tape.Tape, + x: ir.Value, + w: ir.Value, + b: ir.Value, + scale: ir.Value, + offset: ir.Value, + conv_out: ir.Value, + ) -> ir.Value: + scale_value = scale.const_value.numpy() + offset_value = offset.const_value.numpy() + w_value = w.const_value.numpy() + b_value = b.const_value.numpy() + scaled_w_weight = op.initializer(ir.tensor(w_value * scale_value), w.name + "_scaled") + offset_bias = ir.tensor(b_value * scale_value + offset_value) + offset_bias = op.initializer(offset_bias, b.name + "_offset") + conv_attributes = conv_out.producer().attributes + return op.Conv(x, scaled_w_weight, offset_bias, **conv_attributes) + + +affine_conv_fusion_rule = AffineConvFusion().rule() +conv_affine_fusion_rule = ConvAffineFusion().rule() diff --git a/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py new file mode 100644 index 0000000000..d456cab76b --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter import rewrite, testing +from onnxscript.rewriter.rules.common import ( + affine_conv_fusion_rule, + conv_affine_fusion_rule, +) + + +class FuseConvAffineTest(unittest.TestCase): + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def test_conv_affine_fusion(self): + tape = ir.tape.Tape() + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + conv_out = tape.op("Conv", [x, w, b], attributes={"pads": [1, 1, 1, 1]}) + mul_out = tape.op("Mul", [conv_out, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.val( + "z", + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1, 3, 32, 32]), + ), + ) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[z], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[conv_affine_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + def test_affine_conv_fusion_without_pad(self): + tape = ir.tape.Tape() + x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32])) + w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w")) + b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b")) + scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale")) + offset = tape.initializer(ir.tensor(np.array([3.0], dtype=np.float32), name="offset")) + + mul_out = tape.op("Mul", [x, scale]) + z = tape.op( + "Add", + [mul_out, offset], + output=ir.val( + "z", + dtype=ir.DataType.FLOAT, + shape=ir.Shape([1, 3, 32, 32]), + ), + ) + conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]}) + + model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[conv_out], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 17}, + ), + ir_version=8, + ) + rewritten_model = self.clone_model(model) + rewritten_model = rewrite( + rewritten_model, + pattern_rewrite_rules=[affine_conv_fusion_rule], + ) + # Check that Mul and Add are fused into Conv + self.assertEqual(model.graph.num_nodes() - 2, rewritten_model.graph.num_nodes()) + + # Check that the results are numerically equal + rng = np.random.default_rng(42) + inputs = [ + rng.random((1, 3, 32, 32), dtype=np.float32), + ] + testing.assert_numerically_equal(model, rewritten_model, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish.py b/onnxscript/rewriter/rules/common/_fuse_hardswish.py new file mode 100644 index 0000000000..6d2e8c84e1 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Does the following transformation: +- Div(Clip(Add(x))) -> HardSigmoid +- Mul(HardSigmoid(x), x) -> HardSwish +""" + +from __future__ import annotations + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter import pattern +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import is_singleton_value +from onnxscript.rewriter._rewrite_rule import RewriteRuleSet + + +class _HardSigmoidFusionBase(pattern.RewriteRuleClassBase): + """HardSwish requires constant values so we check in base class.""" + + def check( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> MatchResult: + check_result = MatchResult() + + if not is_singleton_value(clip_min, 0.0, rtol=1e-4): + return check_result.fail("Swish requires min value of 0 for clip") + if not is_singleton_value(clip_max, 6.0, rtol=1e-4): + return check_result.fail("Swish requires max value of 6 for clip") + if not is_singleton_value(bias, 3.0, rtol=1e-4): + return check_result.fail("Swish requires bias value of 3") + if not is_singleton_value(divisor, 6.0, rtol=1e-4): + return check_result.fail("Swish requires divisor value of 6") + return check_result + + +class HardSwishFusion(_HardSigmoidFusionBase): + """Fuse Add(_, 3) + Clip<0, 6>(_) + Mul + Div(_, 6) into HardSwish + + In this case we can't make HardSigmoid fusion first. The Mul + is placed before Div while HardSigmoid requires Add+Clip+Div. + """ + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) * x + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSwish(x) + + +class HardSwishFusionFromHardSigmoid(pattern.RewriteRuleClassBase): + """Fuse HardSigmoid + Mul into HardSwish""" + + def pattern(self, op, x: ir.Value) -> ir.Value: + # Floating point matching for 1/6 is not exact, so we use isclose below + out = op.HardSigmoid(x, _allow_other_attributes=True, _outputs=["hardsigmoid_out"]) + out = out * x + return out + + def check(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> MatchResult: + check_result = MatchResult() + hardsigmoid = hardsigmoid_out.producer() + # Use getter to protect when 'alpha' / 'beta' is not in attributes + alpha = hardsigmoid.attributes.get_float("alpha", -1) + beta = hardsigmoid.attributes.get_float("beta", -1) + if not np.isclose(alpha, 1 / 6): + return check_result.fail( + "HardSigmoid alpha must be 1/6 to get fused into HardSwish" + ) + if not np.isclose(beta, 0.5): + return check_result.fail( + "HardSigmoid beta must be 0.5 to get fused into HardSwish" + ) + return check_result + + def rewrite(self, op, x: ir.Value, hardsigmoid_out: ir.Value) -> ir.Value: + return op.HardSwish(x) + + +class HardSigmoidFusion(_HardSigmoidFusionBase): + """Fuse HardSigmoid only for HardSwish hyper-parameters: alpha=1/6, beta=0.5""" + + def pattern( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + out = op.Clip(x + bias, clip_min, clip_max) + out = out / divisor + return out + + def rewrite( + self, + op, + x: ir.Value, + clip_min: ir.Value, + clip_max: ir.Value, + bias: ir.Value, + divisor: ir.Value, + ) -> ir.Value: + return op.HardSigmoid(x, alpha=1 / 6, beta=0.5) + + +def fuse_hardswish_rules() -> RewriteRuleSet: + """Returns the rewrite rules for fusing HardSwish and HardSigmoid.""" + return RewriteRuleSet( + [ + HardSwishFusion().rule(), + HardSigmoidFusion().rule(), + HardSwishFusionFromHardSigmoid().rule(), + ], + commute=True, + ) diff --git a/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py new file mode 100644 index 0000000000..36556e9cff --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_hardswish_test.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +import onnxruntime as ort +from onnx_ir.passes.common import onnx_checker, shape_inference + +from onnxscript import optimizer +from onnxscript.rewriter import testing +from onnxscript.rewriter.rules.common import fuse_hardswish_rules + + +class FuseHardSwishTest(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250621) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = fuse_hardswish_rules().apply_to_model(updated_model) + + # Polish model to remove unused constants + updated_model = optimizer.optimize(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = (self.rng.integers(low=-10, high=10, size=(2 * 32), dtype=np.int32),) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ort_optimization_level=ort.GraphOptimizationLevel.ORT_DISABLE_ALL, + ) + + # Validate serialized model + output_model_proto = ir.to_proto(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def test_hardsigmoid_fusion(self): + model_text = """ + + hardsigmoid (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + y = Div(clipped, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSigmoid"]) + + def test_hardswish_fusion(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + mul_x = Mul(clipped, x) + y = Div(mul_x, six) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_mul_last(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + three = Constant () + six = Constant () + zero = Constant () + x_plus_3 = Add(x, three) + clipped = Clip(x_plus_3, zero, six) + div_x = Div(clipped, six) + y = Mul(div_x, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + def test_hardswish_fusion_from_sigmoid(self): + model_text = """ + + hardswish (float[N] x) => (float[N] y) { + hardsigmoid_out = HardSigmoid(x) + y = Mul(hardsigmoid_out, x) + } + """ + model = ir.from_onnx_text(model_text) + self.run_test(model, ["HardSwish"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/fuse_pad_into_conv.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py similarity index 95% rename from onnxscript/rewriter/fuse_pad_into_conv.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py index 7aeae57ccd..39aab00eda 100644 --- a/onnxscript/rewriter/fuse_pad_into_conv.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv.py @@ -327,25 +327,17 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value: return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"]) -normalize_pad_format_conv = NormalizePadFormatConv.rule() -normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule() -fuse_pad_into_conv = FuseConvPad.rule() -fuse_pad_into_conv_integer = FuseConvIntegerPad.rule() - - -def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet: - """Returns a set of rewrite rules that fuse Pad nodes into preceding: - - Conv - - ConvInteger - - Returns: - RewriteRuleSet - """ - return orp.RewriteRuleSet( - [ - normalize_pad_format_conv, - normalize_pad_format_conv_integer, - fuse_pad_into_conv, - fuse_pad_into_conv_integer, - ] - ) +normalize_pad_format_conv_rule = NormalizePadFormatConv.rule() +normalize_pad_format_conv_integer_rule = NormalizePadFormatConvInteger.rule() +fuse_pad_into_conv_rule = FuseConvPad.rule() +fuse_pad_into_conv_integer_rule = FuseConvIntegerPad.rule() + + +rules = orp.RewriteRuleSet( + [ + normalize_pad_format_conv_rule, + normalize_pad_format_conv_integer_rule, + fuse_pad_into_conv_rule, + fuse_pad_into_conv_integer_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_pad_into_conv_test.py b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py similarity index 93% rename from onnxscript/rewriter/fuse_pad_into_conv_test.py rename to onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py index dfbf117bd1..ded57fe023 100644 --- a/onnxscript/rewriter/fuse_pad_into_conv_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py @@ -12,10 +12,10 @@ from onnxscript.rewriter import pattern as orp from onnxscript.rewriter import testing -from onnxscript.rewriter.fuse_pad_into_conv import ( - fuse_pad_into_conv, - fuse_pad_into_conv_rule_set, - normalize_pad_format_conv, +from onnxscript.rewriter.rules.common import _fuse_pad_into_conv +from onnxscript.rewriter.rules.common._fuse_pad_into_conv import ( + fuse_pad_into_conv_rule, + normalize_pad_format_conv_rule, ) @@ -61,13 +61,13 @@ def build_model( # Register operations in the tape idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT - x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype)) y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes) y = tape.op( op_type, inputs=[y, self.get_conv_weights(weight_shape, tape)], attributes=conv_attributes, - output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)), ) if op_type == "ConvInteger": y.dtype = ir.DataType.INT32 @@ -118,7 +118,7 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads, conv_a updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -209,11 +209,11 @@ def test_unsupported_fuse_pad_into_conv( # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = fuse_pad_into_conv.apply_to_model(base_model, tracer=tracer) + count = fuse_pad_into_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[fuse_pad_into_conv][0] + tracer_match = tracer.best_matches_map[fuse_pad_into_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, err_msg) @@ -255,7 +255,7 @@ def test_fuse_pad_into_conv_integer( updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) # Check that Pad was fused self.assertEqual(count, 1 if conv_auto_pad is None else 2) @@ -290,12 +290,12 @@ def build_model( raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.") # Register operations in the tape - x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) y = tape.op( "Conv", inputs=[x, *conv_inputs], attributes=conv_attributes, - output=ir.Input("Y", shape=output_shape, type=x.type), + output=ir.val("Y", shape=output_shape, type=x.type), ) # Build the model @@ -344,7 +344,7 @@ def test_normalize_pad_format(self, dynamic_shape, strides, kernel_shape, auto_p updated_model = _clone_model(base_model) # Apply rule - count = fuse_pad_into_conv_rule_set().apply_to_model(updated_model) + count = _fuse_pad_into_conv.rules.apply_to_model(updated_model) onnx_checker.CheckerPass(True)(updated_model) # Check conv has changed @@ -372,11 +372,11 @@ def test_unsupported_normalize_pad_format(self, input_shape, infer_shapes, error # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, error_msg) @@ -393,11 +393,11 @@ def test_unsupported_normalize_pad_format_on_weights(self): # Apply rule and check it was not applied tracer = orp.MatchingTracer() - count = normalize_pad_format_conv.apply_to_model(base_model, tracer=tracer) + count = normalize_pad_format_conv_rule.apply_to_model(base_model, tracer=tracer) self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[normalize_pad_format_conv][0] + tracer_match = tracer.best_matches_map[normalize_pad_format_conv_rule][0] self.assertEqual(tracer_match.status.value, orp.MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, "same length than kernel_shape") diff --git a/onnxscript/rewriter/fuse_relus_clips.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py similarity index 89% rename from onnxscript/rewriter/fuse_relus_clips.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips.py index 484ca679fc..5d294cdbd7 100644 --- a/onnxscript/rewriter/fuse_relus_clips.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips.py @@ -169,25 +169,17 @@ def pattern(self, op, x): return op.Relu(op.Clip(x, _allow_other_inputs=True, _outputs=["out_first_clip"])) -fuse_successive_relu_rule = FuseSuccessiveRelu().rule() -fuse_successive_clip_rule = FuseSuccessiveClip().rule() -fuse_successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() -fuse_successive_relu_clip_rule = FuseSuccessiveReluClip().rule() - - -def fuse_relus_clips_rules() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse successive Relu/Clip nodes. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - fuse_successive_clip_relu_rule, - fuse_successive_relu_clip_rule, - fuse_successive_relu_rule, - fuse_successive_clip_rule, - ] - ) +successive_relu_rule = FuseSuccessiveRelu().rule() +successive_clip_rule = FuseSuccessiveClip().rule() +successive_clip_relu_rule = FuseSuccessiveClipRelu().rule() +successive_relu_clip_rule = FuseSuccessiveReluClip().rule() + + +rules = RewriteRuleSet( + [ + successive_clip_relu_rule, + successive_relu_clip_rule, + successive_relu_rule, + successive_clip_rule, + ] +) diff --git a/onnxscript/rewriter/fuse_relus_clips_test.py b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py similarity index 94% rename from onnxscript/rewriter/fuse_relus_clips_test.py rename to onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py index d58b493fb4..df2d669930 100644 --- a/onnxscript/rewriter/fuse_relus_clips_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_relus_clips_test.py @@ -13,13 +13,13 @@ MatchingTracer, MatchStatus, RewriteRule, - fuse_relus_clips, testing, ) -from onnxscript.rewriter.fuse_relus_clips import ( - fuse_successive_clip_relu_rule, - fuse_successive_clip_rule, - fuse_successive_relu_clip_rule, +from onnxscript.rewriter.rules.common import _fuse_relus_clips +from onnxscript.rewriter.rules.common._fuse_relus_clips import ( + successive_clip_relu_rule, + successive_clip_rule, + successive_relu_clip_rule, ) @@ -40,7 +40,7 @@ def run_test( onnx_checker.CheckerPass(True)(base_model) base_model = shape_inference.infer_shapes(base_model) updated_model = self.clone_model(base_model) - _ = fuse_relus_clips.fuse_relus_clips_rules().apply_to_model(updated_model) + _ = _fuse_relus_clips.rules.apply_to_model(updated_model) # Check expected op_types self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) @@ -214,7 +214,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -222,7 +222,7 @@ def test_successful_fuse_successive_relu_clip_no_min(self, _, nodes): x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -245,7 +245,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Relu(X) Y = Clip(x1, min) """, - fuse_successive_clip_relu_rule, + successive_clip_relu_rule, ), ( "clip_then_relu", @@ -253,7 +253,7 @@ def test_fail_fuse_successive_relu_clip_non_initializers(self, _, nodes, rewrite x1 = Clip(X, min) Y = Relu(x1) """, - fuse_successive_relu_clip_rule, + successive_relu_clip_rule, ), ] ) @@ -334,7 +334,7 @@ def test_fail_fuse_successive_clips_non_initializers(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is not a constant.") + self.run_failed_condition_test(model, successive_clip_rule, "is not a constant.") def test_fail_fuse_successive_clips_graph_inputs(self): model = ir.from_onnx_text(""" @@ -346,7 +346,7 @@ def test_fail_fuse_successive_clips_graph_inputs(self): Y = Clip(x1, min2) } """) - self.run_failed_condition_test(model, fuse_successive_clip_rule, "is a graph input.") + self.run_failed_condition_test(model, successive_clip_rule, "is a graph input.") class FuseReluClipIntegrationTest(_FuseReluClipTestBase): diff --git a/onnxscript/rewriter/gemm_to_matmul_add.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py similarity index 76% rename from onnxscript/rewriter/gemm_to_matmul_add.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py index 09666466d3..e51b4b22fa 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from onnxscript.rewriter._rewrite_rule import RewriteRule -from onnxscript.rewriter.broadcast_to_matmul import check_if_not_need_reshape +from onnxscript.rewriter.rules.common._broadcast_to_matmul import check_if_not_need_reshape # Pattern to match against @@ -18,4 +18,6 @@ def matmul_add(op, input_a, input_b, input_c, **_): return op.Add(matmul, input_c) -rule = RewriteRule(reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape) +gemm_to_matmul_add_rule = RewriteRule( + reshape_gemm_reshape_pattern, matmul_add, check_if_not_need_reshape +) diff --git a/onnxscript/rewriter/gemm_to_matmul_add_test.py b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py similarity index 92% rename from onnxscript/rewriter/gemm_to_matmul_add_test.py rename to onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py index aab56cc3fe..90551d8d3b 100644 --- a/onnxscript/rewriter/gemm_to_matmul_add_test.py +++ b/onnxscript/rewriter/rules/common/_gemm_to_matmul_add_test.py @@ -5,7 +5,7 @@ import onnx.parser from onnxscript import ir -from onnxscript.rewriter import gemm_to_matmul_add +from onnxscript.rewriter.rules.common import _gemm_to_matmul_add class ReshapeGemmReshapeTest(unittest.TestCase): @@ -25,7 +25,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable(self): ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) @@ -70,7 +70,7 @@ def test_reshape_gemm_reshape_replace_when_nd_inputs_are_broadcastable_in_nested ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.functions), 1) self.assertEqual(len(model.functions[("pkg.custom", "afunction", "")]), 4) @@ -94,7 +94,7 @@ def test_reshape_gemm_reshape_remain_when_input_last_dim_and_second_last_dim_not """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -115,7 +115,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_not_broadcastable( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -136,7 +136,7 @@ def test_reshape_gemm_reshape_replace_when_inputs_are_broadcastable_with_one_in_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -159,7 +159,7 @@ def test_reshape_gemm_reshape_replace_when_first_input_is_one_dimension_and_broa """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -182,7 +182,7 @@ def test_reshape_gemm_reshape_remain_when_first_input_is_one_dimension_and_not_b """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -203,7 +203,7 @@ def test_reshape_gemm_reshape_replace_when_second_input_is_one_dimension_and_bro """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 4) self.assertEqual(model.graph[2].op_type, "MatMul") @@ -226,7 +226,7 @@ def test_reshape_gemm_reshape_remain_when_second_input_is_one_dimension_and_not_ """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -247,7 +247,7 @@ def test_reshape_gemm_reshape_replaces_when_inputs_are_two_dimensional_and_broad """ ) model = ir.serde.deserialize_model(model_proto) - replacement_count = gemm_to_matmul_add.rule.apply_to_model(model) + replacement_count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(replacement_count, 1) self.assertEqual(len(model.graph), 4) @@ -268,7 +268,7 @@ def test_reshape_gemm_reshape_remain_when_inputs_are_two_dimension_and_not_broad """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) @@ -289,7 +289,7 @@ def test_reshape_gemm_reshape_remain_when_output_is_not_matmul_broadcasted( """ ) model = ir.serde.deserialize_model(model_proto) - count = gemm_to_matmul_add.rule.apply_to_model(model) + count = _gemm_to_matmul_add.gemm_to_matmul_add_rule.apply_to_model(model) self.assertEqual(count, 0) self.assertEqual(len(model.graph), 5) diff --git a/onnxscript/rewriter/matmul_add_to_gemm.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py similarity index 84% rename from onnxscript/rewriter/matmul_add_to_gemm.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py index dc0364a778..fe7a4a6cd8 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm.py @@ -84,20 +84,11 @@ def pattern(self, op, input_a, input_b, input_c): transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() -def gemm_rule_set() -> RewriteRuleSet: - """Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node, - handling cases where one or both MatMul inputs are transposed. - - Returns: - RewriteRuleSet - """ - - # Order is important - return RewriteRuleSet( - [ - transpose_ab_matmul_add_to_gemm_rule, - transpose_a_matmul_add_to_gemm_rule, - transpose_b_matmul_add_to_gemm_rule, - matmul_add_to_gemm_rule, - ] - ) +rules = RewriteRuleSet( + [ + transpose_ab_matmul_add_to_gemm_rule, + transpose_a_matmul_add_to_gemm_rule, + transpose_b_matmul_add_to_gemm_rule, + matmul_add_to_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/matmul_add_to_gemm_test.py b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py similarity index 91% rename from onnxscript/rewriter/matmul_add_to_gemm_test.py rename to onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py index fd08125807..4c643801fc 100644 --- a/onnxscript/rewriter/matmul_add_to_gemm_test.py +++ b/onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py @@ -9,8 +9,8 @@ from parameterized import parameterized from onnxscript import ir -from onnxscript.rewriter import MatchingTracer, MatchStatus, matmul_add_to_gemm, testing -from onnxscript.rewriter.matmul_add_to_gemm import matmul_add_to_gemm_rule +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.common import _matmul_add_to_gemm class _MatMulAddToGemmTestBase(unittest.TestCase): @@ -46,10 +46,10 @@ def get_test_model( bias_shape = weight_shape[0] if transB else weight_shape[-1] output_shape = ir.Shape(("?",) * input_shape.rank()) - x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) if weight_as_inputs: - w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) + w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT)) inputs.append(w) else: w = ir.tensor( @@ -58,7 +58,7 @@ def get_test_model( w = tape.initializer(w) if bias_as_inputs: - b = ir.Input( + b = ir.val( "B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT) ) inputs.append(b) @@ -77,7 +77,7 @@ def get_test_model( y = tape.op( "Add", inputs=[y, b], - output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), ) # Build the model @@ -101,13 +101,15 @@ def check_matmul_add_to_gemm_incompatible_shapes(self, **kwargs): updated_model = self.clone_model(base_model) tracer = MatchingTracer() - count = matmul_add_to_gemm_rule.apply_to_model(updated_model, tracer=tracer) + count = _matmul_add_to_gemm.matmul_add_to_gemm_rule.apply_to_model( + updated_model, tracer=tracer + ) # Check that the model is unchanged self.assertEqual(count, 0) # Check that the error message is the expected one - tracer_match = tracer.best_matches_map[matmul_add_to_gemm_rule][0] + tracer_match = tracer.best_matches_map[_matmul_add_to_gemm.matmul_add_to_gemm_rule][0] self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) self.assertRegex( tracer_match.match_result.reason, "Rank of input_a and input_b must be 2" @@ -129,7 +131,7 @@ def test_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inputs): bias_as_inputs=bias_as_inputs, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul + Add are fused into Gemm self.assertEqual(count, 1) @@ -176,7 +178,7 @@ def test_transpose_a_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transA=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, W) + Add are fused into Gemm self.assertEqual(count, 1) @@ -225,7 +227,7 @@ def test_transpose_b_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_input transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(X, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) @@ -275,7 +277,7 @@ def test_transpose_ab_matmul_add_to_gemm(self, _, weight_as_inputs, bias_as_inpu transB=True, ) updated_model = self.clone_model(base_model) - count = matmul_add_to_gemm.gemm_rule_set().apply_to_model(updated_model) + count = _matmul_add_to_gemm.rules.apply_to_model(updated_model) # Check MatMul(Transpose, Transpose) + Add are fused into Gemm self.assertEqual(count, 1) diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip.py b/onnxscript/rewriter/rules/common/_min_max_to_clip.py new file mode 100644 index 0000000000..88ae495dbc --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip.py @@ -0,0 +1,253 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses successive Min/Max patterns in ONNX graphs. + +Supported transformations: +- Min(Min(X, c1, c2, ...), d1, d2, ...) → Min(X, fused_const) +- Max(Max(X, c1, c2, ...), d1, d2, ...) → Max(X, fused_const) +- Min(Max(X, lb1, lb2, ...), ub1, ub2, ...) → Clip(X, lb, ub) +- Max(Min(X, ub1, ub2, ...), lb1, lb2, ...) → Clip(X, lb, ub) + +Where: + - fused_const is the reduction (min or max) over all constant inputs. + - For Clip fusion: + * All constant inputs must be scalars. + * The effective lower bound is the maximum of all lower-bound constants. + * The effective upper bound is the minimum of all upper-bound constants. + + For the case of Max(Min(X, upper_bound), lower_bound): + * The rule applies only if lower_bound ≤ upper_bound. + +General constraints: + - The first input may be any tensor. + - All other inputs must be constant tensors (from Constant nodes or initializers). +""" + +import abc +import functools +from typing import ClassVar + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _FuseMinMaxBase(RewriteRuleClassBase, abc.ABC): + """Base class for Min/Max fusion rewrites. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - If ``need_scalars`` is True (Clip fusion), all constants must be scalars. + - If ``check_bounds`` is True (Clip fusion in the pattern Max(Min(X, upper_bound), lower_bound)), lower_bound ≤ upper_bound. + """ + + need_scalars: ClassVar = False + check_bounds: ClassVar = False + + @abc.abstractmethod + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: ... + + def rewrite(self, op, x, out1, out2): + first_node = out1.producer() + second_node = out2.producer() + + # Compute new constants for the fused op + constants = self.compute_constants(first_node, second_node, x.name) + + initializers = [op.initializer(constant, name=name) for constant, name in constants] + + return op.op( + self.op_type, + inputs=[x, *initializers], + ) + + def _is_scalar(self, v: np.ndarray) -> bool: + return np.isscalar(v) or np.size(v) == 1 + + def check(self, context, out1, out2, **_): + """Condition to check if we need to replace the pattern. + + Conditions: + - The min and max input nodes must not be graph inputs. + - These inputs (except the first) must be constant values (from Constant nodes or initializers). + - In the case of Min(Max) and Max(Min) patterns: + * All inputs must be scalars (as Clip requires scalars). + For Max(Min) pattern: + * The lower bound must be less than or equal to the upper bound. + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Not used + check_result = MatchResult() + + first_node = out1.producer() + second_node = out2.producer() + + # Ensure all inputs except the first are constants + for input_ in first_node.inputs[1:] + second_node.inputs[1:]: + if ir.convenience.get_const_tensor(input_) is None: + return check_result.fail(f"{input_.name} is not a constant.") + + # If scalars are required (Clip fusion), enforce scalar-ness + if self.need_scalars and not self._is_scalar(input_.const_value.numpy()): + return check_result.fail(f"{input_.name} is not a scalar.") + + if self.need_scalars and self.check_bounds: + # For Clip fusion in the case of Max(Min(X, upper_bound), lower_bound): check that lower_bound <= upper_bound + lower_bound, upper_bound = self.compute_constants(first_node, second_node) + if lower_bound[0].numpy() > upper_bound[0].numpy(): + return check_result.fail( + f"Invalid bounds: lower bound ({lower_bound[0].numpy()}) is greater " + f"than upper bound ({upper_bound[0].numpy()})." + ) + + return check_result + + +class FuseSuccessiveMin(_FuseMinMaxBase): + """Replaces ``Min(Min(X, c1, c2, ...), d1, d2, ...)`` with ``Min(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Min" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.minimum, values)), f"{input_name}_min")] + + def pattern(self, op, x): + return op.Min( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseSuccessiveMax(_FuseMinMaxBase): + """Replaces ``Max(Max(X, c1, c2, ...), d1, d2, ...)`` with ``Max(X, fused_const)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + """ + + op_type: ClassVar = "Max" + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + inputs = first_node.inputs[1:] + second_node.inputs[1:] + values = [input_.const_value.numpy() for input_ in inputs] + return [(ir.tensor(functools.reduce(np.maximum, values)), f"{input_name}_max")] + + def pattern(self, op, x): + return op.Max( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMaxMinToClip(_FuseMinMaxBase): + """Replaces ``Min(Max(X, lb1, lb2, ...), ub1, ub2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + lower_bound = np.max([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + upper_bound = np.min([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Min( + op.Max(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +class FuseMinMaxToClip(_FuseMinMaxBase): + """Replaces ``Max(Min(X, ub1, ub2, ...), lb1, lb2, ...)`` with ``Clip(X, lb, ub)``. + + Constraints: + - All inputs except the first must be constants (from Constant nodes or initializers). + - All constant inputs must be scalars. + - The effective lower bound is ``max(lb1, lb2, ...)``. + - The effective upper bound is ``min(ub1, ub2, ...)``. + - Requires ``lower_bound <= upper_bound``. + """ + + op_type: ClassVar = "Clip" + need_scalars: ClassVar = True + check_bounds: ClassVar = True + + def compute_constants( + self, + first_node: ir.Node, + second_node: ir.Node, + input_name: str = "", + ) -> list[tuple[ir.Tensor, str]]: + upper_bound = np.min([input_.const_value.numpy() for input_ in first_node.inputs[1:]]) + lower_bound = np.max([input_.const_value.numpy() for input_ in second_node.inputs[1:]]) + return [ + (ir.tensor(lower_bound), f"{input_name}_min"), + (ir.tensor(upper_bound), f"{input_name}_max"), + ] + + def pattern(self, op, x): + return op.Max( + op.Min(x, _allow_other_inputs=True, _outputs=["out1"]), + _allow_other_inputs=True, + _outputs=["out2"], + ) + + +min_min_rule = FuseSuccessiveMin().rule() +max_max_rule = FuseSuccessiveMax().rule() +min_max_rule = FuseMinMaxToClip().rule() +max_min_rule = FuseMaxMinToClip().rule() + + +rules = RewriteRuleSet( + [ + min_min_rule, + max_max_rule, + min_max_rule, + max_min_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py new file mode 100644 index 0000000000..dd09078a9e --- /dev/null +++ b/onnxscript/rewriter/rules/common/_min_max_to_clip_test.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common._min_max_to_clip import ( + max_max_rule, + max_min_rule, + min_max_rule, + min_min_rule, + rules, +) + + +class _TestMinMaxToClipBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20250817) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + dtype: str = "float", + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + _ = rules.apply_to_model(updated_model) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.integers( + low=-10, + high=10, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + dtype=np.int32, + ), + ) + if dtype == "float": + inputs = (inputs[0].astype(np.float32),) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseSuccessiveMinOrMax(_TestMinMaxToClipBase): + @parameterized.expand( + [ + ("int32_min", "int32", "Min"), + ("int32_max", "int32", "Max"), + ("float32_min", "float", "Min"), + ("float32_max", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 32, 14, 17] X) => ({dtype} [N, ?, ?, ?] Y) + <{dtype}[1] cst1 = {{3}}, {dtype}[1] cst2 = {{6}}> + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min_multi", "int32", "Min"), + ("int32_max_multi", "int32", "Max"), + ("float32_min_multi", "float", "Min"), + ("float32_max_multi", "float", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_multiple_inputs(self, _, dtype, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model ({dtype}[N, 3, 3] X) => ({dtype}[N, 3, 3] Y) + <{dtype}[3] cst1 = {{2, 5, 8}}, + {dtype}[1] cst2 = {{4}}, + {dtype}[3] cst3 = {{3, 1, -6}}, + {dtype}[1] cst4 = {{10}}, + {dtype}[3] cst5 = {{-2, 7, 9}}, + {dtype}[1] cst6 = {{0}}, + {dtype}[3] cst7 = {{11, -3, 4}}> + {{ + x1 = {op_type}(X, cst1, cst2, cst3, cst4) + Y = {op_type}(x1, cst5, cst6, cst7) + }} + """) + self.run_test(base_model, expected_op_types=[op_type], dtype=dtype) + + @parameterized.expand( + [ + ("int32_min", "Min"), + ("int32_max", "Max"), + ("float32_min", "Min"), + ("float32_max", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + cst2 = Constant() + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=["Constant", op_type]) + + @parameterized.expand( + [ + ("min_nonconst", "Min", min_min_rule), + ("max_nonconst", "Max", max_max_rule), + ] + ) + def test_failure_fuse_successive_min_or_max_non_constant(self, _, op_type, rewrite_rule): + model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] Y) + + {{ + cst1 = ReduceMean(X) + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_failed_condition_test(model, rewrite_rule, "is not a constant.") + + @parameterized.expand( + [ + ("min_graph_input", "Min"), + ("max_graph_input", "Max"), + ] + ) + def test_successful_fuse_successive_min_or_max_graph_inputs_as_constants(self, _, op_type): + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] cst1, float[1] cst2) => (float[N, ?, ?, ?] Y) + + {{ + x1 = {op_type}(X, cst1) + Y = {op_type}(x1, cst2) + }} + """) + self.run_test(base_model, expected_op_types=[op_type]) + + +class TestMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_min_max_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_min_max_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + max = Constant() + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_min_max_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_min_max_to_clip_invalid_bounds(self): + """Min node should have the max value and Max node should have the min value.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "Invalid bounds:") + + def test_failure_fuse_min_max_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(model, min_max_rule, "is not a constant.") + + def test_failure_min_max_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min) + Y = Max(x1, max) + } + """) + self.run_failed_condition_test(base_model, min_max_rule, "is not a scalar") + + +class TestMaxMinToClip(_TestMinMaxToClipBase): + def test_successful_max_min_to_clip(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + min = Constant() + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Constant", "Clip"]) + + def test_successful_max_min_to_clip_graph_inputs_as_constants(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X, float[1] min, float[1] max) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_successful_max_min_to_clip_check_bounds(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_test(base_model, expected_op_types=["Clip"]) + + def test_failure_fuse_max_min_to_clip_non_constant(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14, 17] X) => (float [N, ?, ?, ?] Y) + + { + min = ReduceMean(X) + x1 = Max(X, max) + Y = Min(x1, min) + } + """) + self.run_failed_condition_test(model, max_min_rule, "is not a constant.") + + def test_failure_max_min_to_clip_need_scalars(self): + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 4, 4] X) => (float [N, ?, ?] Y) + + { + x1 = Max(X, min) + Y = Min(x1, max) + } + """) + self.run_failed_condition_test(base_model, max_min_rule, "is not a scalar") + + +class TestIntegrationMinMaxToClip(_TestMinMaxToClipBase): + def test_successful_full_chain_fusion(self): + model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 20] > + test_model (float[N, 32, 14] X) => (float [N, ?, ?] Y) + + { + x1 = Min(X, min1) + x2 = Min(x1, min2) + x3 = Max(x2, max1) + x4 = Max(x3, max2) + x5 = Min(x4, min3) + x6 = Max(x5, max3) + Y = Min(x6, min4) + } + """) + self.run_test(model, expected_op_types=["Clip", "Clip", "Clip"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/no_op.py b/onnxscript/rewriter/rules/common/_no_op.py similarity index 100% rename from onnxscript/rewriter/no_op.py rename to onnxscript/rewriter/rules/common/_no_op.py diff --git a/onnxscript/rewriter/no_op_test.py b/onnxscript/rewriter/rules/common/_no_op_test.py similarity index 90% rename from onnxscript/rewriter/no_op_test.py rename to onnxscript/rewriter/rules/common/_no_op_test.py index 2b2a57f32a..2c2f9e6e2b 100644 --- a/onnxscript/rewriter/no_op_test.py +++ b/onnxscript/rewriter/rules/common/_no_op_test.py @@ -5,16 +5,21 @@ import parameterized from onnxscript import ir -from onnxscript.rewriter import no_op +from onnxscript.rewriter.rules.common import _no_op class NoOpTest(unittest.TestCase): def _check(self, model_text: str) -> None: model = ir.from_onnx_text(model_text) - count = no_op.rules.apply_to_model(model) + count = _no_op.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(model.graph[-1].op_type, "Identity") + def _check_no_optimization(self, model_text: str) -> None: + model = ir.from_onnx_text(model_text) + count = _no_op.rules.apply_to_model(model) + self.assertEqual(count, 0) + @parameterized.parameterized.expand( [ ("float one input", "float[M]", "value_float=1.0", "one, input"), @@ -195,6 +200,17 @@ def test_dropout_zero_or_inference_no_op_with_initializer(self, _, attribute: st ) # TODO: Test the negative cases + def test_broadcast_is_not_eliminated(self): + model_text = """ + + agraph (float[M] input) => (float[1, 1, M] output) + + { + output = Add(zero, input) + } + """ + self._check_no_optimization(model_text) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/redundant_scatter_nd.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py similarity index 91% rename from onnxscript/rewriter/redundant_scatter_nd.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd.py index 5852e85dc3..09c5db7735 100644 --- a/onnxscript/rewriter/redundant_scatter_nd.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd.py @@ -20,7 +20,7 @@ import onnx_ir as ir import onnxscript.rewriter -from onnxscript.rewriter import _ir_utils as ir_utils +from onnxscript.rewriter import _ir_utils from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet @@ -41,7 +41,7 @@ def check(self, context, data, axis, transposed_data, **_): # Check that updated-indices represent the full range of the first dimension of the transposed data. # That is: check that the data.shape[axis] matches transposed_data.shape[0]. result = onnxscript.rewriter.MatchResult() - axis_value = ir_utils.get_singleton_value(axis) + axis_value = _ir_utils.get_singleton_value(axis) if not isinstance(axis_value, int): return result.fail("Axis value must be a constant integer.", axis) shape: ir.Shape | None = data.shape @@ -54,7 +54,7 @@ def check(self, context, data, axis, transposed_data, **_): "Transposed data shape is not statically known.", transposed_data ) actual_dim_value = transposed_data_shape[0] - if updated_dim_value != actual_dim_value: + if not _ir_utils.same_dim(updated_dim_value, actual_dim_value): # The first dimension of the transposed data does not match the updated dimension, # so we cannot apply this rule. return result.fail( @@ -87,7 +87,7 @@ def check(self, context, data, indices, updates, **_): return result.fail("The value 'data' shape is not statically known.", data) if updates.shape is None: return result.fail("The value 'updates' shape is not statically known.", updates) - if data.shape != updates.shape: + if not _ir_utils.same_shape(data.shape, updates.shape): return result.fail( "The shape of 'data' and 'updates' are different.", [data, updates] ) @@ -107,7 +107,7 @@ def rewrite(self, op, updates, **_): return op.Identity(updates) -rule = ScatterAllDynamic.rule() -static_rule = ScatterAllStatic.rule() +no_op_dynamic_scatter_nd_rule = ScatterAllDynamic.rule() +no_op_static_scatter_nd_rule = ScatterAllStatic.rule() -rules = RewriteRuleSet([rule, static_rule]) +rules = RewriteRuleSet([no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule]) diff --git a/onnxscript/rewriter/redundant_scatter_nd_test.py b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py similarity index 96% rename from onnxscript/rewriter/redundant_scatter_nd_test.py rename to onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py index d2ba51eec4..96e3bcc80c 100644 --- a/onnxscript/rewriter/redundant_scatter_nd_test.py +++ b/onnxscript/rewriter/rules/common/_redundant_scatter_nd_test.py @@ -13,7 +13,7 @@ import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op -from onnxscript.rewriter import redundant_scatter_nd +from onnxscript.rewriter.rules.common import _redundant_scatter_nd shape_inference = ShapeInferencePass() onnx_check = CheckerPass(True) @@ -48,7 +48,7 @@ def model_script( onnx_check(model) shape_inference(model) onnxscript.optimizer.fold_constants(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) onnx_check(model) optimized_model_proto = ir.serde.serialize_model(model) @@ -94,7 +94,7 @@ def test_redundant_scatter_nd_static_indices(self): model.graph.initializers["indices"] = indices_value original_model_proto = ir.serde.serialize_model(model) - count = redundant_scatter_nd.rules.apply_to_model(model) + count = _redundant_scatter_nd.rules.apply_to_model(model) self.assertEqual(count, 1) self.assertEqual(len(model.graph), 1) self.assertIn("Identity", [node.op_type for node in model.graph]) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias.py b/onnxscript/rewriter/rules/common/_remove_optional_bias.py new file mode 100644 index 0000000000..ead8a73eab --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Remove optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _RemoveOptionalBias(RewriteRuleClassBase): + def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: + node = out.producer() + + return op.op( + self.op_type, + inputs=node.inputs[:-1], + attributes=node.attributes, + ) + + def check(self, context, b: ir.Value, **_) -> MatchResult: + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the bias is all zeros. The bias should be + a constant value (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + bias_tensor = ir.convenience.get_const_tensor(b) + if bias_tensor is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = bias_tensor.numpy() + if not np.equal(bias_array, 0.0).all(): + return check_result.fail("Bias is not all zeros.") + + return check_result + + +class RemoveOptionalBiasFromConv(_RemoveOptionalBias): + """Remove zero bias from Conv operation.""" + + op_type: ClassVar[str] = "Conv" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Conv(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromConvTranspose(_RemoveOptionalBias): + """Remove zero bias from ConvTranspose operation.""" + + op_type: ClassVar[str] = "ConvTranspose" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.ConvTranspose(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromQLinearConv(_RemoveOptionalBias): + """Remove zero bias from QLinearConv operation.""" + + op_type: ClassVar[str] = "QLinearConv" + + def pattern( + self, + op: ir.tape.Tape, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b: ir.Value, + ) -> ir.Value: + return op.QLinearConv( + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b, + _outputs=["out"], + ) + + +class RemoveOptionalBiasFromGemm(_RemoveOptionalBias): + """Remove zero bias from Gemm operation.""" + + op_type: ClassVar[str] = "Gemm" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Gemm(x, w, b, _outputs=["out"]) + + +remove_optional_bias_from_conv_rule = RemoveOptionalBiasFromConv().rule() +remove_optional_bias_from_conv_transpose_rule = RemoveOptionalBiasFromConvTranspose().rule() +remove_optional_bias_from_qlinear_conv_rule = RemoveOptionalBiasFromQLinearConv().rule() +remove_optional_bias_from_gemm_rule = RemoveOptionalBiasFromGemm().rule() + +rules = RewriteRuleSet( + [ + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_qlinear_conv_rule, + remove_optional_bias_from_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py new file mode 100644 index 0000000000..4349d7aae3 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common import _remove_optional_bias +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) + + +class _RemoveOptionalBiasTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20251016) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def _get_test_model( + self, + op_type: str, + input_shape: ir.Shape, + weight_shape: ir.Shape, + zero_bias: bool, + attributes=None, + ): + tape = ir.tape.Tape() + bias_shape = weight_shape[1] if op_type == "ConvTranspose" else weight_shape[0] + output_shape = ir.Shape(("?",) * input_shape.rank()) + + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + + w = tape.initializer( + ir.tensor(self.rng.uniform(-0.5, 0.5, weight_shape).astype(np.float32), name="W") + ) + + if zero_bias: + bias = np.zeros(bias_shape, dtype=np.float32) + else: + bias = self.rng.uniform(-0.5, 0.5, bias_shape).astype(np.float32) + + b = tape.initializer(ir.tensor(bias, name="B")) + y = tape.op( + op_type, + inputs=[x, w, b], + attributes=attributes, + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + return ir_model + + def run_test( + self, + base_model: ir.Model, + input_shape: tuple, + input_dtype=np.float32, + ): + updated_model = self.clone_model(base_model) + count = _remove_optional_bias.rules.apply_to_model(updated_model) + + # Check rule is applied + self.assertEqual(count, 1) + + # Check number of inputs is reduced + self.assertEqual( + len(updated_model.graph[0].inputs), len(base_model.graph[0].inputs) - 1 + ) + + # Prepare inputs + inputs = (self.rng.random(input_shape).astype(input_dtype),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class RemoveOptionalBiasGemmTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=True, + attributes={"transB": 1}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=False, + attributes={"transB": 1}, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_gemm_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=True, + attributes={"strides": (2, 2)}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTransposeTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=True, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_transpose_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasQLinearConvTest(_RemoveOptionalBiasTestBase): + def _get_test_model(self, zero_bias): + if zero_bias: + bias = np.zeros((16,), dtype=np.int32) + else: + bias = self.rng.uniform(-5, 5, (16,)).astype(np.int32) + + w = ir.tensor(self.rng.uniform(-5, 5, (16, 3, 3, 3)).astype(np.uint8), name="W") + b = ir.tensor(bias, name="B") + + model = ir.from_onnx_text( + """ + < ir_version: 10, opset_import: ["" : 20] > + test_model (uint8[N, 3, 32, 32] X) => (uint8 [N, ?, ?, ?] Y) + + { + Y = QLinearConv(X, x_scale, x_zero_point, W, w_scale, w_zero_point, y_scale, y_zero_point, B) + } + """, + initializers=[w, b], + ) + onnx_checker.CheckerPass(True)(model) + return model + + def test_successful_remove_optional_bias_qlinear_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model(zero_bias=True) + self.run_test(base_model, input_shape, np.uint8) + + def test_fail_remove_optional_bias_qlinear_conv(self): + base_model = self._get_test_model(zero_bias=False) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_qlinear_conv_rule, "Bias is not all zeros." + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/rules/fusion/__init__.py b/onnxscript/rewriter/rules/fusion/__init__.py new file mode 100644 index 0000000000..59e481eb93 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/onnxscript/rewriter/rules/fusion/_gqa.py b/onnxscript/rewriter/rules/fusion/_gqa.py new file mode 100644 index 0000000000..c12dcc7140 --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("ONNXGQA", remove_nodes=False) + + def pattern( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + ): + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) + present_key_BHStD = op.Reshape( + present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) + present_value_BHStD = op.Reshape( + present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] + ) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + pattern.Var("mask", can_match_none=True), + _outputs=["attention_BHSDh"], + ) + + return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD + + def check( + self, + context: _basics.MatchContext, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + present_key_BHStD, + present_value_BHStD, + **_, + ): + bindings: dict[str, Dim] = {} + # Check that inputs to new Attention node have expected shapes + _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) + _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) + _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) + # We need to check that the Expand/Reshape arguments are as expected. + # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. + # TODO (rama): May be better to check the actual Expand/Reshape arguments. + _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) + _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) + + return True + + def rewrite( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + attention_BHSDh, + **_, + ): + original_attention_node = attention_BHSDh.producer() + original_attrs = original_attention_node.attributes + return op.Attention( + query_BHSD, + key_BHkvSD, + value_BHkvSD, + mask, + past_key_BHkvSpD, + past_value_BHkvSpD, + **original_attrs, + _outputs=3, + ) + + +_basic_gqa_rule = OnnxGroupQueryAttention.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/rules/fusion/_gqa_test.py b/onnxscript/rewriter/rules/fusion/_gqa_test.py new file mode 100644 index 0000000000..baf80c4b8c --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_gqa_test.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx +import onnx_ir as ir +from packaging import version + +import onnxscript +import onnxscript.optimizer +import onnxscript.rewriter.testing +from onnxscript import FLOAT, script +from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa + +op = onnxscript.values.Opset("", 23) + +H = [8] # Number of attention heads +Hkv = [4] # Number of key/value heads (H should be divisible by Hkv) +D = [64] # Head size +G = [2] # Number of groups + + +@script(ir_version=10) +def _gqa_script( + query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 + key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 + past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 +) -> FLOAT[2, 8, 4, 64]: + """Basic GQA pattern that should be fused into an Attention op.""" + + # Concatenate past_key cache and current key + present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + + # Unsqueeze to add group dimension + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + + # Calculate shapes dynamically + B = op.Shape(query_BHSD, start=0, end=1) # [B] + T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] + + # Create expand shape [B, Hkv, G, S+P, D] + expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + + # Create reshape shape [B, H, S+P, D] + reshape_shape = op.Concat(B, H, T, D, axis=0) + present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Same for value + present_value_BHkvStD = op.Concat( + past_value_BHkvPD, value_BHkvSD, axis=-2 + ) # [B, Hkv, S+P, D] + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + present_value_BHkvGStD = op.Expand( + present_value_BHkv1StD, expand_shape + ) # [B, Hkv, G, S+P, D] + present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Attention computation + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + ) + + return attention_BHSDh + + +class GQAFusionTest(unittest.TestCase): + def test_basic_gqa_fusion(self): + """Test basic GQA fusion pattern.""" + model_proto = _gqa_script.to_model_proto() + + # Apply GQA fusion + model = ir.serde.deserialize_model(model_proto) + onnxscript.optimizer.optimize(model) + count = fuse_gqa(model) + self.assertGreater(count, 0, "GQA fusion should have occurred") + + # We can't yet test numerical equivalence because of a bug in the op spec/implementation. + onnx_ver = version.parse(onnx.__version__) + if onnx_ver >= version.parse("1.19.1") and not ( + onnx_ver.is_prerelease or onnx_ver.is_devrelease + ): + # Only official releases >= 1.19.1 + onnxscript.optimizer.remove_unused_nodes(model) + rewritten_model_proto = ir.serde.serialize_model(model) + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm.py b/onnxscript/rewriter/rules/fusion/_layer_norm.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_layer_norm.py rename to onnxscript/rewriter/rules/fusion/_layer_norm.py diff --git a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py similarity index 97% rename from onnxscript/rewriter/onnx_fusions/_layer_norm_test.py rename to onnxscript/rewriter/rules/fusion/_layer_norm_test.py index 6c9734d058..5e13f5e479 100644 --- a/onnxscript/rewriter/onnx_fusions/_layer_norm_test.py +++ b/onnxscript/rewriter/rules/fusion/_layer_norm_test.py @@ -5,12 +5,11 @@ import onnx_ir as ir -import onnxscript import onnxscript.optimizer import onnxscript.rewriter.testing from onnxscript import FLOAT, OnnxFunction, script from onnxscript import opset18 as op -from onnxscript.rewriter.onnx_fusions._layer_norm import fuse_layer_normalization +from onnxscript.rewriter.rules.fusion._layer_norm import fuse_layer_normalization @script() diff --git a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py b/onnxscript/rewriter/rules/fusion/_rms_normalization.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_rms_normalization.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization.py diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py similarity index 53% rename from onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py rename to onnxscript/rewriter/rules/fusion/_rms_normalization_test.py index 59a460005a..e70c4ec7a0 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/rules/fusion/_rms_normalization_test.py @@ -5,14 +5,12 @@ import unittest import onnx_ir as ir -from parameterized import parameterized import onnxscript -import onnxscript.rewriter.onnx_fusions as onnx_fusions -from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rms_normalization -class OnnxFusionsTest(unittest.TestCase): +class RmsNormOnnxFusionsTest(unittest.TestCase): def test_rms_normalization_fusion(self): opset23 = onnxscript.values.Opset("", 23) @@ -34,34 +32,10 @@ def rms_norm_script(embedding, layernorm_weight): output_types=[onnxscript.FLOAT[128]], ) model = ir.serde.deserialize_model(rms_norm_model_proto) - onnx_fusions.fuse(model, debug=True) + count = _rms_normalization.fuse_rms_normalization(model) + self.assertEqual(count, 1) self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") - @parameterized.expand( - [ - ( - "test_case_1", - _rotary_embedding_models.test_case_1, - ), - ( - "test_case_2", - _rotary_embedding_models.test_case_2, - ), - ] - ) - def test_rotary_embedding_fusion(self, _: str, test_data_constructor): - test = test_data_constructor() - for opset_version in [22, 23]: - model: ir.Model = test.get_onnx_model() - model.graph.opset_imports[""] = opset_version - onnxscript.optimizer.optimize(model) - onnx_fusions.fuse(model) - op_types = [n.op_type for n in model.graph] - if opset_version == 22: - self.assertNotIn("RotaryEmbedding", op_types) - else: - self.assertIn("RotaryEmbedding", op_types) - if __name__ == "__main__": unittest.main() diff --git a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py similarity index 84% rename from onnxscript/rewriter/onnx_fusions/_rotary_embedding.py rename to onnxscript/rewriter/rules/fusion/_rotary_embedding.py index 2009c6953f..b659afdbc0 100644 --- a/onnxscript/rewriter/onnx_fusions/_rotary_embedding.py +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding.py @@ -30,13 +30,24 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2): class RotaryEmbedding23Fusion(pattern.RewriteRuleClassBase): def __init__(self): - super().__init__(name="RotaryEmbedding23") + super().__init__(name="RotaryEmbedding23", remove_nodes=False) - def pattern(self, op, x, cos, sin, start1, end1, start2, end2): - return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2): + freqs_repeated = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(freqs_repeated) + sin = op.Sin(freqs_repeated) + cos_4d = op.Unsqueeze(cos, one1) + sin_4d = op.Unsqueeze(sin, one2) + return x * cos_4d + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin_4d - def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: # type: ignore[name-defined] + def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() + + if not _ir_utils.is_singleton_value(one1, 1): + return check_result.fail("Unsqueeze axes is not [1]", one1) + if not _ir_utils.is_singleton_value(one2, 1): + return check_result.fail("Unsqueeze axes is not [1]", one2) + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) if x is None or x.shape is None or len(x.shape) != 4: return check_result.fail("Input is not known to be a 4D tensor.", x) @@ -59,8 +70,10 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: ) return check_result - def rewrite(self, op, x, cos, sin, **_): + def rewrite(self, op, x, freqs, **_): num_heads = x.shape[1] + cos = op.Cos(freqs) + sin = op.Sin(freqs) return op.RotaryEmbedding( x, cos, diff --git a/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py new file mode 100644 index 0000000000..b8ffe95cac --- /dev/null +++ b/onnxscript/rewriter/rules/fusion/_rotary_embedding_test.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx +import onnx_ir as ir +from packaging.version import Version +from parameterized import parameterized + +import onnxscript +import onnxscript.rewriter.testing +from onnxscript.rewriter.models import _rotary_embedding_models +from onnxscript.rewriter.rules.fusion import _rotary_embedding + + +class RotaryEmbeddingOnnxFusionTest(unittest.TestCase): + @parameterized.expand( + [ + ( + "test_case_1", + _rotary_embedding_models.test_case_1, + ), + ( + "test_case_2", + _rotary_embedding_models.test_case_2, + ), + ] + ) + def test_rotary_embedding_fusion(self, _: str, test_data_constructor): + test = test_data_constructor() + model: ir.Model = test.get_onnx_model() + model.graph.opset_imports[""] = 23 + model_proto = ir.serde.serialize_model(model) + onnxscript.optimizer.optimize(model) + _rotary_embedding.fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + rewritten_model_proto = ir.serde.serialize_model(model) + inputs = test.get_ort_inputs() + + onnx_version = Version(onnx.__version__) + min_version = Version("1.19.1") + is_stable = not (onnx_version.is_devrelease or onnx_version.is_prerelease) + if onnx_version >= min_version and is_stable: + onnxscript.rewriter.testing.assert_numerically_equal( + model_proto, rewritten_model_proto, args=inputs, use_reference=True + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 591f9387c2..2a9d24ee01 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -6,6 +6,7 @@ import numpy as np import onnx +import onnx.reference import onnxruntime as ort from onnxscript import ir @@ -32,10 +33,11 @@ def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]: def assert_numerically_equal( original_model_proto: onnx.ModelProto | ir.Model, rewritten_model_proto: onnx.ModelProto | ir.Model, - args: tuple[Any, ...] | dict[str, Any], + args: tuple[Any, ...] | dict[str, Any] | None = None, ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, rtol: float = 1, atol: float = 1e-3, + use_reference: bool = False, ): """Assert that the two models are numerically equal. @@ -46,6 +48,7 @@ def assert_numerically_equal( ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. + use_reference: If True, use ONNX reference implementation instead of ONNXRuntime. """ if isinstance(original_model_proto, ir.Model): @@ -53,7 +56,10 @@ def assert_numerically_equal( if isinstance(rewritten_model_proto, ir.Model): rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) - if isinstance(args, dict): + if args is None: + original_proto_ort_inputs = generate_random_inputs(original_model_proto) + the_rewritten_proto_ort_inputs = original_proto_ort_inputs + elif isinstance(args, dict): original_proto_ort_inputs = args the_rewritten_proto_ort_inputs = args else: @@ -64,21 +70,34 @@ def assert_numerically_equal( k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) } - original_proto_ort_inference_session = _ort_session_initializer( - original_model_proto.SerializeToString(), ort_optimization_level - ) - run_options = ort.RunOptions() - run_options.log_severity_level = 3 # 3: Error - original_outputs = original_proto_ort_inference_session.run( - None, original_proto_ort_inputs, run_options=run_options - ) - - the_rewritten_proto_ort_inference_session = _ort_session_initializer( - rewritten_model_proto.SerializeToString(), ort_optimization_level - ) - the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( - None, the_rewritten_proto_ort_inputs, run_options=run_options - ) + if use_reference: + # Use ONNX reference implementation + original_evaluator = _reference_session( + original_model_proto.SerializeToString(), ort_optimization_level + ) + original_outputs = original_evaluator.run(None, original_proto_ort_inputs) + + rewritten_evaluator = _reference_session( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs) + else: + # Use ONNXRuntime + original_proto_ort_inference_session = _ort_session_initializer( + original_model_proto.SerializeToString(), ort_optimization_level + ) + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + original_outputs = original_proto_ort_inference_session.run( + None, original_proto_ort_inputs, run_options=run_options + ) + + the_rewritten_proto_ort_inference_session = _ort_session_initializer( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( + None, the_rewritten_proto_ort_inputs, run_options=run_options + ) np.testing.assert_allclose( original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True @@ -103,3 +122,18 @@ def _ort_session_initializer( provider for provider in possible_providers if provider in available_providers ] return ort.InferenceSession(model, providers=providers, sess_options=session_options) + + +def _reference_session( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> onnx.reference.ReferenceEvaluator: + """Initialize an ONNX reference evaluator with the specified model.""" + # Parse the model from bytes if needed + if isinstance(model, (str, bytes)): + model_proto = onnx.load_from_string(model) + else: + model_proto = model + + # Note: ort_optimization_level is ignored for reference implementation + # as it doesn't have equivalent optimization levels + return onnx.reference.ReferenceEvaluator(model_proto) diff --git a/onnxscript/utils/metadata_merger.py b/onnxscript/utils/metadata_merger.py new file mode 100644 index 0000000000..121d8db8c8 --- /dev/null +++ b/onnxscript/utils/metadata_merger.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Merging metadata_props""" + +from __future__ import annotations + +from typing import Callable, Iterable + +import onnx_ir as ir + +# Utilities for merging metadata properties, represented as strings. +# The merging-logic will take care of special cases like missing metadata or +# empty string metadata, and so the functions defined below need not handle +# special cases like empty string. (This does assume that an empty string is +# the same as no metadata, which is a reasonable assumption for most metadata.) + +StringMerger = Callable[[str, str], str] + + +def overwrite(_: str, new: str) -> str: + return new + + +def join(separator: str) -> StringMerger: + """Creates a StringMerger that joins two strings with the given separator. + + Args: + separator (str): The separator to use when joining the strings. + + Returns: + StringMerger: A function that joins two strings with the specified separator. + """ + + def merger(first: str, second: str) -> str: + return f"{first}{separator}{second}" + + return merger + + +comma_separator_merger = join(", ") + + +class MetadataMerger: + """Merges metadata properties using specified merging logic. + + Attributes: + mergers: A mapping from metadata property keys to their corresponding merging functions. + default: The default merging function to use when a specific key does not have a defined merger. + If None, the first value is used. (Specify `overwrite` to always use the second value.) + """ + + def __init__( + self, mergers: dict[str, StringMerger], default: StringMerger | None = None + ) -> None: + self.mergers = mergers + self.default = default + + def update_dict(self, updated: dict[str, str], updates: dict[str, str]) -> None: + """Updates the first metadata property dictionary with values from the second. + + Args: + updated: The metadata dictionary to be updated. + updates: The updates metadata dictionary. + """ + for key, new_value in updates.items(): + if new_value == "": + continue + if (key in updated) and ((updated_value := updated[key]) != ""): + merger = self.mergers.get(key, self.default) + if merger is not None: + updated[key] = merger(updated_value, new_value) + else: + updated[key] = new_value + + def copy_merged_metadata( + self, from_nodes: Iterable[ir.Node], to: ir.Node | Iterable[ir.Node] + ) -> None: + """Merges metadata from multiple nodes and assigns it to one or more target nodes. + + Args: + from_nodes: The source nodes from which to merge metadata. + to: The target node(s) to which the merged metadata will be assigned. + """ + if isinstance(to, ir.Node): + updated = to.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + elif len(to) == 1: + # Handle single node in iterable case + target_node = next(iter(to)) + updated = target_node.metadata_props + for node in from_nodes: + self.update_dict(updated, node.metadata_props) + else: + merged_metadata: dict[str, str] = {} + for node in from_nodes: + self.update_dict(merged_metadata, node.metadata_props) + for target_node in to: + self.update_dict(target_node.metadata_props, merged_metadata) diff --git a/onnxscript/utils/replace.py b/onnxscript/utils/replace.py new file mode 100644 index 0000000000..d46493155d --- /dev/null +++ b/onnxscript/utils/replace.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""A utility function to replace custom operations in a model with their expansions""" + +from typing import Sequence + +import onnx +import onnx_ir as ir +import onnx_ir.passes.common as common_passes + + +def replace_functions_inplace(irmodel: ir.Model, irfunctions: Sequence[ir.Function]) -> None: + """A utility function to replace custom operations in a model with their expansions: + + The model is updated in-place. + + Args: + irmodel: An ONNX model possibly containing calls to custom operations. + irfunctions: A sequence of functions defining the expansions for the custom operations. + + + """ + model_functions = irmodel.functions + if len(model_functions) != 0: + # Since we use inlining, check that there are no model-local functions. + raise ValueError("Input model cannot have model-local functions.") + for func in irfunctions: + model_functions[func.identifier()] = func + + # TODO (rama): Ideally, we should provide users more control over renaming strategy for inlined values. + common_passes.InlinePass()(irmodel) + common_passes.RemoveUnusedOpsetsPass()(irmodel) + + +def replace_functions( + model: onnx.ModelProto, functions: Sequence[onnx.FunctionProto] +) -> onnx.ModelProto: + """A utility function to replace custom operations in a model with their expansions: + Args: + model: An ONNX ModelProto possibly containing calls to custom operations. + functions: A sequence of FunctionProto defining the expansions for the custom operations. + + Returns: + An updated ModelProto with custom operations replaced by their expansions. + """ + irmodel = ir.from_proto(model) + irfunctions = [ir.from_proto(func) for func in functions] + replace_functions_inplace(irmodel, irfunctions) + return ir.to_proto(irmodel) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index dddf11150c..cb7a6c43ad 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -155,12 +155,13 @@ def _get_str_attribute(node: ir.Node, name: str, default: str | None = None) -> @register("DFT", node_version=19, up_conversion=True) def dft_19_20(node: ir.Node, op): input = node.inputs[0] + dft_length = node.inputs[1] if len(node.inputs) > 1 else None inverse = _get_int_attribute(node, "inverse", 0) onesided = _get_int_attribute(node, "onesided", 0) axis = _get_int_attribute(node, "axis", None) if axis is not None: axis_value = op.Constant(value_int=axis) - return op.DFT(input, axis_value, inverse=inverse, onesided=onesided) + return op.DFT(input, dft_length, axis_value, inverse=inverse, onesided=onesided) return None diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index cf6507196b..021c6e72bb 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -144,7 +144,7 @@ def test_version_convert_compatible(self): self.assertEqual(model.graph.node(3).version, 20) self.assertEqual(model.graph.node(3).op_type, "DFT") self.assertEqual(model.graph.node(3).version, 20) - self.assertEqual(len(model.graph.node(3).inputs), 2) + self.assertEqual(len(model.graph.node(3).inputs), 3) def test_version_convert_gridsample_linear(self): model = ir.from_onnx_text( @@ -241,7 +241,7 @@ def test_version_convert_inline(self): self.assertEqual(model.graph.node(4).attributes["mode"].value, "linear") self.assertEqual(model.graph.node(6).op_type, "DFT") self.assertEqual(model.graph.node(6).version, 20) - self.assertEqual(len(model.graph.node(6).inputs), 2) + self.assertEqual(len(model.graph.node(6).inputs), 3) class VersionConverter20to21Test(unittest.TestCase): diff --git a/pyproject.toml b/pyproject.toml index 1f720c1168..1e6a99f656 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,13 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "License :: OSI Approved :: MIT License", ] dependencies = [ "ml_dtypes", "numpy", - "onnx_ir>=0.1.7,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx_ir>=0.1.12,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. "onnx>=1.16", "packaging", "typing_extensions>=4.10", @@ -41,7 +42,6 @@ onnxscript = ["py.typed"] onnx = ["py.typed"] [tool.pytest.ini_options] -filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"] addopts = "-rsfEX --tb=short --color=yes" [tool.mypy] @@ -80,40 +80,6 @@ module = [ ] ignore_errors = true -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern_test.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - -# FIXME(#1378): Remove this overrides section -[[tool.mypy.overrides]] -module = [ - "onnxrewriter.rewriter.generic_pattern.*", -] -check_untyped_defs = false -disable_error_code = 'override,import-untyped,no-untyped-def,assignment,union-attr,func-returns-value,annotation-unchecked,arg-type,index,name-defined,attr-defined' -disallow_incomplete_defs = true -disallow_untyped_defs = true -disallow_untyped_decorators = true -show_column_numbers = true -strict_optional = true -warn_incomplete_stub = true -warn_no_return = true -warn_unused_configs = true -warn_unused_ignores = false - [tool.pylint.messages_control] # NOTE: This list is for vscode. Add new disables in pyproject_pylint.toml for lintrunner # Exclude patterns should be modified in .lintrunner.toml diff --git a/requirements-dev.txt b/requirements-dev.txt index 355fce3bff..b689d9bad5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,9 +17,6 @@ sphinx>=6 myst_nb chardet -# Torch lib -beartype!=0.16.0 - # Testing expecttest==0.1.6 hypothesis diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index e2eda3baa9..d206c9fcd6 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.19.0.dev20250602 +onnx-weekly==1.21.0.dev20251103 diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 4ed908b4e2..f2e801846a 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.23.0.dev20250517001 +onnxruntime==1.23.2 diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index 41e736dcb4..f140d45917 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,11 +1,11 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.12.10 +ruff==0.14.7 # MYPY mypy==1.10.1 -types-PyYAML==6.0.12.20250402 +types-PyYAML==6.0.12.20250915 # PYLINT -pylint==3.3.6 +pylint==3.3.9 # EDITORCONFIG-CHECKER -editorconfig-checker==3.2.0 +editorconfig-checker==3.4.1 diff --git a/tests/eager_mode_test.py b/tests/eager_mode_test.py index 566169f223..e4cb0ab313 100644 --- a/tests/eager_mode_test.py +++ b/tests/eager_mode_test.py @@ -6,7 +6,6 @@ import numpy as np import parameterized -import onnxscript import onnxscript.evaluator import onnxscript.tensor from onnxscript import opset17 as op diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index ab58bbc1a1..d344723408 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -1,10 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations -# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo - +import math import unittest +import parameterized + +# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo import torch from torch.onnx._internal.exporter import _testing @@ -76,6 +79,88 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) _testing.assert_onnx_program(onnx_program) + def test_repeat_interleave_integer_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_integer_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.repeat_interleave(x, 3, dim=1) + + onnx_program = torch.onnx.export( + Model(), (torch.randn(2, 3, 4),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind, dim=0) + + onnx_program = torch.onnx.export( + Model(), + ( + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + torch.tensor([1, 2], dtype=torch.int64), + ), + dynamo=True, + optimize=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_tensor_none(self): + class Model(torch.nn.Module): + def forward(self, x, ind): + return torch.repeat_interleave(x, ind) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.tensor([1, 2, 3, 2], dtype=torch.int64), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + dynamo=True, + optimize=False, + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "ind"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_repeat_interleave_symbolic_tensor(self): + class Model(torch.nn.Module): + def forward(self, x, y): + return torch.repeat_interleave(x, y.shape[1], dim=1) * torch.repeat_interleave( + y, x.shape[1], dim=1 + ) + + inputs = ( + torch.arange(4, dtype=torch.float32).reshape((2, 2)), + torch.arange(6, dtype=torch.float32).reshape((2, 3)), + ) + onnx_program = torch.onnx.export( + Model(), + inputs, + input_names=["x", "y"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + def test_sdpa_with_bool_attn_mask(self): class ScaledDotProductAttention(torch.nn.Module): def forward(self, query, key, value, attn_mask): @@ -98,6 +183,694 @@ def forward(self, query, key, value, attn_mask): ) _testing.assert_onnx_program(onnx_program) + def test_dynamic_paddings(self): + class Model(torch.nn.Module): + def forward(self, x): + height = x.size(2) # height is SymInt + x = torch.nn.functional.pad(x, (0, 0, 0, height), mode="replicate") + return x + + onnx_program = torch.onnx.export( + Model(), + (torch.rand(1, 1, 1, 1),), + dynamo=True, + dynamic_shapes=({2: torch.export.Dim("H")},), + ) + _testing.assert_onnx_program(onnx_program) + + def test_enable_gqa_in_attention(self): + class Model(torch.nn.Module): + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + enable_gqa=True, + ) + + model = Model() + + query = torch.randn(2, 4, 8, 16) + key = torch.randn(2, 2, 8, 16) + value = torch.randn(2, 2, 8, 16) + + onnx_program = torch.onnx.export( + model, + ( + query, + key, + value, + ), + input_names=["query", "key", "value"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_bitwise_and_scalar(self): + class Model(torch.nn.Module): + def forward(self, x): + return x & 3 + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([1, 2, 3, 4, 5]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_dft_axis_promoted_from_attribute_to_input(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten._fft_r2c(x, [0], normalization=1, onesided=True) # pylint: disable=protected-access + + onnx_program = torch.onnx.export( + Model(), + (torch.randn(2, 3),), + opset_version=20, + dynamic_shapes=({0: "dim_x"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_avg_pool(self): + class Model(torch.nn.Module): + def forward(self, x2d, x3d, x4d, x5d): + return ( + torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable + torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable + ) + + x2d = torch.randn(10, 10) + x3d = torch.randn(10, 10, 10) + x4d = torch.randn(10, 10, 10, 10) + x5d = torch.randn(10, 10, 10, 10, 10) + onnx_program = torch.onnx.export( + Model(), + (x2d, x3d, x4d, x5d), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_concat_with_empty_tensor(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.cat([x, torch.tensor([]), x], dim=0) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([1, 2]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_concat_with_empty_tensor_single_element(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.cat([x, torch.tensor([])], dim=1) + + onnx_program = torch.onnx.export( + Model(), + (torch.tensor([[1, 2]]),), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_unidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_bidirectional(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_lstm_multilayer(self): + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.lstm = torch.nn.LSTM( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.lstm(x) + + model = LSTMModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_unidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=1, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_bidirectional(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, + hidden_size=20, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_gru_multilayer(self): + class GRUModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.gru = torch.nn.GRU( + input_size=10, hidden_size=20, num_layers=3, batch_first=True + ) + + def forward(self, x): + return self.gru(x) + + model = GRUModel() + x = torch.randn(5, 3, 10) # (batch, seq, input_size) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_int32(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 0, 0], dtype=torch.int32) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_unique_consecutive_return(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.unique_consecutive(x, return_inverse=True, return_counts=True) + + model = Model() + x = torch.tensor([0, 1, 2, 2, 3, 3, 3, 0, 0], dtype=torch.int64) + onnx_program = torch.onnx.export( + model, + (x,), + dynamic_shapes=({0: "length"},), + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_1(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=True) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_2(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft(x, n_fft=4, return_complex=False) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_3(self): + class Model(torch.nn.Module): + def forward(self, x): + window = torch.ones(16, dtype=torch.float32) + return torch.ops.aten.stft(x, n_fft=16, window=window, return_complex=False) + + x = torch.randn(100, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_aten_stft_4(self): + class Model(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.stft( + x, + n_fft=4, + hop_length=1, + win_length=4, + center=True, + onesided=True, + return_complex=True, + ) + + x = torch.randn(4, 16, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x,), + dynamo=True, + verbose=False, + ) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dim0(self): + """Test unbind along dimension 0""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dim1(self): + """Test unbind along dimension 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_negative_dim(self): + """Test unbind with negative dimension""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=-1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_size_one(self): + """Test unbind with dimension of size 1""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return tensors[0] + + model = UnbindModel() + x = torch.randn(1, 4, 5) + onnx_program = torch.onnx.export(model, (x,), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_with_lstm(self): + """Test unbind in LSTM context""" + + class LSTMDecoder(torch.nn.Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(100, 64) + self.lstm = torch.nn.LSTM(64, 64, 2, batch_first=True) # 2 layers + self.fc = torch.nn.Linear(64, 100) + + def forward(self, tokens, h, c): + embedded = self.embedding(tokens).unsqueeze(0) + output, (h_out, c_out) = self.lstm(embedded, (h, c)) + logits = self.fc(output.squeeze(0).squeeze(0)) + return logits, h_out, c_out + + model = LSTMDecoder() + model.eval() + tokens = torch.tensor([1]) + h = torch.randn(2, 1, 64) # 2 layers + c = torch.randn(2, 1, 64) # 2 layers + onnx_program = torch.onnx.export(model, (tokens, h, c), dynamo=True, verbose=False) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim0(self): + """Test unbind with dynamic dimension 0 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=0) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(3, 4, 5) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({0: "batch_size"},) + ) + _testing.assert_onnx_program(onnx_program) + + def test_unbind_dynamic_dim1(self): + """Test unbind with dynamic dimension 1 - triggers SplitToSequence""" + + class UnbindModel(torch.nn.Module): + def forward(self, x): + tensors = torch.unbind(x, dim=1) + return sum(tensors) + + model = UnbindModel() + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + model, (x,), dynamo=True, verbose=False, dynamic_shapes=({1: "seq_len"},) + ) + _testing.assert_onnx_program(onnx_program) + + @parameterized.parameterized.expand( + [ + # Multiple advanced indices, all 1D tensors. + # Non-contiguous advanced indices: updates must be broadcastable to (2, 6) + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 6), + "non_contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (2, 1), + "non_contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (1, 6), + "non_contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (6,), + "non_contiguous_non_broadcast_indices_new_dim1", + ), + ( + (6, 6, 6), + [[0, 1], None, [2, 3]], + (), + "non_contiguous_non_broadcast_indices_scalar", + ), + # Contiguous advanced indices versions of above tests: updates must be broadcastable to (6, 2) + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 2), + "contiguous_non_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (6, 1), + "contiguous_non_broadcast_indices_expand_dim2", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (1, 2), + "contiguous_non_broadcast_indices_expand_dim1", + ), + ( + (6, 6, 6), + [None, [0, 1], [2, 3]], + (2,), + "contiguous_non_broadcast_indices_new_dim1", + ), + ((6, 6, 6), [None, [0, 1], [2, 3]], (), "contiguous_non_broadcast_indices_scalar"), + # Multiple advanced indices, with broadcasting among indices. + # Contiguous advanced indices: + # This produces index tuples [(0,2), (0, 3), (1,2), (1,3)] in shape (2,2) + # The update values must be broadcastable to (6,2,2) + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 2, 2), + "contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (6, 1, 1), + "contiguous_broadcast_indices_expand_dim2_dim3", + ), + ( + (6, 6, 6), + [None, [[0], [1]], [2, 3]], + (2,), + "contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Non-contiguous advanced indices versions of above tests: + # Here, update values must be broadcastable to (2,2,6) + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (2, 2, 6), + "non_contiguous_broadcast_indices_no_value_broadcast", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (1, 1, 6), + "non_contiguous_broadcast_indices_expand_dim1_dim2", + ), + ( + (6, 6, 6), + [[[0], [1]], None, [2, 3]], + (6,), + "non_contiguous_broadcast_indices_extend_dim1_dim2", + ), + # Other test cases + ( + (4, 4, 4, 4), + [None, [0, 1], None, [2, 3]], + (2, 4, 4), + "non_contiguous_non_first", + ), + ((6, 6, 6), [0, None, None], (6, 6), "single_scalar_index"), + ((6, 6, 6), [0, None, [0, 1]], (2, 6), "non_contiguous_scalar_index_and_1d_index"), + ((6, 6, 6), [None, 0, [0, 1]], (6, 2), "contiguous_scalar_index_and_1d_index"), + # (TODO): Exporter doesn't yet support all None indices + # ((6, 6, 6), [None, None, None], (6, 6, 6), "all_none_indices"), + ] + ) + def test_index_put(self, x_shape, index_list, update_shape, _: str): + indices = [ + (torch.tensor(index, dtype=torch.int64) if index is not None else None) + for index in index_list + ] + + class Model(torch.nn.Module): + def forward(self, x, update): + return torch.ops.aten.index_put(x, indices, update, accumulate=True) + + x = torch.zeros(x_shape, dtype=torch.float32) + update = torch.randn(update_shape, dtype=torch.float32) + + onnx_program = torch.onnx.export( + Model(), + (x, update), + input_names=["x", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_dynamic(self): + for dimension in [3, 4, 2]: + with self.subTest(dimension=dimension): + + class Model(torch.nn.Module): + def __init__(self, dimension): + super().__init__() + self.params = torch.zeros( + (4, 5) + if dimension == 2 + else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5)) + ) + self.dimension = dimension + + def forward(self, update, index1, index2): + copy = self.params.clone() + if self.dimension == 2: + copy[index1, index2] = update + elif self.dimension == 3: + copy[:, index1, index2] = update + else: + copy[:, :, index1, index2] = update + return copy + + update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32) + index1 = torch.tensor([1, 2], dtype=torch.int64) + index2 = torch.tensor([3, 4], dtype=torch.int64) + feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2))) + onnx_program = torch.onnx.export( + Model(dimension), + tuple(feeds.values()), + input_names=["update", "index1", "index2"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes={ + "update": {0: "dn"}, + "index1": {0: "dn"}, + "index2": {0: "dn"}, + }, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_12_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update) + + x = torch.zeros((6, 5), dtype=torch.float32) + index = torch.tensor([[2, 1]], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_55_2_25(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + return torch.ops.aten.index_put(x, [index], update, accumulate=True) + + x = torch.ones((6, 5), dtype=torch.float32) + index = torch.tensor([4, 3], dtype=torch.int64) + update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32) + onnx_program = torch.onnx.export( + Model(), + (x, index, update), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + ) + _testing.assert_onnx_program(onnx_program) + + def test_index_put_scatter_nd(self): + class Model(torch.nn.Module): + def forward(self, x, index, update): + x = x.clone() + return torch.ops.aten.index_put(x, [None, index, None], update) + + shape = (2, 3, 2) + N = math.prod(shape) + x = torch.arange(N, dtype=torch.float32).reshape(shape) + update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100 + index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2] + + feeds = dict(zip(["x", "index", "update"], (x, index, update))) + onnx_program = torch.onnx.export( + Model(), + tuple(feeds.values()), + input_names=["x", "index", "update"], + output_names=["output"], + opset_version=18, + dynamo=True, + dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}), + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 3d81896187..6cc7ab6d35 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(item, dtype=dtype) +def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs): + """Sample inputs for bilinear operation.""" + del op_info + del kwargs + + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + # Test cases: (batch_size, in1_features, in2_features, out_features) + cases = [ + (2, 3, 4, 5), # Basic case + (1, 2, 2, 1), # Minimal case + (3, 5, 7, 4), # Different dimensions + (2, 1, 1, 3), # Single input features + ] + + for batch_size, in1_features, in2_features, out_features in cases: + input1 = make_arg((batch_size, in1_features)) + input2 = make_arg((batch_size, in2_features)) + weight = make_arg((out_features, in1_features, in2_features)) + bias = make_arg((out_features,)) + + # Test with bias + yield opinfo_core.SampleInput(input1, args=(input2, weight, bias)) + + # Test without bias (only for first case to avoid too many tests) + if batch_size == 2: + yield opinfo_core.SampleInput(input1, args=(input2, weight, None)) + + def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs): del op_info @@ -87,6 +118,35 @@ def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_gra yield opinfo_core.SampleInput(t, kwargs={"p": p}) +def sample_inputs_broadcast_in_dim(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + + # cases: (input_shape, target_shape, broadcast_dimensions) + # broadcast_dimensions maps each input dim to an axis in target_shape + cases = ( + # scalar -> 1-D tensor + ((), (3,), ()), + # identity (no-op broadcast) + ((3,), (3,), (0,)), + # rank-preserving broadcast where singleton dims expand + ((1, 3, 1), (2, 3, 4), (0, 1, 2)), + # input rank 2 -> output rank 3, input dims map to trailing axes + ((3, 1), (2, 3, 4), (1, 2)), + # add leading broadcast axis + ((3, 4), (1, 3, 4), (1, 2)), + # insert broadcasting in middle axis + ((3,), (2, 3, 1), (1,)), + ) + make_arg = functools.partial( + torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + + for shape, target_shape, broadcast_dimensions in cases: + tensor = make_arg(shape) + yield opinfo_core.SampleInput(tensor, args=(target_shape, broadcast_dimensions)) + + def sample_inputs_col2im(op_info, device, dtype, requires_grad, **kwargs): del op_info # input_shape, output_size, kernal, dilation, padding, stride @@ -719,6 +779,109 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs_fake_quantize_per_tensor_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary) + shapes = [ + (S,), + (1, 0, 3), + (), + ] + + scale_zero_point_dtypes = [ + # default (float, int) + (None, None) + ] + [ + # tensor_qparams (tensor, tensor) + (t1, t2) + for t1 in common_dtype.all_types_and() + for t2 in common_dtype.all_types_and() + ] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals) + for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases: + scale = make_arg( + (), + dtype=scale_dtype or torch.float64, + ) + if scale_dtype is None: + scale = scale.item() + + zero_point = make_arg( + (), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + if zero_point_dtype is None: + zero_point = zero_point.item() + + args = (scale, zero_point, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + +def sample_inputs_fake_quantize_per_channel_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, 2D, 4D and empty tensors (scalar tensors not supported) + axes_and_shapes = [ + # 1D, 2D, 4D + (axis, (S,) * dims) + for dims in (1, 2, 4) + for axis in range(dims) + ] + [ + # empty + (0, (1, 0, 3)), + (2, (1, 0, 3)), + # empty channel axis causes an error due to + # an internal zero_point.min() calculation + # (1, (1, 0, 3)), + ] + + # tensor_qparams + scale_dtype = torch.float + zero_point_dtypes = [torch.int32, torch.float, torch.half] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals) + for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases: + scale = make_arg((shape[axis],), dtype=scale_dtype) + + zero_point = make_arg( + (shape[axis],), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + + args = (scale, zero_point, axis, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + def _index_variable_bool(shape, max_indices, device): if not isinstance(shape, tuple): shape = (shape,) @@ -1336,6 +1499,109 @@ def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs) yield opinfo_core.SampleInput(input_, args=(src, *args)) +def sample_inputs_scatter_src(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.src + cases = [ + # (self_shape, index_shape, src_shape, dim) + ((5, 5), (2, 3), (2, 3), 0), # 2D scatter on dim=0 + ((5, 5), (3, 2), (3, 2), 1), # 2D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 0), # 3D scatter on dim=0 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 1), # 3D scatter on dim=1 + ((3, 4, 5), (2, 2, 3), (2, 2, 3), 2), # 3D scatter on dim=2 + ((10,), (3,), (3,), 0), # 1D scatter + ] + + for self_shape, index_shape, src_shape, dim in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + src_tensor = make_arg(src_shape) + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, src_tensor)) + + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index, scalar src (dim_size=5) + dim_size = 5 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, scalar_src)) + + # Test case: single-element tensor index, scalar src (dim_size=7) + dim_size = 7 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + scalar_src = make_arg(()) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, scalar_src)) + + # Test case: scalar index, single-element tensor src (dim_size=3) + dim_size = 3 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, src_1d)) + + # Test case: single-element tensor index, single-element tensor src (dim_size=10) + dim_size = 10 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + src_1d = make_arg((1,)) + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, src_1d)) + + +def sample_inputs_scatter_value(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + make_arg = functools.partial( + torch_testing.make_tensor, dtype=dtype, device=device, requires_grad=requires_grad + ) + + # Basic test cases for scatter.value + cases = [ + # (self_shape, index_shape, dim, value) + ((5, 5), (2, 3), 0, 1.0), # 2D scatter on dim=0 with scalar value + ((5, 5), (3, 2), 1, -2.5), # 2D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 0, False), # 3D scatter on dim=0 with scalar value + ((3, 4, 5), (2, 2, 3), 1, 3.14), # 3D scatter on dim=1 with scalar value + ((3, 4, 5), (2, 2, 3), 2, -1), # 3D scatter on dim=2 with scalar value + ((10,), (3,), 0, 5.0), # 1D scatter with scalar value + ] + + for self_shape, index_shape, dim, value in cases: + self_tensor = make_arg(self_shape) + # Create valid indices for the given dimension without duplication + index_buffer_shape = list(index_shape) + index_buffer_shape[dim] = self_shape[dim] + index_tensor = torch.rand(index_buffer_shape, device=device).argsort(dim=dim)[ + tuple(slice(None, d, None) for d in index_shape) + ] + yield opinfo_core.SampleInput(self_tensor, args=(dim, index_tensor, value)) + + # Additional test cases for scalar and single-element tensor combinations with dim=0 + # Test case: scalar index with scalar value (dim_size=6, value_type=torch.long) + dim_size = 6 + data_1d = make_arg((dim_size,)) + valid_index = torch.randint(0, dim_size, (), device=device, dtype=torch.long) + random_value = torch.randint(0, 10, (), device=device, dtype=torch.long).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index, random_value)) + + # Test case: single-element tensor index with scalar value (dim_size=8, value_type=torch.float) + dim_size = 8 + data_1d = make_arg((dim_size,)) + valid_index_1d = torch.randint(0, dim_size, (1,), device=device, dtype=torch.long) + random_value = torch.rand((), device=device, dtype=torch.float).item() + yield opinfo_core.SampleInput(data_1d, args=(0, valid_index_1d, random_value)) + + def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs ): @@ -2151,6 +2417,13 @@ def __init__(self): # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. OP_DB: List[opinfo_core.OpInfo] = [ + opinfo_core.OpInfo( + "bilinear", + op=torch.nn.functional.bilinear, + dtypes=common_dtype.floating_types(), + sample_inputs_func=sample_inputs_bilinear, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", @@ -2276,21 +2549,28 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + aten_name="fake_quantize_per_tensor_affine", + op=torch.fake_quantize_per_tensor_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_channel_affine", + aten_name="fake_quantize_per_channel_affine", + op=torch.fake_quantize_per_channel_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine, + supports_out=False, + ), opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", - dtypes=common_dtype.floating_types_and_half(), + dtypes=common_dtype.all_types_and_half(), rhs_make_tensor_kwargs=dict(exclude_zero=True), ), - opinfo_core.BinaryUfuncInfo( - "ops.aten.floor_divide.int", - aten_name="floor_divide", - op=torch.ops.aten.floor_divide, - dtypes=common_dtype.integral_types(), - # Create only positive inputs - lhs_make_tensor_kwargs=dict(low=0), - rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), - ), opinfo_core.OpInfo( "ops.aten.hamming_window", aten_name="hamming_window", @@ -2551,6 +2831,22 @@ def __init__(self): sample_inputs_func=sample_inputs_slice_scatter, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.scatter.src", + op=torch.ops.aten.scatter.src, + aten_name="scatter.src", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_src, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.scatter.value", + op=torch.ops.aten.scatter.value, + aten_name="scatter.value", + dtypes=common_dtype.all_types_and(torch.bfloat16, torch.half, torch.bool), + sample_inputs_func=sample_inputs_scatter_value, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten._softmax", op=torch.ops.aten._softmax, # pylint: disable=protected-access @@ -2725,6 +3021,13 @@ def __init__(self): sample_inputs_func=sample_inputs_upsample_trilinear3d_vec, supports_out=False, ), + opinfo_core.ReductionOpInfo( + "ops.prims.broadcast_in_dim.default", + op=torch.ops.prims.broadcast_in_dim.default, + dtypes=common_dtype.all_types(), + sample_inputs_func=sample_inputs_broadcast_in_dim, + supports_out=False, + ), opinfo_core.ReductionOpInfo( "ops.prims.var.default", nan_policy="propagate", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 7ba6f9d37f..a45050fb22 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -39,6 +39,7 @@ from torch.utils import _pytree as pytree import onnxscript +from onnxscript._internal import version_utils from tests.function_libs.torch_lib import ( error_reproduction, ops_test_common, @@ -98,7 +99,7 @@ def _should_skip_xfail_test_sample( class TestFunctionValidity(unittest.TestCase): @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_script_function_passes_checker( self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo @@ -109,10 +110,12 @@ def test_script_function_passes_checker( onnx.checker.check_function(function_proto) # type: ignore[attr-defined] @parameterized.parameterized.expand( - [(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] + [(info.op_info_name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS] ) def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo): func = torchlib_op_info.op + if not hasattr(func, "op_schema"): + raise AssertionError(f"Function {func.__name__} does not have op_schema attribute") schema = func.op_schema self.assertIsNotNone(schema) self.assertEqual(schema.name, func.name) @@ -200,7 +203,7 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("unbind") + or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7")) or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} ): diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index decaddddf4..99594ee17e 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -26,7 +26,6 @@ import numpy as np import onnx -import onnx_ir.passes.common as common_passes import onnxruntime as ort import onnxruntime.capi.onnxruntime_pybind11_state import pytest @@ -37,7 +36,6 @@ import onnxscript import onnxscript.evaluator from onnxscript import ir -from onnxscript.function_libs.torch_lib.ops import common as common_ops from tests.function_libs.torch_lib import error_reproduction T = TypeVar("T") @@ -412,19 +410,6 @@ def _format_model_and_input_information(onnx_model, inputs): } -def add_torchlib_common_imports(model: ir.Model) -> None: - """Hack to add torchlib common imports to the model.""" - - model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 - rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) - is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()) - model.functions[rank_func.identifier()] = rank_func - model.functions[is_scalar_func.identifier()] = is_scalar_func - removal_pass = common_passes.RemoveUnusedFunctionsPass() - assert removal_pass.in_place - removal_pass(model) - - def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) -> bool: """Checks if the dtype is compatible with the schema. @@ -593,7 +578,6 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args, proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - add_torchlib_common_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) try: diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 183b23cc4c..01653f74fe 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -48,7 +48,6 @@ from torch.testing._internal.opinfo import definitions as opinfo_definitions from typing_extensions import Self -from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib import _flags from onnxscript.function_libs.torch_lib.ops import core as core_ops from onnxscript.function_libs.torch_lib.ops import fft as fft_ops @@ -478,40 +477,13 @@ def _where_input_wrangler( fft_ops.aten__fft_r2c, tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)}, ), + TorchLibOpInfo("ops.aten._local_scalar_dense", core_ops.aten__local_scalar_dense), TorchLibOpInfo( - "ops.aten._local_scalar_dense", - core_ops.aten__local_scalar_dense, - ), - TorchLibOpInfo("ops.aten._log_softmax", core_ops.aten__log_softmax), - TorchLibOpInfo( - "ops.aten._log_softmax_half", - core_ops.aten__log_softmax_half, + "ops.aten._log_softmax", + core_ops.aten__log_softmax, tolerance={torch.float16: (1e-3, 1e-3)}, - ) - .xfail( - reason="PyTorch does not implement _log_softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("ops.aten._softmax", core_ops.aten__softmax), - TorchLibOpInfo("ops.aten._softmax_half", core_ops.aten__softmax_half) - .xfail( - reason="PyTorch does not implement _softmax for float16 on CPU", - dtypes=(torch.float16,), - enabled_if=version_utils.torch_older_than("2.2"), - ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.17"), - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo("all_dim", core_ops.aten_all_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), @@ -522,10 +494,7 @@ def _where_input_wrangler( reason="this overload requires dim to be a tuple", ), TorchLibOpInfo("allclose", core_ops.aten_allclose), - TorchLibOpInfo( - "all", - core_ops.aten_all, - ).skip( + TorchLibOpInfo("all", core_ops.aten_all).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -560,32 +529,14 @@ def _where_input_wrangler( reason="zero sized inputs cannot be compared", ), TorchLibOpInfo("addmv", core_ops.aten_addmv, tolerance={torch.float16: (2e-3, 2e-2)}), - TorchLibOpInfo( - "addr", - core_ops.aten_addr, - tolerance={torch.float16: (3e-3, 4e-3)}, - ), - TorchLibOpInfo( - "amax", - core_ops.aten_amax, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "amin", - core_ops.aten_amin, - input_wrangler=_amin_amax_input_wrangler, - ), - TorchLibOpInfo( - "any", - core_ops.aten_any, - ).skip( + TorchLibOpInfo("addr", core_ops.aten_addr, tolerance={torch.float16: (3e-3, 4e-3)}), + TorchLibOpInfo("amax", core_ops.aten_amax, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("amin", core_ops.aten_amin, input_wrangler=_amin_amax_input_wrangler), + TorchLibOpInfo("any", core_ops.aten_any).skip( matcher=lambda sample: len(sample.kwargs) != 0, reason="this Aten overload only support one tensor as input by design", ), - TorchLibOpInfo( - "any_dim", - core_ops.aten_any_dim, - ).skip( + TorchLibOpInfo("any_dim", core_ops.aten_any_dim).skip( matcher=lambda sample: not (len(sample.kwargs) > 0) or isinstance(sample.kwargs.get("dim"), tuple), reason="this Aten overload only support one tensor as input and {dim,keepdim} as kwargs by design. dim must be an integer", @@ -597,85 +548,58 @@ def _where_input_wrangler( TorchLibOpInfo("asin", core_ops.aten_asin), TorchLibOpInfo("asinh", core_ops.aten_asinh), TorchLibOpInfo("atan", core_ops.aten_atan), - TorchLibOpInfo("atan2", core_ops.aten_atan2, tolerance={torch.float16: (1e-3, 1e-3)}), + TorchLibOpInfo("atan2", core_ops.aten_atan2), TorchLibOpInfo("atanh", core_ops.aten_atanh), TorchLibOpInfo("atleast_1d", core_ops.aten_atleast_1d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_1d_Sequence", - core_ops.aten_atleast_1d_sequence, - ) + TorchLibOpInfo("atleast_1d_Sequence", core_ops.aten_atleast_1d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_2d", core_ops.aten_atleast_2d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_2d_Sequence", - core_ops.aten_atleast_2d_sequence, - ) + TorchLibOpInfo("atleast_2d_Sequence", core_ops.aten_atleast_2d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("atleast_3d", core_ops.aten_atleast_3d).skip( matcher=lambda sample: isinstance(sample.input, (list, tuple)), reason="takes single tensor as input", ), - TorchLibOpInfo( - "atleast_3d_Sequence", - core_ops.aten_atleast_3d_sequence, - ) + TorchLibOpInfo("atleast_3d_Sequence", core_ops.aten_atleast_3d_sequence) .skip( matcher=lambda sample: not isinstance(sample.input, (list, tuple)), reason="takes tensor sequences only", ) - .xfail( - enabled_if=version_utils.onnxruntime_older_than("1.16"), - reason=( - "fixme: [ONNXRuntimeError] : 1 : FAIL : This is an invalid model. Error: Duplicate definition of name (_0x9370ed0_rank)." - "https://github.com/microsoft/onnxscript/issues/960" - ), - ) .xfail( reason=( "fixme: ORT shape inference failed." "https://github.com/microsoft/onnxscript/issues/1007" - ), + ) ), TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}), TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True), + TorchLibOpInfo( + "bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (2e-5, 2e-5)} + ), TorchLibOpInfo( # This string is a unique ID. In extra_opinfo.py, we # also define test data for this ID with @@ -687,16 +611,10 @@ def _where_input_wrangler( ), TorchLibOpInfo("ops.aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p), TorchLibOpInfo("bitwise_and", core_ops.aten_bitwise_and), - TorchLibOpInfo("bitwise_left_shift_int16", core_ops.aten_bitwise_left_shift_int16), - TorchLibOpInfo("bitwise_left_shift_int32", core_ops.aten_bitwise_left_shift_int32), - TorchLibOpInfo("bitwise_left_shift_int64", core_ops.aten_bitwise_left_shift_int64), - TorchLibOpInfo("bitwise_left_shift_int8", core_ops.aten_bitwise_left_shift_int8), + TorchLibOpInfo("bitwise_left_shift", core_ops.aten_bitwise_left_shift), TorchLibOpInfo("bitwise_not", core_ops.aten_bitwise_not), TorchLibOpInfo("bitwise_or", core_ops.aten_bitwise_or), - TorchLibOpInfo("bitwise_right_shift_int16", core_ops.aten_bitwise_right_shift_int16), - TorchLibOpInfo("bitwise_right_shift_int32", core_ops.aten_bitwise_right_shift_int32), - TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), - TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), + TorchLibOpInfo("bitwise_right_shift", core_ops.aten_bitwise_right_shift), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), @@ -714,10 +632,7 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( - enabled_if=version_utils.torch_older_than("2.7"), - reason="Test for chunk is not configured for torch<2.7", - ), + TorchLibOpInfo("chunk", core_ops.aten_chunk), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, @@ -753,7 +668,6 @@ def _where_input_wrangler( TorchLibOpInfo("deg2rad", core_ops.aten_deg2rad), # TorchLibOpInfo("detach", core_ops.aten_detach), # detach is not in OP-TEST-DB TorchLibOpInfo("diagonal", core_ops.aten_diagonal), - TorchLibOpInfo("diagonal_bool", core_ops.aten_diagonal_bool), TorchLibOpInfo("div", core_ops.aten_div).skip( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="this variation does not take the rounding_mode argument", @@ -771,7 +685,6 @@ def _where_input_wrangler( # Numbers match sometimes but not other times reason="fixme: off-by-one. https://github.com/microsoft/onnxscript/issues/990", ), - TorchLibOpInfo("div_mode_int", core_ops.aten_div_mode_int), TorchLibOpInfo("dot", core_ops.aten_dot), TorchLibOpInfo( "empty", @@ -781,8 +694,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("einsum", core_ops.aten_einsum, input_wrangler=_einsum_input_wrangler) .xfail( - reason="fixme: PyTorch produces int64 output with int32 input", - dtypes=(torch.int32,), + reason="fixme: PyTorch produces int64 output with int32 input", dtypes=(torch.int32,) ) .xfail( reason="fixme: ONNX shape inference fails: https://github.com/onnx/onnx/issues/5739", @@ -805,6 +717,17 @@ def _where_input_wrangler( TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_channel_affine", + core_ops.aten_fake_quantize_per_channel_affine, + ).xfail( + reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values", + matcher=lambda sample: sample.args[1].dtype != torch.int32, + ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + core_ops.aten_fake_quantize_per_tensor_affine, + ), TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip).skip( reason="fixme: size 0 inputs are not handled yet", @@ -813,25 +736,18 @@ def _where_input_wrangler( TorchLibOpInfo("flatten", core_ops.aten_flatten), TorchLibOpInfo("floor", core_ops.aten_floor), TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide), - TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int), TorchLibOpInfo("fmod", core_ops.aten_fmod), TorchLibOpInfo("frac", core_ops.aten_frac), TorchLibOpInfo("full", core_ops.aten_full), - TorchLibOpInfo( - "full_like", - core_ops.aten_full_like, - ).skip( - enabled_if=ops_test_common.IS_MACOS, - reason="fixme: memory allocation issue on CI", + TorchLibOpInfo("full_like", core_ops.aten_full_like).skip( + enabled_if=ops_test_common.IS_MACOS, reason="fixme: memory allocation issue on CI" ), TorchLibOpInfo("gather", core_ops.aten_gather).skip( matcher=lambda sample: sample.input.numel() == 0 or sample.args[1].numel() == 0, reason="fixme: ORT does not support empty tensors as input", ), TorchLibOpInfo("ge", core_ops.aten_ge), - TorchLibOpInfo("ge_bool", core_ops.aten_ge_bool), TorchLibOpInfo("gt", core_ops.aten_gt), - TorchLibOpInfo("gt_bool", core_ops.aten_gt_bool), # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), @@ -845,9 +761,7 @@ def _where_input_wrangler( reason="this Aten overload only supports tensor(bool) as indices", ), TorchLibOpInfo( - "index_put", - core_ops.aten_index_put, - input_wrangler=_index_put_input_wrangler, + "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ) .skip( matcher=lambda sample: sample.args[0][0].dtype != torch.int64, @@ -887,20 +801,13 @@ def _where_input_wrangler( dtypes=(torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", ) - .xfail( - variant_name="tensor_overload", - dtypes=(torch.int64, torch.int32), + .skip( + matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32), reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854", - enabled_if=not version_utils.torch_older_than("2.2"), ), TorchLibOpInfo("log", core_ops.aten_log), TorchLibOpInfo("le", core_ops.aten_le), - TorchLibOpInfo("le_bool", core_ops.aten_le_bool), - TorchLibOpInfo( - "lerp", - core_ops.aten_lerp, - tolerance={torch.float16: (2e-3, 2e-1)}, - ), + TorchLibOpInfo("lerp", core_ops.aten_lerp, tolerance={torch.float16: (2e-3, 2e-1)}), TorchLibOpInfo("log10", core_ops.aten_log10), TorchLibOpInfo("log1p", core_ops.aten_log1p), TorchLibOpInfo( @@ -939,7 +846,6 @@ def _where_input_wrangler( TorchLibOpInfo("logdet", core_ops.aten_logdet), TorchLibOpInfo("logsumexp", core_ops.aten_logsumexp), TorchLibOpInfo("lt", core_ops.aten_lt), - TorchLibOpInfo("lt_bool", core_ops.aten_lt_bool), TorchLibOpInfo("masked_fill", core_ops.aten_masked_fill).xfail( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", @@ -955,19 +861,12 @@ def _where_input_wrangler( reason="values of matmul of [m, 0] and [0, n] matrices are undefined", ), TorchLibOpInfo("maximum", core_ops.aten_maximum), - TorchLibOpInfo("maximum_bool", core_ops.aten_maximum_bool), - TorchLibOpInfo( - "mean", - core_ops.aten_mean, - input_wrangler=_mean_input_wrangler, - ).skip( + TorchLibOpInfo("mean", core_ops.aten_mean, input_wrangler=_mean_input_wrangler).skip( matcher=lambda sample: sample.kwargs.get("dim") is not None, reason="this Aten overload only accept 1 inputs: self", ), TorchLibOpInfo( - "mean_dim", - core_ops.aten_mean_dim, - input_wrangler=_mean_input_wrangler, + "mean_dim", core_ops.aten_mean_dim, input_wrangler=_mean_input_wrangler ).skip( matcher=lambda sample: sample.kwargs.get("dim") is None, reason="this Aten overload can accept 2 inputs:(self, dim)", @@ -979,15 +878,11 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "min", - core_ops.aten_min, - ).skip( + TorchLibOpInfo("min", core_ops.aten_min).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), TorchLibOpInfo("minimum", core_ops.aten_minimum), - TorchLibOpInfo("minimum_bool", core_ops.aten_minimum_bool), TorchLibOpInfo("mm", core_ops.aten_mm).skip( matcher=lambda sample: torch.numel(sample.input) == 0, reason="values of matmul of [m, 0] and [0, n] matrices are undefined", @@ -996,39 +891,19 @@ def _where_input_wrangler( TorchLibOpInfo("mT", core_ops.aten_mT_complex, complex=True), TorchLibOpInfo("mul", core_ops.aten_mul), TorchLibOpInfo("mul", core_ops.aten_mul_complex, complex=True), - TorchLibOpInfo( - "mv", - core_ops.aten_mv, - tolerance={torch.float16: (3e-2, 1e-2)}, - ), + TorchLibOpInfo("mv", core_ops.aten_mv, tolerance={torch.float16: (3e-2, 1e-2)}), TorchLibOpInfo("narrow", core_ops.aten_narrow), TorchLibOpInfo("ops.aten.native_dropout", core_ops.aten_native_dropout), TorchLibOpInfo("ne", core_ops.aten_ne), TorchLibOpInfo("neg", core_ops.aten_neg), + TorchLibOpInfo("new_empty", core_ops.aten_new_empty, nondeterministic=True), TorchLibOpInfo( - "new_empty", - core_ops.aten_new_empty, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_empty_strided", - core_ops.aten_new_empty_strided, - nondeterministic=True, - ), - TorchLibOpInfo( - "new_full", - core_ops.aten_new_full, - ), - TorchLibOpInfo( - "new_ones", - core_ops.aten_new_ones, - ), - TorchLibOpInfo( - "new_zeros", - core_ops.aten_new_zeros, + "new_empty_strided", core_ops.aten_new_empty_strided, nondeterministic=True ), + TorchLibOpInfo("new_full", core_ops.aten_new_full), + TorchLibOpInfo("new_ones", core_ops.aten_new_ones), + TorchLibOpInfo("new_zeros", core_ops.aten_new_zeros), TorchLibOpInfo("nn.functional.celu", nn_ops.aten_celu), - TorchLibOpInfo("nn.functional.celu_type_promoted", nn_ops.aten_celu_type_promoted), TorchLibOpInfo( "nn.functional.cross_entropy", # use cross_entropy as test case instead of cross_entropy_loss (not in OPS_DB) @@ -1041,9 +916,7 @@ def _where_input_wrangler( reason="ONNX SoftmaxCrossEntropyLoss op only accept argument[target] as int type", ), TorchLibOpInfo( - "nn.functional.dropout", - core_ops.aten_dropout, - input_wrangler=_dropout_input_wrangler, + "nn.functional.dropout", core_ops.aten_dropout, input_wrangler=_dropout_input_wrangler ).skip( matcher=lambda sample: len(sample.kwargs) == 0 or sample.kwargs.get("p", 0.0) > 0.0, reason="dropout is random so the result not match", @@ -1104,10 +977,7 @@ def _where_input_wrangler( tolerance={torch.float16: (5e-2, 1e-2)}, ), TorchLibOpInfo("nn.functional.pad", nn_ops.aten_pad) - .skip( - variant_name="circular", - reason="fixme: ORT does not support the circular mode", - ) + .skip(variant_name="circular", reason="fixme: ORT does not support the circular mode") .skip( variant_name="replicate_negative", reason="fixme: The implementation for negative paddings is not correct", @@ -1115,34 +985,21 @@ def _where_input_wrangler( TorchLibOpInfo( "nn.functional.pixel_shuffle", core_ops.aten_pixel_shuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "nn.functional.pixel_unshuffle", core_ops.aten_pixel_unshuffle, - ) - .xfail( + ).xfail( dtypes=(torch.int32, torch.int64), reason="fixme: ONNX Runtime does not support int32/64 inputs", - ) - .xfail( - matcher=lambda sample: sample.input.numel() == 0, - reason="fixme: ORT does not support empty tensor as input", ), TorchLibOpInfo( "ops.aten.reflection_pad1d", nn_ops.aten_reflection_pad1d, - ).xfail( - dtypes=(torch.int64,), - reason="Torch not implement reflection_pad1d for int64.", - ), + ).xfail(dtypes=(torch.int64,), reason="Torch not implement reflection_pad1d for int64."), TorchLibOpInfo( "nn.functional.reflection_pad2d", nn_ops.aten_reflection_pad2d, @@ -1151,26 +1008,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "reflect"), reason="this Aten overload need args[1] == 'reflect' for pad mode", ), - TorchLibOpInfo( - "nn.functional.relu", - nn_ops.aten_relu, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "nn.functional.relu6", - nn_ops.aten_relu6, - ).xfail( - dtypes=(torch.int64,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT did not implement Relu for int64. https://github.com/microsoft/onnxruntime/issues/16654", - ), - TorchLibOpInfo( - "ops.aten.replication_pad1d", - nn_ops.aten_replication_pad1d, - ), + TorchLibOpInfo("nn.functional.relu", nn_ops.aten_relu), + TorchLibOpInfo("nn.functional.relu6", nn_ops.aten_relu6), + TorchLibOpInfo("ops.aten.replication_pad1d", nn_ops.aten_replication_pad1d), TorchLibOpInfo( "nn.functional.replication_pad2d", nn_ops.aten_replication_pad2d, @@ -1180,10 +1020,9 @@ def _where_input_wrangler( matcher=lambda sample: not (len(sample.args) > 1 and sample.args[1] == "replicate"), reason="this Aten overload need args[1] == 'replicate' for pad mode", ) - .xfail( + .skip( variant_name="replicate_negative", - enabled_if=not version_utils.torch_older_than("2.2"), - reason="fixme: negative padding is not implemented yet", + reason="fixme: The implementation for negative paddings is not correct. Potentially an ORT issue", ), TorchLibOpInfo( "nn.functional.replication_pad3d", @@ -1199,15 +1038,9 @@ def _where_input_wrangler( ), TorchLibOpInfo("nn.functional.selu", core_ops.aten_selu), TorchLibOpInfo( - "nn.functional.mse_loss", - nn_ops.aten_mse_loss, - input_wrangler=_mse_loss_input_wrangler, + "nn.functional.mse_loss", nn_ops.aten_mse_loss, input_wrangler=_mse_loss_input_wrangler ), - TorchLibOpInfo( - "nonzero", - core_ops.aten_nonzero, - input_wrangler=_nonzero_input_wrangler, - ) + TorchLibOpInfo("nonzero", core_ops.aten_nonzero, input_wrangler=_nonzero_input_wrangler) .xfail( matcher=lambda sample: sample.kwargs.get("as_tuple"), reason="as_tuple=True is not supported", @@ -1270,17 +1103,41 @@ def _where_input_wrangler( nondeterministic=True, ), TorchLibOpInfo("ops.aten.randn", core_ops.aten_randn, nondeterministic=True).xfail( - dtypes=(torch.float16,), - reason="fixme: Shape inference error", + dtypes=(torch.float16,), reason="fixme: Shape inference error" ), TorchLibOpInfo("ops.aten.randn_like", core_ops.aten_randn_like, nondeterministic=True), TorchLibOpInfo("rad2deg", core_ops.aten_rad2deg), TorchLibOpInfo("reciprocal", core_ops.aten_reciprocal), - TorchLibOpInfo( - "remainder", - core_ops.aten_remainder, - ), + TorchLibOpInfo("remainder", core_ops.aten_remainder), TorchLibOpInfo("repeat", core_ops.aten_repeat), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_self_int) + .skip( + matcher=lambda sample: not isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is a Tensor"), + ) + .skip(dtypes=(torch.bool,), reason="bool not supported") + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), + TorchLibOpInfo("repeat_interleave", core_ops.aten_repeat_interleave_Tensor) + .skip( + matcher=lambda sample: isinstance(sample.kwargs.get("repeats", None), int), + reason=("ignore cases when repeasts is an int"), + ) + .skip(dtypes=(torch.bool,), reason="bool not supported") + .skip( + matcher=lambda sample: sample.kwargs.get("dim") is None, + reason="fixme: conversion not implemented if dim is None", + ) + .skip( + matcher=lambda sample: sample.input.numel() == 0, + reason="fixme: conversion not implemented when input tensor is empty", + ), TorchLibOpInfo("reshape", core_ops.aten_reshape), TorchLibOpInfo("resolve_conj", core_ops.aten_resolve_conj), TorchLibOpInfo("resolve_neg", core_ops.aten_resolve_neg), @@ -1302,14 +1159,9 @@ def _where_input_wrangler( complex=True, ), TorchLibOpInfo( - "ops.aten.scalar_tensor", - core_ops.aten_scalar_tensor_complex, - complex=True, + "ops.aten.scalar_tensor", core_ops.aten_scalar_tensor_complex, complex=True ), - TorchLibOpInfo( - "scatter_add", - core_ops.aten_scatter_add, - ) + TorchLibOpInfo("scatter_add", core_ops.aten_scatter_add) .xfail( matcher=lambda sample: len(sample.input.shape) == 0, reason="fixme: Rank(0) input will lead ORT failed due to different rank(result) in if-else branch. https://github.com/onnx/onnx/issues/4986", @@ -1358,48 +1210,10 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: Tensor-likes are not close. Tests pass for float32.", ), - TorchLibOpInfo( - "split_with_sizes", - core_ops.aten_split_with_sizes, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), - TorchLibOpInfo( - "split", - core_ops.aten_split, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: ORT failed to produce the correct argument type: https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ) - .xfail( - variant_name="list_args", - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("split_with_sizes", core_ops.aten_split_with_sizes), + TorchLibOpInfo("split", core_ops.aten_split), TorchLibOpInfo("sqrt", core_ops.aten_sqrt), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1409,11 +1223,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze_dim", - core_ops.aten_squeeze_dim_complex, - complex=True, - ) + TorchLibOpInfo("squeeze_dim", core_ops.aten_squeeze_dim_complex, complex=True) .skip( matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)), reason="this Aten overload only support one tensor as input and one int as args by design", @@ -1423,10 +1233,7 @@ def _where_input_wrangler( and sample.input.shape[sample.args[0]] != 1, reason="this Aten overload only support squeeze dim with size 1", ), - TorchLibOpInfo( - "squeeze", - core_ops.aten_squeeze, - ).skip( + TorchLibOpInfo("squeeze", core_ops.aten_squeeze).skip( matcher=lambda sample: len(sample.args) != 0, reason="this Aten overload only support one tensor as input by design", ), @@ -1435,20 +1242,14 @@ def _where_input_wrangler( TorchLibOpInfo("sub", core_ops.aten_sub, tolerance={torch.float16: (2e-3, 1e-3)}), TorchLibOpInfo("sub", core_ops.aten_sub_complex, complex=True), # TorchLibOpInfo("sym_size", core_ops.aten_sym_size), # no test case in OPS_DB - TorchLibOpInfo( - "t", - core_ops.aten_t, - ).xfail( + TorchLibOpInfo("t", core_ops.aten_t).xfail( enabled_if=not _flags.EXPERIMENTAL_PREFER_TRACING, reason="fixme: ORT Graph attribute inferencing failed on rank-1 input. https://github.com/onnx/onnx/issues/4986", test_class_name="TestOutputConsistencyFullGraph", ), TorchLibOpInfo("tan", core_ops.aten_tan), TorchLibOpInfo("tanh", core_ops.aten_tanh), - TorchLibOpInfo( - "tile", - core_ops.aten_tile, - ).skip( + TorchLibOpInfo("tile", core_ops.aten_tile).skip( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape) or not sample.input.shape, reason="fixme: Logic not implemented for size 0 inputs in op.Reshape", @@ -1476,19 +1277,7 @@ def _where_input_wrangler( reason="fixme: ORT does not have an implementation of Trilu for int32.", ), TorchLibOpInfo("trunc", core_ops.aten_trunc), - TorchLibOpInfo( - "unbind", - core_ops.aten_unbind, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", - ), + TorchLibOpInfo("unbind", core_ops.aten_unbind), TorchLibOpInfo("unflatten", core_ops.aten_unflatten), TorchLibOpInfo("unfold", core_ops.aten_unfold), TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold), @@ -1507,10 +1296,7 @@ def _where_input_wrangler( ), TorchLibOpInfo("xlogy", special_ops.aten_special_xlogy), TorchLibOpInfo("zeros", core_ops.aten_zeros), - TorchLibOpInfo( - "arange_start_step", - core_ops.aten_arange_start_step, - ) + TorchLibOpInfo("arange_start_step", core_ops.aten_arange_start_step) .skip( matcher=lambda sample: len(sample.args) != 2, reason="arange_start_step overload takes three arguments (input, start, step)", @@ -1520,10 +1306,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange_start", - core_ops.aten_arange_start, - ) + TorchLibOpInfo("arange_start", core_ops.aten_arange_start) .skip( matcher=lambda sample: len(sample.args) != 1, reason="arange_start overload takes two arguments (input, start)", @@ -1533,10 +1316,7 @@ def _where_input_wrangler( reason="dtype needs to be specified for non-float tensors", dtypes=(torch.float16, torch.int64, torch.int32), ), - TorchLibOpInfo( - "arange", - core_ops.aten_arange, - ) + TorchLibOpInfo("arange", core_ops.aten_arange) .xfail( dtypes=(torch.int32,), reason="fixme: output shape mismatch in edge cases. https://github.com/microsoft/onnxscript/issues/974", @@ -1559,10 +1339,7 @@ def _where_input_wrangler( TorchLibOpInfo( "as_strided", core_ops.aten_as_strided, - ).xfail( - variant_name="partial_views", - reason="ONNX doesn't have partial view for tensor", - ), + ).xfail(variant_name="partial_views", reason="ONNX doesn't have partial view for tensor"), TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", @@ -1582,19 +1359,13 @@ def _where_input_wrangler( tolerance={torch.float32: (2e-4, 9e-4)}, ), TorchLibOpInfo("empty_like", core_ops.aten_empty_like, nondeterministic=True), - TorchLibOpInfo( - "grid_sampler_2d", - core_ops.aten_grid_sampler_2d, - ) + TorchLibOpInfo("grid_sampler_2d", core_ops.aten_grid_sampler_2d) .skip( # Torch implemented this using the cubic convolution algorithm with alhpa=-0.75, might be different than ORT matcher=lambda sample: sample.args[1] == 2, reason="fixme: 'bicubic' mode in ORT implemented differently with Torch", ) - .skip( - dtypes=(torch.float16,), - reason="fixme: Accuracy is not high enough", - ), + .skip(dtypes=(torch.float16,), reason="fixme: Accuracy is not high enough"), TorchLibOpInfo( "nn.functional.group_norm", nn_ops.aten_group_norm, @@ -1638,6 +1409,10 @@ def _where_input_wrangler( dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,), reason="fixme: test is unstable on macosx, windows", ), + TorchLibOpInfo("logical_and", core_ops.aten_logical_and), + TorchLibOpInfo("logical_not", core_ops.aten_logical_not), + TorchLibOpInfo("logical_or", core_ops.aten_logical_or), + TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor), TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}), TorchLibOpInfo("max_dim", core_ops.aten_max_dim) .xfail( @@ -1651,10 +1426,7 @@ def _where_input_wrangler( or (len(sample.args) > 0 and not isinstance(sample.args[0], int)), reason="this ATen overload only support one tensor as input and another int as args", ), - TorchLibOpInfo( - "max", - core_ops.aten_max, - ).skip( + TorchLibOpInfo("max", core_ops.aten_max).skip( matcher=lambda sample: len(sample.args) > 0, reason="this ATen overload only supports one tensor as input by design", ), @@ -1712,8 +1484,7 @@ def _where_input_wrangler( reason="fixme: ORT only supports BatchNorm less than opset14", ), TorchLibOpInfo( - "ops.aten._native_batch_norm_legit.no_stats", - core_ops.aten__native_batch_norm_no_stats, + "ops.aten._native_batch_norm_legit.no_stats", core_ops.aten__native_batch_norm_no_stats ), TorchLibOpInfo( "ops.aten._native_batch_norm_legit_functional", @@ -1734,10 +1505,6 @@ def _where_input_wrangler( "ops.aten.native_group_norm", core_ops.aten_native_group_norm, tolerance={torch.float16: (1e-2, 7e-3)}, - ).xfail( - dtypes=(torch.float16,), - reason="fixme: 'GroupNormKernelImpl' not implemented for 'Half' in nightly and weekly", - enabled_if=version_utils.torch_older_than("2.2"), ), TorchLibOpInfo( "native_layer_norm", @@ -1819,9 +1586,7 @@ def _where_input_wrangler( tolerance={torch.float16: (1e-2, 1e-3)}, ), TorchLibOpInfo( - "ops.aten.conv3d", - core_ops.aten_conv3d, - tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + "ops.aten.conv3d", core_ops.aten_conv3d, tolerance={torch.float32: (3.7e-5, 1.8e-4)} ), TorchLibOpInfo("nn.functional.gelu", nn_ops.aten_gelu), TorchLibOpInfo("nn.functional.glu", nn_ops.aten_glu), @@ -1902,11 +1667,6 @@ def _where_input_wrangler( nn_ops.aten_scaled_dot_product_attention, tolerance={torch.float32: (3e-4, 1.5e-5)}, ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype == torch.bool, - reason="this overload takes a non-boolean mask", - ) .skip( matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, reason="dropout is random so the results do not match", @@ -1915,6 +1675,12 @@ def _where_input_wrangler( dtypes=(torch.float16,), reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", test_class_name="TestOutputConsistencyFullGraph", + ) + .xfail( + matcher=lambda sample: len(sample.input.shape) != 4 + or len(sample.args[0].shape) != 4 + or len(sample.args[1].shape) != 4, + reason="torch sdpa is expected to pass in 4d q, k, and v.", ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", @@ -1923,15 +1689,7 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( - device_type="cpu", - reason="_scaled_dot_product_flash_attention only supports CUDA", - ), + ).skip(device_type="cpu", reason="_scaled_dot_product_flash_attention only supports CUDA"), TorchLibOpInfo( "ops.aten._scaled_dot_product_efficient_attention", nn_ops.aten__scaled_dot_product_efficient_attention, @@ -1939,34 +1697,10 @@ def _where_input_wrangler( # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, compare_shape_only_for_output=(1, 2, 3), - ) - .skip( - enabled_if=version_utils.torch_older_than("2.1"), - reason="The operator is not supported in older version.", - ) - .skip( + ).skip( enabled_if=not torch.cuda.is_available(), reason="_scaled_dot_product_efficient_attention only supports CUDA", ), - TorchLibOpInfo( - "nn.functional.scaled_dot_product_attention_bool_mask", - nn_ops.aten_scaled_dot_product_attention_bool_mask, - tolerance={torch.float32: (3e-4, 1.5e-5)}, - ) - .skip( - matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None - and attn_mask.dtype != torch.bool, - reason="this overload takes a boolean mask", - ) - .skip( - matcher=lambda sample: sample.kwargs.get("dropout_p") != 0.0, - reason="dropout is random so the results do not match", - ) - .xfail( - dtypes=(torch.float16,), - reason="fixme: ORT failed. https://github.com/microsoft/onnxruntime/issues/16438", - test_class_name="TestOutputConsistencyFullGraph", - ), TorchLibOpInfo( "ops.aten.upsample_bilinear2d.default", nn_ops.aten_upsample_bilinear2d, @@ -1986,10 +1720,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bilinear2d.vec", - nn_ops.aten_upsample_bilinear2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( "ops.aten.upsample_bicubic2d.default", nn_ops.aten_upsample_bicubic2d, @@ -2009,10 +1740,7 @@ def _where_input_wrangler( # Shape-only comparison is the appropriate testing approach for this case. compare_shape_only_for_output=(0,), ), - TorchLibOpInfo( - "ops.aten.upsample_bicubic2d.vec", - nn_ops.aten_upsample_bicubic2d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo( "ops.aten.upsample_linear1d", nn_ops.aten_upsample_linear1d, @@ -2021,38 +1749,14 @@ def _where_input_wrangler( and sample.kwargs.get("scales") is not None, reason="fixme: align_corners=False output mismatch when scales are provided", ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d", - nn_ops.aten_upsample_nearest1d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest1d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d", - nn_ops.aten_upsample_nearest2d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest2d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d", - nn_ops.aten_upsample_nearest3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_nearest3d.vec", - nn_ops.aten_upsample_nearestnd_vec, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.default", - nn_ops.aten_upsample_trilinear3d, - ), - TorchLibOpInfo( - "ops.aten.upsample_trilinear3d.vec", - nn_ops.aten_upsample_trilinear3d_vec, - ), + TorchLibOpInfo("ops.aten.upsample_nearest1d", nn_ops.aten_upsample_nearest1d), + TorchLibOpInfo("ops.aten.upsample_nearest1d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest2d", nn_ops.aten_upsample_nearest2d), + TorchLibOpInfo("ops.aten.upsample_nearest2d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_nearest3d", nn_ops.aten_upsample_nearest3d), + TorchLibOpInfo("ops.aten.upsample_nearest3d.vec", nn_ops.aten_upsample_nearestnd_vec), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.default", nn_ops.aten_upsample_trilinear3d), + TorchLibOpInfo("ops.aten.upsample_trilinear3d.vec", nn_ops.aten_upsample_trilinear3d_vec), TorchLibOpInfo("ones_like", core_ops.aten_ones_like), TorchLibOpInfo( "roll", @@ -2070,10 +1774,7 @@ def _where_input_wrangler( core_ops.aten_scatter_reduce, input_wrangler=_scatter_reduce_input_wrangler, ) - .xfail( - variant_name="mean", - reason="ONNX doesn't support reduce='mean' option", - ) + .xfail(variant_name="mean", reason="ONNX doesn't support reduce='mean' option") .xfail( variant_name="prod", dtypes=(torch.float16, torch.float64), @@ -2100,8 +1801,18 @@ def _where_input_wrangler( reason="onnxruntime does not support ml_dtypes.bfloat16", ), TorchLibOpInfo("ops.aten.slice_scatter", core_ops.aten_slice_scatter), + TorchLibOpInfo("ops.aten.scatter.src", core_ops.aten_scatter_src), + TorchLibOpInfo("ops.aten.scatter.value", core_ops.aten_scatter_value), TorchLibOpInfo("slice", core_ops.aten_slice), TorchLibOpInfo("slice", core_ops.aten_slice_complex, complex=True), + TorchLibOpInfo( + "ops.aten.stft", # Custom from extra_opinfo + core_ops.aten_stft, + tolerance={torch.float32: (3.7e-5, 1.8e-4)}, + ).xfail( + dtypes=(torch.float16,), + reason="RuntimeError: MKL FFT doesn't support tensors of type: Half", + ), TorchLibOpInfo( "sum", core_ops.aten_sum_dim_IntList, @@ -2131,6 +1842,7 @@ def _where_input_wrangler( "Our implementation is based on that for CUDA" ), ), + TorchLibOpInfo("ops.prims.broadcast_in_dim.default", prims_ops.prims_broadcast_in_dim), TorchLibOpInfo( "ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)} ), @@ -2144,40 +1856,13 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "atleast_1d", ("atleast_1d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_2d", ("atleast_2d_Sequence",)) ops_test_common.duplicate_opinfo(OPS_DB, "atleast_3d", ("atleast_3d_Sequence",)) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_left_shift", - ( - "bitwise_left_shift_int8", - "bitwise_left_shift_int16", - "bitwise_left_shift_int32", - "bitwise_left_shift_int64", - ), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "bitwise_right_shift", - ( - "bitwise_right_shift_int8", - "bitwise_right_shift_int16", - "bitwise_right_shift_int32", - "bitwise_right_shift_int64", - ), -) ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) -ops_test_common.duplicate_opinfo(OPS_DB, "diagonal", ("diagonal_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode", "div_mode_int")) -ops_test_common.duplicate_opinfo(OPS_DB, "ge", ("ge_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "gt", ("gt_bool",)) +ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "le", ("le_bool",)) -ops_test_common.duplicate_opinfo(OPS_DB, "lt", ("lt_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "maximum", ("maximum_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) -ops_test_common.duplicate_opinfo(OPS_DB, "minimum", ("minimum_bool",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.pad", @@ -2187,20 +1872,6 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.scaled_dot_product_attention", - ("nn.functional.scaled_dot_product_attention_bool_mask",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, - "nn.functional.celu", - ("nn.functional.celu_type_promoted",), -) -ops_test_common.duplicate_opinfo( - OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",) -) -ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",)) ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",)) ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",)) ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))