33from PIL import Image
44from os import makedirs
55from os .path import join
6- from misc .clip_model import TXT , IMG
6+ from misc .clip_model import TXT , IMG , TXT_NORM , IMG_NORM
77from misc .config import ONNX_FP , opset_version , IMG_DIR
88from misc .proc import transform , tokenizer
99import torch
1717image = torch .tensor (image )
1818
1919
20- def onnx_export (outdir , model , args , ** kwds ):
21- name = f'{ outdir } .onnx'
20+ def onnx_export (model , args , ** kwds ):
21+ name = f'{ model . __class__ . __name__ } .onnx'
2222 fp = join (ONNX_FP , name )
2323 torch .onnx .export (
2424 model ,
@@ -35,26 +35,28 @@ def onnx_export(outdir, model, args, **kwds):
3535
3636
3737# 参考 https://github.com/OFA-Sys/Chinese-CLIP/blob/master/cn_clip/deploy/pytorch_to_onnx.py
38-
39- onnx_export ('txt' ,
40- TXT ,
41- tokenizer (['a photo of cat' , 'a image of cat' ], ),
42- input_names = ['input' , 'attention_mask' ],
43- dynamic_axes = {
44- 'input' : {
45- 0 : 'batch' ,
46- 1 : 'batch' ,
47- },
48- 'attention_mask' : {
49- 0 : 'batch' ,
50- 1 : 'batch' ,
51- }
52- })
53-
54- onnx_export ('img' ,
55- IMG ,
56- image ,
57- input_names = ['input' ],
58- dynamic_axes = {'input' : {
59- 0 : 'batch'
60- }})
38+ def export (txt , img ):
39+ onnx_export (txt ,
40+ tokenizer (['a photo of cat' , 'a image of cat' ], ),
41+ input_names = ['input' , 'attention_mask' ],
42+ dynamic_axes = {
43+ 'input' : {
44+ 0 : 'batch' ,
45+ 1 : 'batch' ,
46+ },
47+ 'attention_mask' : {
48+ 0 : 'batch' ,
49+ 1 : 'batch' ,
50+ }
51+ })
52+
53+ onnx_export (img ,
54+ image ,
55+ input_names = ['input' ],
56+ dynamic_axes = {'input' : {
57+ 0 : 'batch'
58+ }})
59+
60+
61+ export (TXT , IMG )
62+ export (TXT_NORM , IMG_NORM )
0 commit comments