-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
34 lines (26 loc) · 1.17 KB
/
Copy pathmain.py
File metadata and controls
34 lines (26 loc) · 1.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from __future__ import annotations
import argparse
import sys
AGENT_MODULES = {
"ppo": "baselines.agents.ppo_rnn",
"ppo_rnn": "baselines.agents.ppo_rnn",
"rnd": "baselines.agents.ppo_rnd",
"ppo_rnd": "baselines.agents.ppo_rnd",
"icm": "baselines.agents.ppo_icm",
"ppo_icm": "baselines.agents.ppo_icm",
}
def main(argv: list[str] | None = None) -> None:
argv = list(sys.argv[1:] if argv is None else argv)
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--algorithm", choices=sorted(AGENT_MODULES), default="ppo")
args, remaining = parser.parse_known_args(argv)
module_name = AGENT_MODULES[args.algorithm]
module = __import__(module_name, fromlist=["train", "make_parser"])
result = module.train(module.make_parser().parse_args(remaining))
print(f"compile_time_s: {result.compile_time_s:.6f}")
print(f"training_time_s: {result.training_time_s:.6f}")
print(f"steps_per_second_excluding_compile: {result.steps_per_second:.2f}")
print(f"aggregate_completed_timesteps: {result.aggregate_completed_timesteps}")
print(f"Saved results to {result.results_path}")
if __name__ == "__main__":
main()