-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy patheval.py
More file actions
91 lines (74 loc) · 2.36 KB
/
eval.py
File metadata and controls
91 lines (74 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import subprocess
from termcolor import cprint
from omegaconf import DictConfig, ListConfig, OmegaConf
def get_config():
cli_conf = OmegaConf.from_cli()
yaml_conf = OmegaConf.load(cli_conf.config)
conf = OmegaConf.merge(yaml_conf, cli_conf)
return conf
if __name__ == "__main__":
config = get_config()
project_name = config.experiment.project
eval_type = config.dataset.data_type
def begin_with(file_name):
with open(file_name, "w") as f:
f.write("")
def sample(model_base):
cprint(f"This is sampling.", color = "green")
if model_base == "dream":
subprocess.run(
f'python dream_sample.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='sample',
check=True,
)
elif model_base == "llada":
subprocess.run(
f'python llada_sample.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='sample',
check=True,
)
elif model_base == "sdar":
subprocess.run(
f'python sdar_sample.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='sample',
check=True,
)
elif model_base == "trado":
subprocess.run(
f'python trado_sample.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='sample',
check=True,
)
def reward():
cprint(f"This is the rewarding.", color = "green")
subprocess.run(
f'python reward.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='reward',
check=True,
)
def execute():
cprint(f"This is the execution.", color = "green")
subprocess.run(
f'python execute.py '
f'config=../configs/{project_name}.yaml ',
shell=True,
cwd='reward',
check=True,
)
os.makedirs(f"{project_name}/results", exist_ok=True)
sample(config.model_base)
if eval_type == "code":
execute()
reward()