Skip to content

Commit 5bc7de5

Browse files
xaduprejustinchuby
andauthored
Make test test_smollm 20% faster (#2107)
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 489e6b7 commit 5bc7de5

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

onnxscript/rewriter/ort_fusions/_test_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import os
6-
import tempfile
7-
85
import numpy as np
96
import onnx
107
import onnxruntime
@@ -27,13 +24,13 @@ def _save(model, modelpath):
2724

2825
def ort_run(model_name: str, model, inputs):
2926
providers = ["CPUExecutionProvider"]
30-
with tempfile.TemporaryDirectory() as temp_dir:
31-
model_path = os.path.join(temp_dir, f"{model_name}.onnx")
32-
_save(model, model_path)
33-
# Run model
34-
session = onnxruntime.InferenceSession(model_path, providers=providers)
35-
ort_outputs = session.run(None, inputs)
36-
return ort_outputs
27+
model_proto = ir.serde.serialize_model(model)
28+
options = onnxruntime.SessionOptions()
29+
options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
30+
session = onnxruntime.InferenceSession(
31+
model_proto.SerializeToString(), options, providers=providers
32+
)
33+
return session.run(None, inputs)
3734

3835

3936
def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2):

0 commit comments

Comments
 (0)