|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import os |
9 | | -import sys |
10 | 9 | import unittest |
11 | 10 |
|
12 | 11 | import torch |
|
15 | 14 | import onnxscript.testing |
16 | 15 | from onnxscript import FLOAT, evaluator |
17 | 16 | from onnxscript import opset18 as op |
18 | | -from onnxscript._internal import version_utils |
19 | 17 | from onnxscript.function_libs.torch_lib import graph_building, ops |
20 | 18 |
|
21 | 19 | IS_WINDOWS = os.name == "nt" |
@@ -157,79 +155,5 @@ def test_add_initializer_allows_adding_the_same_tensor_twice_using_same_name(sel |
157 | 155 | graph.add_initializer("x", x_tensor) |
158 | 156 |
|
159 | 157 |
|
160 | | -class _MLP(torch.nn.Module): |
161 | | - def __init__(self, input_size, hidden_size, output_size): |
162 | | - super().__init__() |
163 | | - self.fc1 = torch.nn.Linear(input_size, hidden_size) |
164 | | - self.fc2 = torch.nn.Linear(hidden_size, output_size) |
165 | | - self.relu = torch.nn.ReLU() |
166 | | - |
167 | | - def forward(self, x): |
168 | | - out = self.fc1(x) |
169 | | - out = self.relu(out) |
170 | | - out = self.fc2(out) |
171 | | - return out |
172 | | - |
173 | | - |
174 | | -@unittest.skipIf( |
175 | | - IS_WINDOWS and version_utils.torch_older_than("2.3"), |
176 | | - "dynamo_export not supported on Windows in PyTorch<2.3", |
177 | | -) |
178 | | -@unittest.skipIf( |
179 | | - sys.version_info > (3, 11), |
180 | | - "dynamo_export not supported due to torch.compile not functional for python>3.11", |
181 | | -) |
182 | | -class TestModelSaving(unittest.TestCase): |
183 | | - def test_save_initializer_to_files_for_large_model(self): |
184 | | - # # of model parameters: |
185 | | - # input_size x hidden_size + hidden_size + |
186 | | - # hidden_size x output_size + output_size |
187 | | - # ~= 3GB below |
188 | | - batch_size, input_size, hidden_size, output_size = 1, 4, 50000000, 10 |
189 | | - model = _MLP(input_size, hidden_size, output_size) |
190 | | - x = torch.randn(batch_size, input_size) |
191 | | - |
192 | | - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
193 | | - # Assert model is larger than 2GB (~=3GB) |
194 | | - self.assertGreater(model_proto.ByteSize(), 2**31) |
195 | | - |
196 | | - def test_input_output_and_initializer_are_not_stored_in_value_info(self): |
197 | | - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 |
198 | | - model = _MLP(input_size, hidden_size, output_size) |
199 | | - x = torch.randn(batch_size, input_size) |
200 | | - |
201 | | - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
202 | | - v_names = {v.name for v in model_proto.graph.value_info} |
203 | | - |
204 | | - for i in model_proto.graph.input: |
205 | | - self.assertNotIn(i.name, v_names) |
206 | | - for o in model_proto.graph.output: |
207 | | - self.assertNotIn(o.name, v_names) |
208 | | - for i in model_proto.graph.initializer: |
209 | | - self.assertNotIn(i.name, v_names) |
210 | | - |
211 | | - @unittest.skipIf( |
212 | | - not version_utils.torch_older_than("2.4"), |
213 | | - "PyTorch 2.4-preview optimizes the functions away", |
214 | | - ) |
215 | | - def test_experimental_function_value_info_are_stored_in_graph_value_info(self): |
216 | | - batch_size, input_size, hidden_size, output_size = 1, 4, 5, 10 |
217 | | - model = _MLP(input_size, hidden_size, output_size) |
218 | | - x = torch.randn(batch_size, input_size) |
219 | | - |
220 | | - model_proto = torch.onnx.dynamo_export(model, x).model_proto |
221 | | - v_names = {v.name for v in model_proto.graph.value_info} |
222 | | - torch_functions = [ |
223 | | - f for f in model_proto.functions if f.domain.startswith("pkg.torch") |
224 | | - ] |
225 | | - self.assertNotEqual(len(torch_functions), 0) |
226 | | - for f in torch_functions: |
227 | | - for n in f.node: |
228 | | - for i in n.input: |
229 | | - self.assertIn(f"{f.domain}::{f.name}/{i}", v_names) |
230 | | - for o in n.output: |
231 | | - self.assertIn(f"{f.domain}::{f.name}/{o}", v_names) |
232 | | - |
233 | | - |
234 | 158 | if __name__ == "__main__": |
235 | 159 | unittest.main() |
0 commit comments