Skip to content

Commit 89696da

Browse files
authored
Add DaSiamRPN for object tracking (#15)
* impl wrapper & demo * add data for object tracking benchmark * impl benchmark for DaSiamRPN * update benchmark results for DaSiamRPN
1 parent 3783c22 commit 89696da

16 files changed

+481
-20
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Hardware Setup:
3636
| [PP-ResNet](./models/image_classification_ppresnet) | 224x224 | 56.05 | 602.58 | 98.64 |
3737
| [PP-HumanSeg](./models/human_segmentation_pphumanseg) | 192x192 | 19.92 | 105.32 | 67.97 |
3838
| [WeChatQRCode](./models/qrcode_wechatqrcode) | 100x100 | 7.04 | 37.68 | --- |
39+
| [DaSiamRPN](./models/object_tracking_dasiamrpn) | 1280x720 | 36.15 | 705.48 | 76.82 |
3940

4041
## License
4142

benchmark/benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,14 @@ def run(self, model):
8585
model.setBackend(self._backend)
8686
model.setTarget(self._target)
8787

88-
if 'video' in self._dataloader.name.lower():
89-
model.init(self._dataloader.getROI())
90-
91-
for data in self._dataloader:
92-
filename, img = data[:2]
93-
size = [img.shape[1], img.shape[0]]
88+
for idx, data in enumerate(self._dataloader):
89+
filename, input_data = data[:2]
9490
if filename not in self._benchmark_results:
9591
self._benchmark_results[filename] = dict()
92+
if isinstance(input_data, np.ndarray):
93+
size = [input_data.shape[1], input_data.shape[0]]
94+
else:
95+
size = input_data.getFrameSize()
9696
self._benchmark_results[filename][str(size)] = self._metric.forward(model, *data[1:])
9797

9898
def printResults(self):
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Benchmark:
2+
name: "Object Tracking Benchmark"
3+
type: "Tracking"
4+
data:
5+
type: "TrackingVideoLoader"
6+
path: "benchmark/data/object_tracking"
7+
files: ["throw_cup.mp4"]
8+
metric:
9+
type: "Tracking"
10+
reduction: "gmean"
11+
backend: "default"
12+
target: "cpu"
13+
14+
Model:
15+
name: "DaSiamRPN"
16+
model_path: "models/object_tracking_dasiamrpn/object_tracking_dasiamrpn_model_2021nov.onnx"
17+
kernel_cls1_path: "models/object_tracking_dasiamrpn/object_tracking_dasiamrpn_kernel_cls1_2021nov.onnx"
18+
kernel_r1_path: "models/object_tracking_dasiamrpn/object_tracking_dasiamrpn_kernel_r1_2021nov.onnx"

benchmark/download_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,11 @@ def get_confirm_token(response): # in case of large files
184184
qrcode=Downloader(name='qrcode',
185185
url='https://drive.google.com/u/0/uc?id=1_OXB7eiCIYO335ewkT6EdAeXyriFlq_H&export=download',
186186
sha='ac01c098934a353ca1545b5266de8bb4f176d1b3',
187-
filename='qrcode.zip')
187+
filename='qrcode.zip'),
188+
object_tracking=Downloader(name='object_tracking',
189+
url='https://drive.google.com/u/0/uc?id=1_cw5pUmTF-XmQVcQAI8fIp-Ewi2oMYIn&export=download',
190+
sha='0bdb042632a245270013713bc48ad35e9221f3bb',
191+
filename='object_tracking.zip')
188192
)
189193

190194
if __name__ == '__main__':

benchmark/utils/dataloaders/base_dataloader.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __iter__(self):
3434
class _VideoStream:
3535
def __init__(self, filepath):
3636
self._filepath = filepath
37-
self._video = cv.VideoCapture(filepath)
37+
self._video = cv.VideoCapture(self._filepath)
3838

3939
def __iter__(self):
4040
while True:
@@ -44,8 +44,21 @@ def __iter__(self):
4444
else:
4545
break
4646

