Skip to content

Commit 99a8a4d

Browse files
committed
add norm onnx
1 parent c1b9ac6 commit 99a8a4d

File tree

16 files changed

+109
-57
lines changed

16 files changed

+109
-57
lines changed

onnx/README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ONNX(Open Neural Network Exchange),开放神经网络交换,用于在各种
3030

3131
* `./bash.sh` 在本地进入容器的 bash,方便调试
3232

33-
* `./export.sh` 运行容器,导出 onnx
33+
* `./export.sh` 运行容器,下载 pytorch 模型,然后转换为 onnx
3434

3535
设置环境变量 MODEL ,可以配置导出、测试脚本运行的模型 。
3636

@@ -39,6 +39,17 @@ ONNX(Open Neural Network Exchange),开放神经网络交换,用于在各种
3939
* AltCLIP-XLMR-L
4040
* AltCLIP-XLMR-L-m9
4141

42+
运行后将生成 4 个 onnx 文件和很多权重文件
43+
44+
* onnx/AltCLIP-XLMR-L-m18/onnx/Img.onnx
45+
* onnx/AltCLIP-XLMR-L-m18/onnx/ImgNorm.onnx
46+
* onnx/AltCLIP-XLMR-L-m18/onnx/Txt.onnx
47+
* onnx/AltCLIP-XLMR-L-m18/onnx/TxtNorm.onnx
48+
49+
其中 Norm 代表输出归一化的向量,如果想把生成的文本向量和图片向量存入向量数据库,进行相似性搜索,请用归一化的向量。
50+
51+
具体用见下文的 onnx 模型的测试脚本。
52+
4253
* `./dist.sh` 运行容器,导出以上 3 个模型的 onnx,并打包放到 dist 目录下。
4354

