-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPPO_training.py
More file actions
53 lines (41 loc) · 1.18 KB
/
Copy pathPPO_training.py
File metadata and controls
53 lines (41 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CheckpointCallback
import os
# 1. Create Environment
env = gym.make("LunarLander-v3")
# 2. Set up TensorBoard log directory and Checkpoint saving
log_dir = "./DRL_logs/tboard/"
model_dir = "./DRL_logs/models/"
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# 3. Define Callback for saving checkpoints
checkpoint_callback = CheckpointCallback(
save_freq=5000, # Save every 5000 steps
save_path=model_dir,
name_prefix="ppo_lunarlander",
)
# 4. Initialize Model
model = PPO(
"MlpPolicy",
env,
verbose=0,
tensorboard_log=log_dir,
learning_rate=0.001,
batch_size=128,
gamma=0.99,
policy_kwargs=dict(net_arch=[256, 256])
)
# 5. Train the Model
# Run TensorBoard in a new terminal to visualize training progress
# tensorboard --logdir ./DRL_logs/tboard/
# open the URL that it give you in your browser
model.learn(
total_timesteps=300_000,
callback=checkpoint_callback,
progress_bar=True#,
#tb_log_name="DQN_run_1" # Custom name for this run in TensorBoard
)
# 6. Save final model
model.save("ppo_final_model")
env.close()