-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathpreprocess_robonet.py
More file actions
127 lines (98 loc) · 4.59 KB
/
preprocess_robonet.py
File metadata and controls
127 lines (98 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import h5py
import cv2
import imageio
import io
import hashlib
import numpy as np
import os
from tqdm import tqdm
import argparse
from robonet.metadata_helper import load_metadata
def load_camera_imgs(cam_index, file_pointer, file_metadata, target_dims, start_time=0, n_load=None):
cam_group = file_pointer['env']['cam{}_video'.format(cam_index)]
old_dims = file_metadata['frame_dim']
length = file_metadata['img_T']
encoding = file_metadata['img_encoding']
image_format = file_metadata['image_format']
if n_load is None:
n_load = length
old_height, old_width = old_dims
images = np.zeros((n_load, old_height, old_width, 3), dtype=np.uint8)
if encoding == 'mp4':
buf = io.BytesIO(cam_group['frames'][:].tostring())
img_buffer = [img for t, img in enumerate(imageio.get_reader(buf, format='mp4'))]
elif encoding == 'jpg':
img_buffer = [cv2.imdecode(cam_group['frame{}'.format(t)][:], cv2.IMREAD_COLOR)[:, :, ::-1]
for t in range(start_time, start_time + n_load)]
else:
raise ValueError("encoding not supported")
for t, img in enumerate(img_buffer):
images[t] = img
if image_format == 'RGB':
pass
elif image_format == 'BGR':
images = images[:, :, :, ::-1]
else:
raise NotImplementedError
return images
def load_actions(file_pointer, meta_data):
a_T, adim = meta_data['action_T'], meta_data['adim']
if adim == 5:
return file_pointer['policy']['actions'][:]
elif adim == 4 and meta_data['primitives'] == 'autograsp':
action_append, old_actions = np.zeros((a_T, 1)), file_pointer['policy']['actions'][:]
next_state = file_pointer['env']['state'][:][1:, -1]
high_val, low_val = meta_data['high_bound'][-1], meta_data['low_bound'][-1]
midpoint = (high_val + low_val) / 2.0
for t, s in enumerate(next_state):
if s > midpoint:
action_append[t, 0] = high_val
else:
action_append[t, 0] = low_val
return np.concatenate((old_actions, action_append), axis=-1)
elif adim < 4:
pad = np.zeros((a_T, 5 - adim), dtype=np.float32)
return np.concatenate((file_pointer['policy']['actions'][:], pad), axis=-1)
elif adim > 5:
return file_pointer['policy']['actions'][:][:, :5]
def load_data(f_name, file_metadata):
assert os.path.exists(f_name) and os.path.isfile(f_name), "invalid f_name"
with open(f_name, 'rb') as f:
buf = f.read()
assert hashlib.sha256(buf).hexdigest(
) == file_metadata['sha256'], "file hash doesn't match meta-data. maybe delete pkl and re-generate?"
with h5py.File(io.BytesIO(buf)) as hf:
start_time, n_states = 0, min([file_metadata['state_T'], file_metadata['img_T'], file_metadata['action_T'] + 1])
assert n_states > 1, "must be more than one state in loaded tensor!"
images, selected_cams = [], []
images.append(load_camera_imgs(0, hf, file_metadata, None, start_time, n_states)[None])
selected_cams.append(0)
images = np.swapaxes(np.concatenate(images, 0), 0, 1)
actions = load_actions(hf, file_metadata).astype(np.float32)[start_time:start_time + n_states - 1]
return images, actions, None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hdf5_path', type=str, required=True)
parser.add_argument('--save_path', type=str, required=True)
args = parser.parse_args()
file_list = os.listdir(args.hdf5_path)
train_save_path = os.path.join(args.save_path, "train")
test_save_path = os.path.join(args.save_path, "test")
os.makedirs(train_save_path, exist_ok=True)
os.makedirs(test_save_path, exist_ok=True)
test_file_list = []
with open("datasets/robonet/robonet_testset_filenames.txt", 'r') as f:
for line in f:
test_file_list.append(line.strip())
for index, file_name in tqdm(enumerate(file_list)):
if ".pkl" in file_name:
continue
file_save_path = test_save_path if file_name in test_file_list else train_save_path
save_name = os.path.join(file_save_path, file_name.split('.')[0] + '.npz')
hdf5_file_name = os.path.join(args.hdf5_path, file_name)
assert 'hdf5' in hdf5_file_name
meta_data = load_metadata(hdf5_file_name)
imgs, actions, _ = load_data(hdf5_file_name, meta_data.get_file_metadata(hdf5_file_name))
img_shape = imgs.shape
imgs = imgs.reshape((-1, img_shape[-3], img_shape[-2], img_shape[-1]))
np.savez_compressed(save_name, **{'image': imgs, 'action': actions})