Skip to content

Latest commit

 

History

History
158 lines (128 loc) · 4.34 KB

File metadata and controls

158 lines (128 loc) · 4.34 KB

jaxenstein_baselines

Baselines for the JAXENSTEIN environment.

Recurrent PPO Timing Runs

SB3 ViZDoom baseline:

uv run python -m scripts.train_sb3_recurrent_ppo_vizdoom \
  --env VizdoomMyWayHome-v1 \
  --total-timesteps 1000000 \
  --n-envs 8 \
  --device cuda \
  --update-log-freq 1 \
  --run-dir runs/sb3_recurrent_ppo_vizdoom_1m \
  --study-name craftax_worker_scaling

JAXENSTEIN baseline:

uv run python main.py --algorithm ppo \
  --maze my-way-home \
  --total-timesteps 1000000 \
  --n-envs 8 \
  --num-seeds 1 \
  --update-log-freq 1 \
  --run-dir runs/jax_ppo_rnn_jaxenstein_1m \
  --study-name craftax_worker_scaling

Simplified intrinsic-reward variants use the same JAXENSTEIN runner shape:

uv run python main.py --algorithm rnd \
  --maze my-way-home \
  --total-timesteps 1000000 \
  --n-envs 8 \
  --num-seeds 1

uv run python main.py --algorithm icm \
  --maze my-way-home \
  --total-timesteps 1000000 \
  --n-envs 8 \
  --num-seeds 1

Pass --action-concat to condition the policy and intrinsic-reward models on the previous one-hot action. The action vector is concatenated after the image convolutional encoder, not to raw pixels.

Each JAXENSTEIN run writes a compact results.json plus a results.npy sidecar with the po_exploration-style payload: config, raw logged metrics, and per-seed summary entries. Use --steps-log-freq to subsample rollout steps within a logged update and --update-log-freq to subsample training updates. Progress is printed independently of --debug; use --progress-log-freq to control how many updates elapse between percent-finished messages. The logged metrics include both returned_episode_returns and returned_discounted_episode_returns.

Experiment Jobs

Hyperparameter configs live under scripts/hyperparams/<maze>/. The current sets cover key-door, my-way-home, and dmlab-static-01, with PPO, PPO+RND, and PPO+ICM configs for each.

uv run python scripts/write_jobs.py scripts/hyperparams/key-door
uv run python scripts/best_hyperparams_batch.py results/key-door
MPLCONFIGDIR=/tmp/mpl uv run python scripts/plot/plot_env_best_hyperparams.py results/key-door

For launch helpers, see scripts/launch/README.md. Use --dry-run before creating TPU resources:

bash scripts/launch/launch_gcp_experiment.sh --dry-run scripts/hyperparams/key-door

With --study-name, outputs are nested as <run-dir>/<study-name>/n_envs_<N>/results.json and results.npy. For SB3, the model and monitor files are saved in that same n_envs_<N> directory.

Craftax-style worker scaling over 1M learning steps:

uv run python -m scripts.run_worker_scaling_sweep \
  --study-name craftax_worker_scaling \
  --n-envs 1 2 4 8 16 32 64 \
  --total-timesteps 1000000 \
  --n-steps 128 \
  --vizdoom-device cuda

This alternates JAXENSTEIN and ViZDoom runs, skips completed results.json files, stops each benchmark on the first failing worker count, uses the same number of PPO updates for both implementations at each worker count, and writes runs/craftax_worker_scaling_worker_scaling.csv plus runs/craftax_worker_scaling_worker_scaling.png.

Equivalent manual commands:

study_name=craftax_worker_scaling

for n_envs in 1 2 4 8 16 32 64; do
  uv run python main.py --algorithm ppo \
    --maze my-way-home \
    --total-timesteps 1000000 \
    --n-envs "${n_envs}" \
    --num-seeds 1 \
    --update-log-freq 1 \
    --run-dir runs/jax_ppo_rnn_jaxenstein_1m \
    --study-name "${study_name}" || break
done

for n_envs in 1 2 4 8 16 32 64; do
  uv run python -m scripts.train_sb3_recurrent_ppo_vizdoom \
    --env VizdoomMyWayHome-v1 \
    --total-timesteps 1000000 \
    --n-envs "${n_envs}" \
    --device cuda \
    --update-log-freq 1 \
    --run-dir runs/sb3_recurrent_ppo_vizdoom_1m \
    --study-name "${study_name}" || break
done

Random Policy Timing Runs

Miniworld:

uv run python -m scripts.random_policy_miniworld \
  --env MiniWorld-Hallway-v0 \
  --steps 10000 \
  --obs-width 64 \
  --obs-height 64 \
  --save-gif runs/miniworld_random.gif

JAXENSTEIN:

uv run python -m scripts.random_policy_jaxenstein \
  --maze simple \
  --steps 10000 \
  --save-gif runs/jaxenstein_random.gif

ViZDoom:

uv run python -m scripts.random_policy_vizdoom \
  --env VizdoomMyWayHome-v1 \
  --steps 10000 \
  --save-gif runs/vizdoom_random.gif