99import onnxruntime
1010import torch
1111
12- import onnxscript .optimizer
13- import onnxscript .rewriter
14- import onnxscript .tools .training_helper
1512import onnxscript .tools .transformers_models
1613import onnxscript .tools .transformers_models .mistral
1714from onnxscript ._internal .version_utils import (
@@ -36,13 +33,7 @@ def test_mistral_export_cpu(self):
3633 )
3734 input_tensors = input_tensors_many [0 ]
3835 expected = model (* input_tensors )
39- try :
40- proto = onnxscript .tools .transformers_models .export_to_onnx (model , * input_tensors )
41- except torch ._export .verifier .SpecViolationError as e : # pylint: disable=protected-access
42- # see https://github.com/pytorch/pytorch/issues/128394
43- if "Node.meta _enter_autocast is missing val field." in str (e ):
44- raise unittest .SkipTest (str (e ))
45- raise
36+ proto = onnxscript .tools .transformers_models .export_to_onnx (model , * input_tensors )
4637 names = [i .name for i in proto .graph .input ]
4738 np_input_tensors = [x .numpy () for x in input_tensors ]
4839 feeds = dict (zip (names , np_input_tensors ))
@@ -65,15 +56,9 @@ def test_mistral_export_cpu_export_api(self):
6556 )
6657 input_tensors = input_tensors_many [0 ]
6758 expected = model (* input_tensors )
68- try :
69- proto = onnxscript .tools .transformers_models .export_to_onnx (
70- model , * input_tensors , export_api = True
71- )
72- except torch ._export .verifier .SpecViolationError as e : # pylint: disable=protected-access
73- # see https://github.com/pytorch/pytorch/issues/128394
74- if "Node.meta _enter_autocast is missing val field." in str (e ):
75- raise unittest .SkipTest (str (e ))
76- raise
59+ proto = onnxscript .tools .transformers_models .export_to_onnx (
60+ model , * input_tensors , export_api = True
61+ )
7762 names = [i .name for i in proto .graph .input ]
7863 np_input_tensors = [x .numpy () for x in input_tensors ]
7964 feeds = dict (zip (names , np_input_tensors ))
@@ -95,13 +80,7 @@ def test_phi_export_cuda(self):
9580 model = model .to ("cuda" )
9681 input_tensors = [i .to ("cuda" ) for i in input_tensors_cpu ]
9782 expected = model (* input_tensors )
98- try :
99- proto = onnxscript .tools .transformers_models .export_to_onnx (model , * input_tensors )
100- except torch ._export .verifier .SpecViolationError as e : # pylint: disable=protected-access
101- # see https://github.com/pytorch/pytorch/issues/128394
102- if "Node.meta _enter_autocast is missing val field." in str (e ):
103- raise unittest .SkipTest (str (e ))
104- raise
83+ proto = onnxscript .tools .transformers_models .export_to_onnx (model , * input_tensors )
10584 names = [i .name for i in proto .graph .input ]
10685 np_input_tensors = [x .detach ().cpu ().numpy () for x in input_tensors ]
10786 feeds = dict (zip (names , np_input_tensors ))
0 commit comments