4455
### 目录
@@ -67,17 +78,10 @@ onnxruntime 有很多版本可以选择,见[onnxruntime](https://onnxruntime.a
6778

6879
请先安装 [direnv](https://github.com/direnv/direnv/blob/master/README.md) 并在本目录下 `direnv allow` 或者手工 `source .envrc` 来设置 PYTHONPATH 环境变量。
6980

70-
* [./test/onnx/onnx_img.py](./test/onnx/onnx_img.py) 生成图片向量
81+
* [./test/onnx/onnx_img.py](./test/onnx/onnx_img.py) 生成图片向量 (norm 代表归一化的向量,可用于向量搜索)
7182
* [./test/onnx/onnx_txt.py](./test/onnx/onnx_txt.py) 生成文本向量
7283
* [./test/onnx/onnx_test.py](./test/onnx/onnx_test.py) 匹配图片向量和文本向量,进行零样本分类
7384

74-
如果想把生成的文本向量和图片向量存入数据库,进行相似性搜索,请先对特征进行归一化。
75-
76-
```python
77-
image_features /= image_features.norm(dim=-1, keepdim=True)
78-
text_features /= text_features.norm(dim=-1, keepdim=True)
79-
```
80-
8185
可借助向量数据库,提升零样本分类的准确性,参见[ECCV 2022 | 无需下游训练,Tip-Adapter 大幅提升 CLIP 图像分类准确率](https://cloud.tencent.com/developer/article/2126102)
8286

8387
#### pytorch 模型

onnx/export/onnx_export.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from PIL import Image
44
from os import makedirs
55
from os.path import join
6-
from misc.clip_model import TXT, IMG
6+
from misc.clip_model import TXT, IMG, TXT_NORM, IMG_NORM
77
from misc.config import ONNX_FP, opset_version, IMG_DIR
88
from misc.proc import transform, tokenizer
99
import torch
@@ -17,8 +17,8 @@
1717
image = torch.tensor(image)
1818

1919

20-
def onnx_export(outdir, model, args, **kwds):
21-
name = f'{outdir}.onnx'
20+
def onnx_export(model, args, **kwds):
21+
name = f'{model.__class__.__name__}.onnx'
2222
fp = join(ONNX_FP, name)
2323
torch.onnx.export(
2424
model,
@@ -35,26 +35,28 @@ def onnx_export(outdir, model, args, **kwds):
3535

3636

3737
# 参考 https://github.com/OFA-Sys/Chinese-CLIP/blob/master/cn_clip/deploy/pytorch_to_onnx.py
38-
39-
onnx_export('txt',
40-
TXT,
41-
tokenizer(['a photo of cat', 'a image of cat'], ),
42-
input_names=['input', 'attention_mask'],
43-
dynamic_axes={
44-
'input': {
45-
0: 'batch',
46-
1: 'batch',
47-
},
48-
'attention_mask': {
49-
0: 'batch',
50-
1: 'batch',
51-
}
52-
})
53-
54-
onnx_export('img',
55-
IMG,
56-
image,
57-
input_names=['input'],
58-
dynamic_axes={'input': {
59-
0: 'batch'
60-
}})
38+
def export(txt, img):
39+
onnx_export(txt,
40+
tokenizer(['a photo of cat', 'a image of cat'], ),
41+
input_names=['input', 'attention_mask'],
42+
dynamic_axes={
43+
'input': {
44+
0: 'batch',
45+
1: 'batch',
46+
},
47+
'attention_mask': {
48+
0: 'batch',
49+
1: 'batch',
50+
}
51+
})
52+
53+
onnx_export(img,
54+
image,
55+
input_names=['input'],
56+
dynamic_axes={'input': {
57+
0: 'batch'
58+
}})
59+
60+
61+
export(TXT, IMG)
62+
export(TXT_NORM, IMG_NORM)

onnx/export/tar.bz2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88

99
def txz(src, to):
10-
stream = bz2.BZ2File(to, 'w')
10+
stream = bz2.BZ2File(to, 'w')
1111

12-
with tarfile.TarFile(fileobj=stream, mode='w') as tar:
13-
tar.add(src, arcname=basename(src))
12+
with tarfile.TarFile(fileobj=stream, mode='w') as tar:
13+
tar.add(src, arcname=basename(src))
1414

15-
stream.close()
15+
stream.close()
1616

1717

1818
txz(ONNX_DIR, ONNX_DIR + '.tar.bz2')

onnx/misc/clip_model.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python
22

33
import torch
4+
from misc.norm import norm
45
import torch.nn as nn
56
from .device import DEVICE
67
from .config import MODEL_FP
@@ -12,10 +13,10 @@
1213
MODEL.to(DEVICE)
1314

1415

15-
class ImgModel(nn.Module):
16+
class Img(nn.Module):
1617

1718
def __init__(self):
18-
super(ImgModel, self).__init__()
19+
super(Img, self).__init__()
1920
self.model = MODEL
2021

2122
def forward(self, image):
@@ -24,10 +25,16 @@ def forward(self, image):
2425
return self.model.get_image_features(image)
2526

2627

27-
class TxtModel(nn.Module):
28+
class ImgNorm(Img):
29+
30+
def forward(self, image):
31+
return norm(super(ImgNorm, self).forward(image))
32+
33+
34+
class Txt(nn.Module):
2835

2936
def __init__(self):
30-
super(TxtModel, self).__init__()
37+
super(Txt, self).__init__()
3138
self.model = MODEL
3239

3340
def forward(self, text, attention_mask):
@@ -37,5 +44,14 @@ def forward(self, text, attention_mask):
3744
return self.model.get_text_features(text, attention_mask=attention_mask)
3845

3946

40-
IMG = ImgModel()
41-
TXT = TxtModel()
47+
class TxtNorm(Txt):
48+
49+
def forward(self, text, attention_mask):
50+
return norm(super(TxtNorm, self).forward(text, attention_mask))
51+
52+
53+
IMG = Img()
54+
IMG_NORM = ImgNorm()
55+
56+
TXT = Txt()
57+
TXT_NORM = TxtNorm()

onnx/misc/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
ROOT = dirname(dirname(abspath(__file__)))
99

10-
IMG_DIR = join(ROOT, 'img')
10+
IMG_DIR = join(ROOT, 'test/img')
1111
MODEL_DIR = join(ROOT, 'model')
1212

1313
MODEL_FP = join(MODEL_DIR, MODEL_NAME)

onnx/misc/norm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env python
2+
import torch
3+
4+
5+
# 对特征进行归一化
6+
def norm(vec):
7+
with torch.no_grad():
8+
vec /= vec.norm(dim=-1, keepdim=True)
9+
return vec

onnx/test/clip/clip_img.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ def img2vec(img):
1313

1414
if __name__ == "__main__":
1515
from misc.config import IMG_DIR
16+
from misc.norm import norm
1617
from os.path import join
17-
fp = join(IMG_DIR, 'cat.jpg')
1818
from PIL import Image
19+
20+
fp = join(IMG_DIR, 'cat.jpg')
1921
img = Image.open(fp)
20-
print(img2vec(img))
22+
vec = img2vec(img)
23+
print('vec', vec)
24+
print('norm', norm(vec))

onnx/test/clip/clip_txt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ def txt2vec(li):
1313
from glob import glob
1414
from misc.config import ROOT
1515
from test_txt import TEST_TXT
16+
from misc.norm import norm
1617

1718
li = glob(join(ROOT, 'jpg/*.jpg'))
1819
for li in TEST_TXT:
1920
r = txt2vec(li)
20-
for txt, i in zip(li, r):
21+
for txt, vec in zip(li, r):
2122
print(txt)
22-
print(i)
23+
print('vec', vec)
24+
print('norm', norm(vec))
2325
print('\n')
File renamed without changes.

0 commit comments

Comments
 (0)