-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
128 lines (97 loc) · 4.12 KB
/
test.py
File metadata and controls
128 lines (97 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Compare an output to the reference test data
import argparse
import json
import os
import time
from tqdm import tqdm
from indexer import LectureVideoIndexer, Stage, CropRegion
from utils.Intersector import Intersector
current_stage = None
test_videos_path = os.environ.get('TEST_VIDEO_PATH')
EXTRA_FRAME_RATIO = 0.5
def list_diff(li1, li2):
return (list(list(set(li1) - set(li2)) + list(set(li2) - set(li1))))
def handle_progress(bar, stage: Stage, progress: float):
if current_stage != stage:
bar.set_description(str(stage))
bar.update(progress - bar.n)
def compare_index(ref_video, config):
print(f"Creating index for {ref_video['name']}")
bar = tqdm(total=100)
current_stage = None
crop_region = None
if 'cropRegion' in ref_video:
coordinates = ref_video['cropRegion']
crop_region = CropRegion(coordinates[0], coordinates[1], coordinates[2], coordinates[3])
indexer = LectureVideoIndexer(
config=config, progress_callback=lambda stage, progress: handle_progress(bar, stage, progress))
index = indexer.index(video_path=os.path.join(test_videos_path, ref_video['name']),
crop_region=crop_region)
seconds = [entry['second'] for entry in index]
bar.close()
# Intersection with a custom equivalence metric considering error in seconds
intersector = Intersector(lambda x, y: abs(x - y) <= config['frame_step'])
intersection = intersector.intersect(set(seconds), set(ref_video['index']))
intersection_set = set([x[0] for x in intersection])
extra_frames = list(set(seconds) - intersection_set)
missing_frames = list(set(ref_video['index']) - intersection_set)
intersection_cnt = len(list(intersection))
extra_frames_cnt = len(seconds) - intersection_cnt
precision = (intersection_cnt - extra_frames_cnt * EXTRA_FRAME_RATIO) / len(ref_video['index'])
if precision < 0:
precision = 0
print("Missing keyframes: ", missing_frames)
print("Extra keyframes: ", extra_frames)
print(f"Precision: {precision}\n")
return {
'video': ref_video['name'],
'precision': precision,
'missingTimestamps': missing_frames,
'extraTimestamps': extra_frames,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--param', dest="param", help="Name of the observed config param", required=True)
parser.add_argument('--values', dest="values", help="Comma separated values", required=True)
args = parser.parse_args()
observed_param = args.param
values = [
int(x) if observed_param == 'frame_step' or observed_param == 'hash_size' else float(x)
for x in args.values.split(',')
]
results = []
config = {
'frame_step': 2,
'hash_size': 16,
'image_similarity_threshold': 0.9,
'text_similarity_treshold': 0.85
}
with open('test_data/reference.json') as json_data:
test_data = json.load(json_data)
for value in values:
print(f"{observed_param} value: {value}")
precisions = []
times = []
local_results = []
config[observed_param] = value
for i in range(2):
for video in test_data:
start = time.time()
result = compare_index(video, config)
end = time.time()
precisions.append(result['precision'])
times.append(end - start)
local_results.append(result)
results.append({
'value': value,
'avg_precision': round(sum(precisions) / len(precisions), 3),
'max_precision': round(max(precisions), 3),
'min_precision': round(min(precisions), 3),
'avg_time': round(sum(times) / len(times)),
'max_time': round(max(times)),
'min_time': round(min(times)),
'results': local_results
})
with open(f'output/{observed_param}_results.json', 'w') as output_file:
json = json.dumps(results, indent=4)
output_file.write(json)