47+
def __next__(self):
48+
while True:
49+
has_frame, frame = self._video.read()
50+
if has_frame:
51+
return frame
52+
else:
53+
break
54+
4755
def reload(self):
48-
self._video = cv.VideoCapture(filepath)
56+
self._video = cv.VideoCapture(self._filepath)
57+
58+
def getFrameSize(self):
59+
w = int(self._video.get(cv.CAP_PROP_FRAME_WIDTH))
60+
h = int(self._video.get(cv.CAP_PROP_FRAME_HEIGHT))
61+
return [w, h]
4962

5063

5164
class _BaseVideoLoader:
@@ -56,6 +69,10 @@ def __init__(self, **kwargs):
5669
self._files = kwargs.pop('files', None)
5770
assert self._files,'Benchmark[\'data\'][\'files\'] cannot be empty.'
5871

72+
self._streams = dict()
73+
for filename in self._files:
74+
self._streams[filename] = _VideoStream(os.path.join(self._path, filename))
75+
5976
@property
6077
def name(self):
6178
return self.__class__.__name__
@@ -64,4 +81,4 @@ def __len__(self):
6481
return len(self._files)
6582

6683
def __getitem__(self, idx):
67-
return self._files[idx], _VideoStream(os.path.join(self._path, self._files[idx]))
84+
return self._files[idx], self._streams[idx]
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import numpy as np
23

34
from .base_dataloader import _BaseVideoLoader
@@ -8,18 +9,19 @@ class TrackingVideoLoader(_BaseVideoLoader):
89
def __init__(self, **kwargs):
910
super().__init__(**kwargs)
1011

11-
self._rois = self._load_roi()
12+
self._first_frames = dict()
13+
for filename in self._files:
14+
stream = self._streams[filename]
15+
self._first_frames[filename] = next(stream)
1216

13-
unsupported_keys = []
14-
for k, _ in kwargs.items():
15-
unsupported_keys.append(k)
16-
print('Keys ({}) are not supported in Benchmark[\'data\'].'.format(str(unsupported_keys)))
17+
self._rois = self._load_roi()
1718

1819
def _load_roi(self):
1920
rois = dict.fromkeys(self._files, None)
2021
for filename in self._files:
21-
rois[filename] = np.loadtxt(os.path.join(self._path, '{}.txt'.format(filename[:-4])), ndmin=2)
22+
rois[filename] = np.loadtxt(os.path.join(self._path, '{}.txt'.format(filename[:-4])), dtype=np.int32, ndmin=2)
2223
return rois
2324

24-
def getROI(self):
25-
return self._rois
25+
def __getitem__(self, idx):
26+
filename = self._files[idx]
27+
return filename, self._streams[filename], self._first_frames[filename], self._rois[filename]
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import Base
22
from .detection import Detection
33
from .recognition import Recognition
4+
from .tracking import Tracking
45

5-
__all__ = ['Base', 'Detection', 'Recognition']
6+
__all__ = ['Base', 'Detection', 'Recognition', 'Tracking']
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import cv2 as cv
2+
3+
from .base_metric import BaseMetric
4+
from ..factory import METRICS
5+
6+
@METRICS.register
7+
class Tracking(BaseMetric):
8+
def __init__(self, **kwargs):
9+
super().__init__(**kwargs)
10+
11+
if self._warmup or self._repeat:
12+
print('warmup and repeat in metric for tracking do not function.')
13+
14+
def forward(self, model, *args, **kwargs):
15+
stream, first_frame, rois = args
16+
17+
for roi in rois:
18+
stream.reload()
19+
model.init(first_frame, tuple(roi))
20+
self._timer.reset()
21+
for frame in stream:
22+
self._timer.start()
23+
model.infer(frame)
24+
self._timer.stop()
25+
26+
return self._getResult()

models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .image_classification_ppresnet.ppresnet import PPResNet
66
from .human_segmentation_pphumanseg.pphumanseg import PPHumanSeg
77
from .qrcode_wechatqrcode.wechatqrcode import WeChatQRCode
8+
from .object_tracking_dasiamrpn.dasiamrpn import DaSiamRPN
89

910
class Registery:
1011
def __init__(self, name):
@@ -24,4 +25,5 @@ def register(self, item):
2425
MODELS.register(SFace)
2526
MODELS.register(PPResNet)
2627
MODELS.register(PPHumanSeg)
27-
MODELS.register(WeChatQRCode)
28+
MODELS.register(WeChatQRCode)
29+
MODELS.register(DaSiamRPN)

0 commit comments

Comments
 (0)