forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtorch_2_5.py
More file actions
146 lines (115 loc) · 4.53 KB
/
torch_2_5.py
File metadata and controls
146 lines (115 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.5."""
from __future__ import annotations
__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]
import dataclasses
import importlib.util
import os
import pathlib
from typing import Callable
from onnxscript import ir, optimizer, version_converter
from onnxscript.function_libs.torch_lib import registration
@dataclasses.dataclass(frozen=True)
class _OnnxFunctionMeta:
"""A wrapper of onnx-script function with additional metadata.
qualified_name: The qualified name of the aten operator.
function: The onnx-script function.
domain: The domain of the function.
name: The name of the function.
is_complex: Whether the function is a complex function.
"""
qualified_name: str
function: Callable
domain: str
name: str
is_complex: bool = False
def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
# Internal flag. Will go away.
enabled = os.getenv("TORCH_ONNX_ENABLE_OPTIMIZATION") == "1"
if enabled:
optimizer.optimize_ir(model)
return model
def convert_version(model: ir.Model, target_version: int) -> ir.Model:
"""Convert the model to the specified ONNX opset version."""
# Internal flag. Will go away.
enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1"
if enabled:
version_converter.convert_version(model, target_version)
return model
def check_model(model: ir.Model) -> None:
"""Check the model."""
del model # Unused yet
def save_model_with_external_data(
model: ir.Model, model_path: str | os.PathLike, verbose: bool = False
) -> None:
"""Save the model with external data. The model is unchanged after saving."""
# TODO(#1835): Decide if we want to externalize large attributes as well
uninitialized_values = [
value.name for value in model.graph.initializers.values() if value.const_value is None
]
if uninitialized_values:
raise ValueError(
f"The model contains uninitialized initializer values ({uninitialized_values}). "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
data_path = f"{destination_path.name}.data"
# Show a progress bar if verbose is True and tqdm is installed
use_tqdm = verbose and importlib.util.find_spec("tqdm") is not None
if use_tqdm:
import tqdm # pylint: disable=import-outside-toplevel
with tqdm.tqdm() as pbar:
total_set = False
def callback(
tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo
) -> None:
nonlocal total_set
if not total_set:
pbar.total = metadata.total
total_set = True
pbar.update()
pbar.set_description(
f"Saving {tensor.name} ({tensor.dtype.short_name()}, {tensor.shape}) at offset {metadata.offset}"
)
ir.save(model, model_path, external_data=data_path, callback=callback)
else:
ir.save(model, model_path, external_data=data_path)
def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
# Trigger op registration
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
ops,
)
del ops # Unused
torchlib_registry = registration.default_registry
function_metas = []
for qualified_name, aten_overloads_func in torchlib_registry.items():
if qualified_name.startswith("internal::"):
# Skip the custom defined internal functions
continue
for overload_func in aten_overloads_func.overloads:
function_meta = _OnnxFunctionMeta(
qualified_name=qualified_name,
function=overload_func,
domain=overload_func.function_ir.domain,
name=overload_func.name,
is_complex=False,
)
function_metas.append(function_meta)
for complex_func in aten_overloads_func.complex:
function_meta = _OnnxFunctionMeta(
qualified_name=qualified_name,
function=complex_func,
domain=complex_func.function_ir.domain,
name=complex_func.name,
is_complex=True,
)
function_metas.append(function_meta)
return function_metas