-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_gaze.py
More file actions
executable file
·118 lines (88 loc) · 3.34 KB
/
test_gaze.py
File metadata and controls
executable file
·118 lines (88 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import time
import pathlib
import numpy as np
import torch
from tqdm import tqdm
from utils import load_config, angular_error
from models import create_model
from datasets import create_gaze_dataloader
from datasets.one_data import create_one_loader
from utils import compute_angle_error, get_output_dir
def test_xgaze(model, test_loader, config):
gpu = config.gpu_id
model.eval()
predictions = []
gt_gazes = []
avg_errors = {}
with torch.no_grad():
pbar = tqdm(total=len(test_loader))
for data in test_loader:
pbar.update(1)
if config.gpu_id >= 0:
data = {k: v.cuda() for k, v in data.items()}
pred_gaze = model(data)
predictions.append(pred_gaze.cpu().detach().numpy())
if config.test_data != 'eth':
gaze = data['gaze'].cpu().detach().numpy()
gt_gazes.append(gaze)
pbar.close()
predictions = np.concatenate(predictions, axis=0)
if config.test_data != 'eth':
gt_gazes = np.concatenate(gt_gazes, axis=0)
err = angular_error(predictions, gt_gazes).mean()
return predictions, err
return predictions
if __name__ == '__main__':
config = load_config()
config.is_train = False
if config.test_output == '':
output_dir = get_output_dir(config)
else:
output_dir = config.test_output
print(output_dir)
if config.test_data == '':
config.test_data = config.data_type
test_data = config.test_data
if config.test_model == '':
test_model = os.path.join(output_dir, 'best_ckpt.pth.tar')
else:
test_model = config.test_model
checkpoint = torch.load(test_model, map_location='cpu')
model = create_model(config)
if config.gpu_id >= 0:
model = model.cuda(config.gpu_id)
model.load_state_dict(checkpoint['model_state'], False)
print('Load eval model from %s' % test_model)
data_dir = config.gaze_data
input_size = config.input_size
batch_size = config.batch_size
data_type = config.data_type
test_ids = config.test_ids
test_ids = []
test_tag = '_'.join(['%02d' % (id) for id in config.test_ids])
eth_test = False
if test_data == 'eth':
eth_test = True
gaze_test_data = create_one_loader(data_dir, input_size, batch_size, test_data, test_ids, False, eth_test)
print('Start test on %d samples' % len(gaze_test_data.dataset))
# test
predictions = test_xgaze(model, gaze_test_data, config)
if test_data != 'eth':
predictions, err = predictions
print('Tested number of samples %d' % predictions.shape[0])
out_dir = os.path.join(output_dir, test_data)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
res_path = os.path.join(out_dir, 'within_eva_results.txt')
np.savetxt(res_path, predictions, delimiter=',')
print('Save results in %s' % (res_path))
if test_data != 'eth':
terr_dir = os.path.join(output_dir, '../%s' % (test_data))
if not os.path.exists(terr_dir):
os.makedirs(terr_dir)
terr_path = os.path.join(terr_dir, 'errors.txt')
with open(terr_path, 'a+') as f:
err_line = '%s,%.6f\n' % (test_tag, err)
f.write(err_line)
print('Gaze error: %.4f saved in %s' %(err, terr_path))