Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"packaging",
"protobuf",
)
ONNX_IR = "onnx_ir==0.1.7"
ONNX_IR = "onnx_ir==0.1.9"
ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir"


Expand Down
154 changes: 2 additions & 152 deletions onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,154 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""In-memory intermediate representation for ONNX graphs."""

__all__ = [
# Modules
"serde",
"traversal",
"convenience",
"external_data",
"tape",
# IR classes
"Tensor",
"ExternalTensor",
"StringTensor",
"LazyTensor",
"SymbolicDim",
"Shape",
"TensorType",
"OptionalType",
"SequenceType",
"SparseTensorType",
"TypeAndShape",
"Value",
"Attr",
"RefAttr",
"Node",
"Function",
"Graph",
"GraphView",
"Model",
# Constructors
"AttrFloat32",
"AttrFloat32s",
"AttrGraph",
"AttrGraphs",
"AttrInt64",
"AttrInt64s",
"AttrSparseTensor",
"AttrSparseTensors",
"AttrString",
"AttrStrings",
"AttrTensor",
"AttrTensors",
"AttrTypeProto",
"AttrTypeProtos",
"Input",
# Protocols
"ArrayCompatible",
"DLPackCompatible",
"TensorProtocol",
"ValueProtocol",
"ModelProtocol",
"NodeProtocol",
"GraphProtocol",
"GraphViewProtocol",
"AttributeProtocol",
"ReferenceAttributeProtocol",
"SparseTensorProtocol",
"SymbolicDimProtocol",
"ShapeProtocol",
"TypeProtocol",
"MapTypeProtocol",
"FunctionProtocol",
# Enums
"AttributeType",
"DataType",
# Types
"OperatorIdentifier",
# Protobuf compatible types
"TensorProtoTensor",
# Conversion functions
"from_proto",
"from_onnx_text",
"to_proto",
# Convenience constructors
"tensor",
"node",
# Pass infrastructure
"passes",
# IO
"load",
"save",
]

from onnx_ir import (
ArrayCompatible,
Attr,
AttrFloat32,
AttrFloat32s,
AttrGraph,
AttrGraphs,
AttributeProtocol,
AttributeType,
AttrInt64,
AttrInt64s,
AttrSparseTensor,
AttrSparseTensors,
AttrString,
AttrStrings,
AttrTensor,
AttrTensors,
AttrTypeProto,
AttrTypeProtos,
DataType,
DLPackCompatible,
ExternalTensor,
Function,
FunctionProtocol,
Graph,
GraphProtocol,
GraphView,
GraphViewProtocol,
Input,
LazyTensor,
MapTypeProtocol,
Model,
ModelProtocol,
Node,
NodeProtocol,
OperatorIdentifier,
OptionalType,
RefAttr,
ReferenceAttributeProtocol,
SequenceType,
Shape,
ShapeProtocol,
SparseTensorProtocol,
SparseTensorType,
StringTensor,
SymbolicDim,
SymbolicDimProtocol,
Tensor,
TensorProtocol,
TensorProtoTensor,
TensorType,
TypeAndShape,
TypeProtocol,
Value,
ValueProtocol,
convenience,
external_data,
from_onnx_text,
from_proto,
load,
node,
passes,
save,
serde,
tape,
tensor,
to_proto,
traversal,
)
# pylint: disable=wildcard-import,unused-wildcard-import
from onnx_ir import * # type: ignore # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

class Bfloat16ConversionTest(unittest.TestCase):
def setUp(self) -> None:
self.v0 = ir.Input(name="v0", shape=ir.Shape([2, 3, 4]))
self.v0 = ir.val(name="v0", shape=ir.Shape([2, 3, 4]))
self.v0.dtype = ir.DataType.BFLOAT16
self.v1 = ir.Input(name="v1", shape=ir.Shape([2, 3, 4]))
self.v1 = ir.val(name="v1", shape=ir.Shape([2, 3, 4]))
self.v1.dtype = ir.DataType.BFLOAT16
self.v2 = ir.Input(name="v2", shape=ir.Shape([2, 3, 4]))
self.v2 = ir.val(name="v2", shape=ir.Shape([2, 3, 4]))
self.v2.dtype = ir.DataType.BFLOAT16

self.add_node = ir.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1)
Expand Down
10 changes: 5 additions & 5 deletions onnxscript/rewriter/rules/common/_basic_rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,14 @@ def _convert_shape(shape, name):
if isinstance(shape, np.ndarray):
shape = tape.initializer(ir.Tensor(shape, name=name))
elif isinstance(shape, (list, tuple)):
shape = ir.Input(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
shape = ir.val(name, ir.Shape(shape), ir.TensorType(ir.DataType.INT64))
tape.graph_like.inputs.append(shape)
else:
raise TypeError(f"Unsupported type {type(shape)} for shape.")
return shape

x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.val("Y", type=ir.TensorType(ir.DataType.FLOAT))
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))

# Build the graph.
Expand Down Expand Up @@ -554,8 +554,8 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
class Flatten2ReshapeTest(unittest.TestCase):
@staticmethod
def create_model(input_shape, axis=1):
x = ir.Input("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.Input("Y", type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", ir.Shape(input_shape), ir.TensorType(ir.DataType.FLOAT))
y = ir.val("Y", type=ir.TensorType(ir.DataType.FLOAT))
tape = ir.tape.Tape(ir.Graph([x], [y], nodes=[], opset_imports={"": 20}))

# Build the graph.
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/rules/common/_fuse_pad_into_conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def build_model(

# Register operations in the tape
idtype = ir.DataType.UINT8 if op_type == "ConvInteger" else ir.DataType.FLOAT
x = ir.Input("X", shape=input_shape, type=ir.TensorType(idtype))
x = ir.val("X", shape=input_shape, type=ir.TensorType(idtype))
y = tape.op("Pad", inputs=[x, *pad_inputs], attributes=pad_attributes)
y = tape.op(
op_type,
inputs=[y, self.get_conv_weights(weight_shape, tape)],
attributes=conv_attributes,
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
output=ir.val("Y", shape=output_shape, type=ir.TensorType(x.dtype)),
)
if op_type == "ConvInteger":
y.dtype = ir.DataType.INT32
Expand Down Expand Up @@ -290,12 +290,12 @@ def build_model(
raise ValueError(f"Unsupported type for pad input ({x}): {type(x)}.")

# Register operations in the tape
x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
y = tape.op(
"Conv",
inputs=[x, *conv_inputs],
attributes=conv_attributes,
output=ir.Input("Y", shape=output_shape, type=x.type),
output=ir.val("Y", shape=output_shape, type=x.type),
)

# Build the model
Expand Down
8 changes: 4 additions & 4 deletions onnxscript/rewriter/rules/common/_matmul_add_to_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def get_test_model(
bias_shape = weight_shape[0] if transB else weight_shape[-1]
output_shape = ir.Shape(("?",) * input_shape.rank())

x = ir.Input("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))
x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT))

if weight_as_inputs:
w = ir.Input("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
w = ir.val("W", shape=weight_shape, type=ir.TensorType(ir.DataType.FLOAT))
inputs.append(w)
else:
w = ir.tensor(
Expand All @@ -58,7 +58,7 @@ def get_test_model(
w = tape.initializer(w)

if bias_as_inputs:
b = ir.Input(
b = ir.val(
"B", shape=ir.Shape([bias_shape]), type=ir.TensorType(ir.DataType.FLOAT)
)
inputs.append(b)
Expand All @@ -77,7 +77,7 @@ def get_test_model(
y = tape.op(
"Add",
inputs=[y, b],
output=ir.Input("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)),
)

# Build the model
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
dependencies = [
"ml_dtypes",
"numpy",
"onnx_ir>=0.1.7,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx_ir>=0.1.9,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range.
"onnx>=1.16",
"packaging",
"typing_extensions>=4.10",
Expand All @@ -41,7 +41,6 @@ onnxscript = ["py.typed"]
onnx = ["py.typed"]

[tool.pytest.ini_options]
filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"]
addopts = "-rsfEX --tb=short --color=yes"

[tool.mypy]
Expand Down
Loading