Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
07f3e4c
chore(deps): bump ruff from 0.12.10 to 0.12.11 in /requirements/lintr…
dependabot[bot] Sep 2, 2025
8974f5e
Implements repeat_interleave (#2477)
xadupre Sep 2, 2025
7b04774
[torchlib] Modify aten_unbind to use None for split_sizes (#2536)
justinchuby Sep 3, 2025
54de741
Refactor rewrite rules into the rewriter.rules namespace (#2531)
justinchuby Sep 3, 2025
a925acc
[torchlib] Improve pixel_shuffle (#2537)
justinchuby Sep 3, 2025
456a6bc
Update constant folding behavior for large tensors (#2488)
justinchuby Sep 4, 2025
fc792e4
[torchlib] Improve handling of SymInt[] (#2522)
justinchuby Sep 4, 2025
d98e3dd
[torch] Fix incorrect Concat when processing dynamic paddings (#2540)
kistenklaus Sep 4, 2025
1934901
Add test for dynamic padding (#2541)
justinchuby Sep 4, 2025
e76bfe0
[Reland] Update SplitToSequence in constant folding (#2544)
titaiwangms Sep 5, 2025
5762a69
[Rewriter]: add fusion rules for successive Min/Max patterns (#2500)
AyoubMDL Sep 5, 2025
f5f9e6a
Update onnx-weekly version to 1.20.0 (#2545)
justinchuby Sep 5, 2025
d0fb218
[rewriter] Unify reshape flatten ops (#2518)
Johansmm Sep 5, 2025
9036fab
[Rewriter] Support specifying node name in rewrites (#2474)
AyoubMDL Sep 5, 2025
cec5396
Do not try to fold op.SplitToSequence when split is `None` (#2550)
titaiwangms Sep 8, 2025
647b22a
Bump version to 0.5.0 (#2538)
justinchuby Sep 9, 2025
0e79b62
[Rewriter] Add fuse batchnorm to default rules (#2553)
AyoubMDL Sep 9, 2025
821015a
Add Conv-Affine(Mul+Add) and hardswish fusion (#2472)
Stonesjtu Sep 10, 2025
710d597
Fix rewriter and CI tests for the latest onnx-ir version (#2554)
justinchuby Sep 10, 2025
50d7e87
[torchlib] Mark atan2 as trace_only and map NaN to 0 (#2557)
justinchuby Sep 11, 2025
366f7be
[torchlib] Fix repeat_interleave when repeats is a symbolic tensor (#…
xadupre Sep 12, 2025
8ed3521
Support `enable_gqa` and only support 4D Q, K, and V (#2558)
titaiwangms Sep 12, 2025
39f1015
[torchlib] Implement torch.ops.prims.broadcast_in_dim.default (#2382)
Copilot Sep 12, 2025
8944f04
Bump version from 0.5.0 to 0.5.1 (#2559)
justinchuby Sep 12, 2025
92633a6
Remove CheckerPass from ort_fusion (#2560)
justinchuby Sep 12, 2025
a70ee8d
Use ir.val to replace ir.Input (#2556)
justinchuby Sep 12, 2025
ea79022
chore(deps): bump ruff from 0.12.11 to 0.13.0 in /requirements/lintru…
dependabot[bot] Sep 15, 2025
f529292
Bump version from 0.5.1 to 0.5.2 (#2565)
justinchuby Sep 16, 2025
3156bed
[torchlib] Fix aten_gather to correctly handle scalar indices (#2566)
linshokaku Sep 18, 2025
79afb87
[rewriter] Remove generic pattern matcher (#2567)
justinchuby Sep 19, 2025
27c7f09
chore(deps): bump ruff from 0.13.0 to 0.13.1 in /requirements/lintrun…
dependabot[bot] Sep 22, 2025
f54cf47
Add GQA fusion to ONNX fusions (#2524)
gramalingam Sep 23, 2025
e67eeef
[torchlib] Simplify linalg_vector_norm to remove the redundant Abs (#…
justinchuby Sep 23, 2025
7e45333
[torchlib] Add trace_only flag to aten_copy, aten_tril, aten_triu (#2…
justinchuby Sep 24, 2025
168fd8a
Bump version from 0.5.2 to 0.5.3 (#2571)
justinchuby Sep 24, 2025
dddf0c2
Fix Onnx 23 Rotary Fusion (#2576)
gramalingam Sep 26, 2025
df8f706
[torchlib] Support integers in logical_and/or ops and update other lo…
justinchuby Sep 29, 2025
94fb24f
Record names of contributing values in the constant folding pass (#2575)
justinchuby Sep 30, 2025
3a26097
Merge output shape with input shape instead of override (#2578)
wodesuck Sep 30, 2025
3505420
[torchlib] Add back operator and/or (#2590)
justinchuby Sep 30, 2025
9b54ad5
Extend utilities for checking a scalar value (#2587)
gramalingam Sep 30, 2025
7227655
Merge input and output shape when removing identity (#2588)
wodesuck Sep 30, 2025
a1db753
Add NaN handling in softmax pattern in SDPA fusion (#2593)
gramalingam Oct 1, 2025
09bbd27
Remove usages of ir.Input in test (#2591)
justinchuby Oct 1, 2025
88b03d8
Improve aten_floor_divide for int inputs (#2592)
justinchuby Oct 1, 2025
149d567
Fix collapse slices rewrite rules to handle unknown dims (#2583)
justinchuby Oct 1, 2025
929a7f2
Expose the should_fold option to optimize() (#2594)
justinchuby Oct 1, 2025
81f8444
Bump version from 0.5.3 to 0.5.4 (#2595)
justinchuby Oct 1, 2025
b7ccc86
Update torch api error message to include value names (#2599)
justinchuby Oct 3, 2025
30ae54b
Remove beartype (#2603)
justinchuby Oct 3, 2025
897345d
Separated implementation of aten::scatter overloads (#2605)
linshokaku Oct 6, 2025
aa2cf4a
chore(deps): bump onnx-weekly from 1.20.0.dev20250901 to 1.20.0.dev20…
dependabot[bot] Oct 6, 2025
6718ef0
Enhanced type annotations and simplified implementation of scatter.va…
linshokaku Oct 7, 2025
7f3325b
support for scalar args to aten::scatter (#2613)
linshokaku Oct 7, 2025
a106bad
chore(deps): bump ruff from 0.13.1 to 0.13.2 in /requirements/lintrun…
dependabot[bot] Oct 7, 2025
8e4d41d
[torchlib] Implement aten_bilinear function using Einsum (#2574)
Copilot Oct 7, 2025
e8d906a
chore(deps): bump actions/setup-python from 5 to 6 (#2551)
dependabot[bot] Oct 7, 2025
256be11
chore(deps): bump editorconfig-checker from 3.2.0 to 3.4.0 in /requir…
dependabot[bot] Oct 7, 2025
8e449da
chore(deps): bump types-pyyaml from 6.0.12.20250402 to 6.0.12.2025091…
dependabot[bot] Oct 7, 2025
4eaf36d
chore(deps): bump pylint from 3.3.6 to 3.3.9 in /requirements/lintrun…
dependabot[bot] Oct 7, 2025
075fc4d
Simplify aten_unbind when shape is static (#2597)
justinchuby Oct 7, 2025
9ab7527
Consolidate overloads in torchlib (#2604)
justinchuby Oct 8, 2025
cb6f873
chore(deps): bump onnxruntime from 1.23.0.dev20250517001 to 1.23.1 in…
dependabot[bot] Oct 9, 2025
59c3d32
[torchlib] Fix implementations for bitwise_* overloads (#2618)
justinchuby Oct 9, 2025
28a8f56
Fix constant in constant folding (#2622)
titaiwangms Oct 10, 2025
071ff1e
Create helper for comparing semantic equivalence of shapes (#2620)
justinchuby Oct 10, 2025
32a61f4
[torchlib] Deprecate Rank and IsScalar (#2624)
justinchuby Oct 13, 2025
dd14682
[torchlib] Fix operator add (#2630)
justinchuby Oct 13, 2025
f44b314
Allow `opset_version` to be set explicitly when exporting (#2615)
NoRaincheck Oct 14, 2025
b6a2d02
Remove redundant registration of operator::add (#2631)
justinchuby Oct 14, 2025
811937c
Merge shapes only in identity op and nodel-level shape inference (#2623)
titaiwangms Oct 14, 2025
75b3d42
Fix GQA fusion to produce present key/value (#2634)
justinchuby Oct 15, 2025
8089bc7
Add RMS Normalization rule variant (#2638)
gramalingam Oct 16, 2025
dd8cb69
[DRAFT] Extend GQA fusion for Gemma3 (#2639)
gramalingam Oct 16, 2025
55f5b82
Bump version to 0.5.5 (#2640)
titaiwangms Oct 16, 2025
80f28c9
Add Gemma3 GQA fusion test case (#2642)
gramalingam Oct 17, 2025
8a94ad6
[Rewriter]: introduce remove_optional_bias (#2635)
AyoubMDL Oct 18, 2025
04a9da4
Unsqueeze unbatched input of avg_pool (#2646)
wodesuck Oct 27, 2025
8c0b72b
Add a verbose mode to torch api for external data save (#2643)
justinchuby Oct 27, 2025
bb75e2b
Support math trunc (#2653)
titaiwangms Oct 27, 2025
3334ba1
chore(deps): bump actions/upload-artifact from 4 to 5 (#2656)
dependabot[bot] Oct 27, 2025
b84d595
chore(deps): bump onnx-weekly from 1.20.0.dev20251006 to 1.20.0.dev20…
dependabot[bot] Oct 28, 2025
ad83914
chore(deps): bump ruff from 0.13.2 to 0.14.2 in /requirements/lintrun…
dependabot[bot] Oct 28, 2025
69025f7
[version converter] Fix DFT opset 20 (#2659)
titaiwangms Oct 28, 2025
45b5189
[torchlib] Fix concat when input tensor has shape `(0,)` (#2661)
justinchuby Oct 29, 2025
9e0366c
Create initializers not constant nodes in constant folding pass (#2650)
titaiwangms Oct 29, 2025
fe50b83
Add support for traced if statements in onnxscript script (#2644)
gramalingam Oct 29, 2025
647754f
Extend GQA fusion for Qwen (#2662)
gramalingam Oct 29, 2025
ee9a6e8
Declare support for Python 3.14 in pyproject.toml (#2663)
justinchuby Oct 29, 2025
3846705
Clear initializers in constant folding pass (#2668)
justinchuby Oct 31, 2025
8a7de40
Add GQA fusion test cases (#2669)
gramalingam Oct 31, 2025
9b699ae
Improve constant folding error messages and allow Identity to skip sh…
justinchuby Oct 31, 2025
5be9d3b
Fix scalar constant check (#2672)
gramalingam Oct 31, 2025
93783ee
Capture rewrite rule name as metadata (#2675)
gramalingam Nov 3, 2025
1a27df1
feat: implement LSTM and GRU operators for torchlib (#2674)
ombrdr47 Nov 4, 2025
d80575d
Keep creating constants when constants are folded inside ir.Function …
titaiwangms Nov 5, 2025
8845fb2
Avoid initializer name collision in _fuse_batchnorm.py (#2680)
titaiwangms Nov 5, 2025
971f9bb
Merge metadata props in rewriter (#2682)
gramalingam Nov 5, 2025
478acf7
[torchlib] Fix unbind.int if num_outputs=1 (#2684)
sebimarkgraf Nov 6, 2025
ea8cb3e
Add option to clear metadata in ort fusion (#2685)
gramalingam Nov 7, 2025
70e751a
Implement SDPA via MHA (#2683)
gramalingam Nov 7, 2025
a1be5c8
[torchlib] Fix mod on SymInt (#2686)
justinchuby Nov 7, 2025
10e541e
Implement aten.stft (#2645)
moatom Nov 9, 2025
1dd9d04
Add converter for unique_consecutive (#2694)
xadupre Nov 10, 2025
4042df3
Add missing output_size kwarg to repeat_interleave (#2691)
yuanyao-nv Nov 10, 2025
cfb52e2
chore(deps): bump ruff from 0.14.2 to 0.14.3 in /requirements/lintrun…
dependabot[bot] Nov 10, 2025
f1a6ec4
chore(deps): bump editorconfig-checker from 3.4.0 to 3.4.1 in /requir…
dependabot[bot] Nov 10, 2025
cba1325
chore(deps): bump onnx-weekly from 1.20.0.dev20251027 to 1.21.0.dev20…
dependabot[bot] Nov 10, 2025
97513c7
Bump version (#2702)
gramalingam Nov 12, 2025
c1bfdfc
Utility and example for custom op expansion (#2701)
gramalingam Nov 12, 2025
53af800
add converter for aten::sym_storage_offset (#2697)
xadupre Nov 18, 2025
6247ac1
Implement ONNX export for `fake_quantize_per_*_affine` (#2696)
ruro Nov 18, 2025
9dbf685
Provide inplace replacement util (#2708)
gramalingam Nov 18, 2025
597d5f7
Fix aten_unbind for torch >= 2.7 dynamo export (#2719)
afshin-paydar Dec 1, 2025
7dab831
chore(deps): bump ruff from 0.14.3 to 0.14.6 in /requirements/lintrun…
dependabot[bot] Dec 1, 2025
c8bfe71
chore(deps): bump actions/checkout from 5 to 6 (#2715)
dependabot[bot] Dec 1, 2025
45ba02d
chore(deps): bump onnxruntime from 1.23.1 to 1.23.2 in /requirements/…
dependabot[bot] Dec 1, 2025
3364ada
chore(deps): bump github/codeql-action from 3 to 4 (#2626)
dependabot[bot] Dec 1, 2025
3e7d9fb
chore(deps): bump ruff from 0.14.6 to 0.14.7 in /requirements/lintrun…
dependabot[bot] Dec 3, 2025
bbe9c2b
Don't constant fold Quantize/DequantizeLinear nodes by default (#2713)
ruro Dec 3, 2025
5583f96
support opset23 (#2725)
titaiwangms Dec 9, 2025
a3883a6
Update aten_index_put implementation (#2712)
gramalingam Dec 11, 2025
da967e3
[torchlib] Fix and implement overloads for aten::remainder (#2727)
justinchuby Dec 12, 2025
1f84cc7
Merge branch 'fix-2219' into main
crypto-a Dec 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/codeql-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ jobs:

steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6

# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v3
uses: github/codeql-action/init@v4
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
Expand All @@ -59,7 +59,7 @@ jobs:
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v3
uses: github/codeql-action/autobuild@v4

# ℹ️ Command-line programs to run using the OS shell.
# 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
Expand All @@ -72,4 +72,4 @@ jobs:
# ./location_of_script_within_repo/buildscript.sh

- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v3
uses: github/codeql-action/analyze@v4
8 changes: 4 additions & 4 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
pull-requests: write

steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: misspell # Check spelling
uses: reviewdog/action-misspell@v1
with:
Expand All @@ -43,9 +43,9 @@ jobs:
permissions:
security-events: write
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
# Version range or exact version of Python to use, using SemVer's version range syntax. Reads from .python-version if unset.
python-version: "3.10"
Expand Down Expand Up @@ -78,7 +78,7 @@ jobs:
# To toggle linter comments in the files page, press `i` on the keyboard
if: always()
continue-on-error: true
uses: github/codeql-action/upload-sarif@v3
uses: github/codeql-action/upload-sarif@v4
with:
# Path to SARIF file relative to the root of the repository
sarif_file: lintrunner.sarif
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ jobs:
nox-tag: test-onnx-ir-git
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install nox
Expand All @@ -83,7 +83,7 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload torchlib error reports
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v5
with:
name: Error reports (${{ matrix.name }}-${{ matrix.os }})
path: error_reports
Expand All @@ -95,9 +95,9 @@ jobs:
os: [ubuntu-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.10"
cache: pip
Expand All @@ -119,9 +119,9 @@ jobs:
update_readme:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
- name: Update readme
run: |
python docs/update_readme.py
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pages.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: Setup Pages
uses: actions/configure-pages@v4
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.10"
- uses: actions/checkout@v5
- uses: actions/checkout@v6
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
Expand Down
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ include_patterns = [
exclude_patterns = [
'tests/**', # Skip linting test files for speed
# FIXME: Fix typing annotations in these files
'examples/custom_op_expansion.py',
'onnxscript/converter_test.py',
'onnxscript/converter.py',
'onnxscript/evaluator_test.py',
Expand All @@ -57,7 +58,6 @@ exclude_patterns = [
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
'onnxscript/rewriter/generic_pattern.py', # FIXME
]
command = [
'python',
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.4.1
0.5.7
1 change: 0 additions & 1 deletion docs/api/optimizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@
optimizer.inline
optimizer.basic_constant_propagation
optimizer.fold_constants
optimizer.remove_unused_nodes
```
61 changes: 61 additions & 0 deletions examples/custom_op_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ruff: noqa

"""A utility and an example showing how onnxscript functions can be used to define function expansions
and be used with the inliner to replace calls to the custom function with an expanded subgraph.
This is useful to perform certain classes of graph surgery easily.
"""

import onnx

import onnxscript
import onnxscript.utils.replace as replace

script = onnxscript.script
FLOAT = onnxscript.FLOAT
op = onnxscript.values.opset22
local = onnxscript.values.Opset("local", 1)


# Example Model: Actual models can come from ModelBuilder or Exporter or any other source.
# Models can contain calls to custom operations (from a custom domain like 'local' here or
# even "com.microsoft" etc.)
@script()
def model_script(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
DoubleX = op.Add(X, X)
YSquare = op.Mul(Y, Y)
# Example call to a custom operation
Temp1 = local.CustomOp1(DoubleX, YSquare)
# Another call to a custom operation with an attribute
Temp2 = local.CustomOp2(Temp1, alp=0.9)
return Temp2


# Define expansions for custom operations as onnxscript functions
@script(opset=local)
def CustomOp1(X: FLOAT["N"], Y: FLOAT["N"]) -> FLOAT["N"]:
Temp1 = op.Sub(X, Y)
return op.Div(Temp1, X)


@script(opset=local)
def CustomOp2(X: FLOAT["N"], alp: float) -> FLOAT["N"]:
Temp2 = op.Elu(X, alpha=alp)
return op.Mul(Temp2, Temp2)


# Now, we can replace the custom operations in the model with their expansions:

functions = [CustomOp1.to_function_proto(), CustomOp2.to_function_proto()]

model = model_script.to_model_proto()

print("Original Model with custom operations:")
print(onnx.printer.to_text(model))


updated_model = replace.replace_functions(model, functions)

print("\nUpdated Model after replacing custom operations with their expansions:")
print(onnx.printer.to_text(updated_model))
25 changes: 0 additions & 25 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,3 @@ def rotary_apply_pattern(op, x, pos_ids, axis):
rule = pattern.RewriteRule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

rule.apply_to_model(ir_model)

# TODO(rama): Update the following, the trace-printed looks different now.

######################################
# The logs shows every time the algorithm rejected a pattern.
# We can see the following:
#
# ::
#
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
# --hint--: BACKWARD: different node types
# --pattern
# ConcatTraining(transpose, transpose) -> (output, length)
# -- model
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
# iteration=1
# --marked-- #2
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
# len(stacked)=0:[]
#
# Line 673 in file `generic_pattern.py`, the match was rejected.
# It says while comparing two nodes in the backward direction,
# node types do not match.
# It also says that two nodes were actually matched.
9 changes: 4 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


COMMON_TEST_DEPENDENCIES = (
"beartype==0.17.2",
"expecttest==0.1.6",
"hypothesis",
"numpy",
Expand All @@ -30,9 +29,9 @@
"ml-dtypes",
)
ONNX = "onnx==1.17"
ONNX_RUNTIME = "onnxruntime==1.20.1"
PYTORCH = "torch==2.5.1"
TORCHVISON = "torchvision==0.20.1"
ONNX_RUNTIME = "onnxruntime==1.23.0"
PYTORCH = "torch==2.7.1"
TORCHVISON = "torchvision==0.22.1"
TRANSFORMERS = "transformers==4.37.2"
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
"flatbuffers",
Expand All @@ -42,7 +41,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.7"
ONNX_IR = "onnx_ir==0.1.12"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
2 changes: 2 additions & 0 deletions onnxscript/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"opset20",
"opset21",
"opset22",
"opset23",
"opset_ai_onnx_ml1",
"opset_ai_onnx_ml2",
"opset_ai_onnx_ml3",
Expand Down Expand Up @@ -92,6 +93,7 @@
opset20,
opset21,
opset22,
opset23,
opset_ai_onnx_ml1,
opset_ai_onnx_ml2,
opset_ai_onnx_ml3,
Expand Down
45 changes: 37 additions & 8 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
]

import dataclasses
import importlib.util
import os
import pathlib
from typing import Callable
Expand Down Expand Up @@ -63,20 +64,48 @@ def check_model(model: ir.Model) -> None:
del model # Unused yet


def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
def save_model_with_external_data(
model: ir.Model, model_path: str | os.PathLike, verbose: bool = False
) -> None:
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
for value in model.graph.initializers.values():
if value.const_value is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
uninitialized_values = [
value.name for value in model.graph.initializers.values() if value.const_value is None
]
if uninitialized_values:
raise ValueError(
f"The model contains uninitialized initializer values ({uninitialized_values}). "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
data_path = f"{destination_path.name}.data"

ir.save(model, model_path, external_data=data_path)
# Show a progress bar if verbose is True and tqdm is installed
use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None

if use_tqdm:
import tqdm # pylint: disable=import-outside-toplevel

with tqdm.tqdm() as pbar:
total_set = False

def callback(
tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo
) -> None:
nonlocal total_set
if not total_set:
pbar.total = metadata.total
total_set = True

pbar.update()
pbar.set_description(
f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}"
)

ir.save(model, model_path, external_data=data_path, callback=callback)
else:
ir.save(model, model_path, external_data=data_path)


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
Expand Down
Loading