-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathpredict.py
More file actions
126 lines (101 loc) · 5.59 KB
/
predict.py
File metadata and controls
126 lines (101 loc) · 5.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import os
import random
import imageio
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoConfig
from safetensors.torch import load_file
from ivideogpt.vq_model import CompressiveVQModel
from ivideogpt.transformer import HeadModelWithAction
from utils import NPZParser
device = 'cuda'
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained_model_name_or_path', type=str, required=True, help="path to pretrained model")
parser.add_argument('--input_path', type=str, required=True, help="path to input npz file")
parser.add_argument('--dataset_name', type=str, required=True, help="dataset name")
parser.add_argument('--output_path', type=str, default='outputs', help="path to save predicted video")
parser.add_argument("--context_length", type=int, default=2, help="number of init context frames")
parser.add_argument("--segment_length", type=int, default=16,
help="number of frames in total, including context and future frames")
parser.add_argument('--resolution', type=int, default=64, help="resolution of frames")
parser.add_argument('--goal_conditioned', default=False, action='store_true', help="goal-conditioned prediction")
parser.add_argument('--action_conditioned', default=False, action='store_true', help="action-conditioned prediction")
parser.add_argument('--action_dim', default=4, type=int)
parser.add_argument('--repeat_times', default=5, type=int, help="number of times to repeat prediction")
parser.add_argument("--seed", type=int, default=0, help="random seed")
args = parser.parse_args()
return args
@torch.no_grad()
def predict(args, tokenizer, model, input, actions=None):
# prepare inputs
pixel_values = input.to(device, non_blocking=True).unsqueeze(0)
actions = actions.to(device, non_blocking=True) if actions is not None else None
tokens, labels = tokenizer.tokenize(pixel_values, args.context_length)
gen_input = tokens[:, :args.context_length * (16 * 16 + 1)] # TODO: magic number
# predict future frames
max_new_tokens = (1 + 4 * 4) * (args.segment_length - args.context_length) - 1
gen_kwargs = {
'do_sample': True,
'temperature': 1.0,
'top_k': 100,
'max_new_tokens': max_new_tokens,
}
generated_tokens = model.generate(
gen_input.repeat(args.repeat_times, 1),
**gen_kwargs,
pad_token_id=50256, # this is out of vocabulary but suppressing warning
**({'action': actions.repeat(args.repeat_times, 1, 1)} if actions is not None else {}),
)
# generated_tokens will include gen_input
recon_output = tokenizer.detokenize(generated_tokens, args.context_length)
recon_output = recon_output.clamp(0.0, 1.0)
# save predicted video
save_path = args.output_path
os.makedirs(save_path, exist_ok=True)
for j in range(args.repeat_times):
gt_frames = [(pixel_values[0, i].permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
for i in range(pixel_values.shape[1])]
recon_frames = [(recon_output[j, i].permute(1, 2, 0).detach().cpu().numpy() *
255).astype(np.uint8) for i in range(recon_output.shape[1])]
frames = [np.concatenate([gt_frames[i], recon_frames[i]], axis=1) for i in range(len(gt_frames))]
imageio.mimsave(f"{save_path}/pred-samples-{j}.gif", frames, fps=4, loop=0)
def main():
args = parse_args()
if args.seed is not None:
set_seed(args.seed)
assert not (args.goal_conditioned and args.action_conditioned), "Cannot be both goal and action conditioned"
# Load pretrained model and tokenizer
tokenizer = CompressiveVQModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder='tokenizer', low_cpu_mem_usage=False).to(device)
assert args.context_length == tokenizer.context_length
if args.action_conditioned:
config = AutoConfig.from_pretrained(args.pretrained_model_name_or_path, subfolder='transformer')
model = AutoModelForCausalLM.from_config(config)
perlude_tokens_num, tokens_per_dyna = (256 + 1) * args.context_length - 1, 16 # TODO: magic number
model = HeadModelWithAction(model, action_dim=args.action_dim,
prelude_tokens_num=perlude_tokens_num,
tokens_num_per_dyna=tokens_per_dyna,
context=args.context_length,
segment_length=args.segment_length).to(device)
state_dict = load_file(os.path.join(args.pretrained_model_name_or_path, 'transformer', 'model.safetensors'))
model.load_state_dict(state_dict, strict=True)
assert model.llm.config.vocab_size == tokenizer.num_vq_embeddings + tokenizer.num_dyn_embeddings + 2
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained_model_name_or_path, subfolder='transformer', low_cpu_mem_usage=False).to(device)
assert model.config.vocab_size == tokenizer.num_vq_embeddings + tokenizer.num_dyn_embeddings + 2
# Load sample data
npz_parser = NPZParser(args.segment_length, args.resolution)
input, actions = npz_parser.parse(args.input_path, args.dataset_name, load_action=args.action_conditioned)
if args.goal_conditioned:
input = torch.concat([input[-1:], input[:-1]], dim=0)
predict(args, tokenizer, model, input, actions)
if __name__ == "__main__":
main()