Skip to content

Commit 77e0adb

Browse files
authored
Add tools for quantization and quantized models (#36)
* add scripts for quantization * update path to pp-resnet50 * add quantized models * rename dict to models * add requirements and readme * fix typos
1 parent c8812a7 commit 77e0adb

11 files changed

+210
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:8550b936bd9fb6362ece6e16b25a4e88d681c244e8d8187acf1265b96a371187
3+
size 120297
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:2b0e941e6f16cc048c20aee0c8e31f569118f65d702914540f7bfdc14048d78a
3+
size 9896933
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:d7f4ae288d187383b616938e2d5d481e8f040d35534856ba834b0121aee08cb2
3+
size 1625110
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:574bc954869eef09b40a3968bb19157c8faf4999419dca13cfaa3ee56ab5ecd4
3+
size 25692063
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:4757c4cb759b79030a9870abf29c064c2ee51e079a05700690800c81b16cf245
3+
size 26763574
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:63b37da9f35d1861fb1af40ab82313794291ad49c950374dc4ed232b56e1b656
3+
size 27710536
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:8b0e04e37882bb1850c91b84b077296153d366f460873658bf8c7c8294d6e0df
3+
size 16363464

tools/quantize/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Quantization with ONNXRUNTIME
2+
3+
ONNXRUNTIME is used for quantization in the Zoo.
4+
5+
Install dependencies before trying quantization:
6+
```shell
7+
pip install -r requirements.txt
8+
```
9+
10+
## Usage
11+
12+
Quantize all models in the Zoo:
13+
```shell
14+
python quantize.py
15+
```
16+
17+
Quantize one of the models in the Zoo:
18+
```shell
19+
# python quantize.py <key_in_models>
20+
python quantize.py yunet
21+
```
22+
23+
Customizing quantization configs:
24+
```python
25+
# add model into `models` dict in quantize.py
26+
models = dict(
27+
# ...
28+
model1=Quantize(model_path='/path/to/model1.onnx'
29+
calibration_image_dir='/path/to/images',
30+
transforms=Compose([''' transforms ''']), # transforms can be found in transforms.py
31+
per_channel=False, # set False to quantize in per-tensor style
32+
act_type='int8', # available types: 'int8', 'uint8'
33+
wt_type='int8' # available types: 'int8', 'uint8'
34+
)
35+
)
36+
# quantize the added models
37+
python quantize.py model1
38+
```

tools/quantize/quantize.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# This file is part of OpenCV Zoo project.
2+
# It is subject to the license terms in the LICENSE file found in the same directory.
3+
#
4+
# Copyright (C) 2021, Shenzhen Institute of Artificial Intelligence and Robotics for Society, all rights reserved.
5+
# Third party copyrights are property of their respective owners.
6+
7+
import os
8+
import sys
9+
import numpy as ny
10+
import cv2 as cv
11+
12+
import onnx
13+
from onnx import version_converter
14+
import onnxruntime
15+
from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType
16+
17+
from transform import Compose, Resize, ColorConvert
18+
19+
class DataReader(CalibrationDataReader):
20+
def __init__(self, model_path, image_dir, transforms):
21+
model = onnx.load(model_path)
22+
self.input_name = model.graph.input[0].name
23+
self.transforms = transforms
24+
self.data = self.get_calibration_data(image_dir)
25+
self.enum_data_dicts = iter([{self.input_name: x} for x in self.data])
26+
27+
def get_next(self):
28+
return next(self.enum_data_dicts, None)
29+
30+
def get_calibration_data(self, image_dir):
31+
blobs = []
32+
for image_name in os.listdir(image_dir):
33+
if not image_name.endswith('jpg'):
34+
continue
35+
img = cv.imread(os.path.join(image_dir, image_name))
36+
img = self.transforms(img)
37+
blob = cv.dnn.blobFromImage(img)
38+
blobs.append(blob)
39+
return blobs
40+
41+
class Quantize:
42+
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8'):
43+
self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
44+
45+
self.model_path = model_path
46+
self.calibration_image_dir = calibration_image_dir
47+
self.transforms = transforms
48+
self.per_channel = per_channel
49+
self.act_type = act_type
50+
self.wt_type = wt_type
51+
52+
# data reader
53+
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms)
54+
55+
def check_opset(self, convert=True):
56+
model = onnx.load(self.model_path)
57+
if model.opset_import[0].version != 11:
58+
print('\tmodel opset version: {}. Converting to opset 11'.format(model.opset_import[0].version))
59+
# convert opset version to 11
60+
model_opset11 = version_converter.convert_version(model, 11)
61+
# save converted model
62+
output_name = '{}-opset11.onnx'.format(self.model_path[:-5])
63+
onnx.save_model(model_opset11, output_name)
64+
# update model_path for quantization
65+
self.model_path = output_name
66+
67+
def run(self):
68+
print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type))
69+
self.check_opset()
70+
output_name = '{}-act_{}-wt_{}-quantized.onnx'.format(self.model_path[:-5], self.act_type, self.wt_type)
71+
quantize_static(self.model_path, output_name, self.dr,
72+
per_channel=self.per_channel,
73+
weight_type=self.type_dict[self.wt_type],
74+
activation_type=self.type_dict[self.act_type])
75+
os.remove('augmented_model.onnx')
76+
os.remove('{}-opt.onnx'.format(self.model_path[:-5]))
77+
print('\tQuantized model saved to {}'.format(output_name))
78+
79+
80+
models=dict(
81+
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2021dec.onnx',
82+
calibration_image_dir='../../benchmark/data/face_detection'),
83+
sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
84+
calibration_image_dir='../../benchmark/data/face_recognition',
85+
transforms=Compose([Resize(size=(112, 112))])),
86+
pphumenseg=Quantize(model_path='../../models/human_segmentation_pphumanseg/human_segmentation_pphumanseg_2021oct.onnx',
87+
calibration_image_dir='../../benchmark/data/human_segmentation',
88+
transforms=Compose([Resize(size=(192, 192))])),
89+
ppresnet50=Quantize(model_path='../../models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx',
90+
calibration_image_dir='../../benchmark/data/image_classification',
91+
transforms=Compose([Resize(size=(224, 224))])),
92+
# TBD: DaSiamRPN
93+
youtureid=Quantize(model_path='../../models/person_reid_youtureid/person_reid_youtu_2021nov.onnx',
94+
calibration_image_dir='../../benchmark/data/person_reid',
95+
transforms=Compose([Resize(size=(128, 256))])),
96+
# TBD: DB-EN & DB-CN
97+
crnn_en=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_EN_2021sep.onnx',
98+
calibration_image_dir='../../benchmark/data/text',
99+
transforms=Compose([Resize(size=(100, 32)), ColorConvert(ctype=cv.COLOR_BGR2GRAY)])),
100+
crnn_cn=Quantize(model_path='../../models/text_recognition_crnn/text_recognition_CRNN_CN_2021nov.onnx',
101+
calibration_image_dir='../../benchmark/data/text',
102+
transforms=Compose([Resize(size=(100, 32))]))
103+
)
104+
105+
if __name__ == '__main__':
106+
selected_models = []
107+
for i in range(1, len(sys.argv)):
108+
selected_models.append(sys.argv[i])
109+
if not selected_models:
110+
selected_models = list(models.keys())
111+
print('Models to be quantized: {}'.format(str(selected_models)))
112+
113+
for selected_model_name in selected_models:
114+
q = models[selected_model_name]
115+
q.run()
116+

tools/quantize/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
opencv-python>=4.5.4.58
2+
onnx
3+
onnxruntime

0 commit comments

Comments
 (0)