forked from microsoft/onnxscript
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathe2e_ops_tests.py
More file actions
123 lines (99 loc) · 4.11 KB
/
e2e_ops_tests.py
File metadata and controls
123 lines (99 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# TODO(pytorch/pytorch#129279): Migrate these tests to the PyTorch repo
import unittest
import torch
from torch.onnx._internal.exporter import _testing
class TorchLibe2eTest(unittest.TestCase):
def test_investigate_one_particular_model(self):
"""This test can be used to investigate a particular issue."""
red, include, stype = "amin", False, "int32"
dtype = getattr(torch, stype)
class Model(torch.nn.Module):
def __init__(self, include, red):
super().__init__()
self.include = include
self.red = red
def forward(self, x, indices, updates):
x = x.clone()
return x.scatter_reduce(
0, indices, updates, self.red, include_self=self.include
)
model = Model(include, red)
xs = (
torch.tensor([[-2, 0, 2], [2, -2, 0]], dtype=dtype),
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
)
onnx_program = torch.onnx.export(model, xs, dynamo=True)
_testing.assert_onnx_program(onnx_program)
def test_pow_tensor_scalar_int_float(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**0.5
onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)
def test_pow_tensor_scalar_int_int(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**2
onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)
def test_pow_tensor_scalar_float16_int(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**2
onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)
def test_pow_tensor_scalar_float16_float(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**0.5
onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)
def test_repeat_interleave_integer(self):
class Model(torch.nn.Module):
def forward(self, x):
return torch.repeat_interleave(x, 3, dim=1)
onnx_program = torch.onnx.export(
Model(), (torch.randn(2, 3),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)
def test_repeat_interleave_tensor(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind, dim=0)
onnx_program = torch.onnx.export(
Model(),
(
torch.arange(6, dtype=torch.float32).reshape((2, 3)),
torch.tensor([1, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)
def test_repeat_interleave_tensor_none(self):
class Model(torch.nn.Module):
def forward(self, x, ind):
return torch.repeat_interleave(x, ind)
onnx_program = torch.onnx.export(
Model(),
(
torch.arange(4, dtype=torch.float32).reshape((2, 2)),
torch.tensor([1, 2, 3, 2], dtype=torch.int64),
),
dynamo=True,
optimize=False,
)
_testing.assert_onnx_program(onnx_program)
if __name__ == "__main__":
unittest.main()