Phase 13.1 — WorldModel: predictive environment model for model-based planning
This issue tracks the specification and implementation of WorldModel, the first sub-phase of Phase 13: World Modeling & Model-Based Planning.
Phase 13 context — Discussion #367 opened the question of which direction to take ASI-Build next. World Modeling was selected as Phase 13 because it is foundational to model-based planning, dream rollouts, and surprise-driven exploration — capabilities that naturally complement the multi-agent coalition infrastructure built in Phase 12.
Motivation
A WorldModel enables the agent to simulate "what happens if" before committing to an action. Instead of relying solely on trial-and-error in the environment, the agent rolls out imagined trajectories through its internal model, selects the best plan, and executes with far greater sample efficiency.
Key benefits:
- Dyna-style planning: interleave real and imagined experience
- Zero-shot transfer: updated world model generalises across tasks
- Surprise detection: high prediction error flags distributional shift
- Scalable: model is a lightweight neural net, not a full simulator
Enumerations
from enum import Enum, auto
class TransitionBackend(Enum):
"""Backend used for world-model forward pass."""
MLP = auto() # simple feed-forward MLP (fast, low-capacity)
LSTM = auto() # recurrent — tracks temporal context
TRANSFORMER = auto() # attention over observation history window
ENSEMBLE = auto() # ensemble of N MLPs for epistemic uncertainty
class PredictionTarget(Enum):
"""What the model predicts on each step."""
NEXT_OBS = auto() # predict next observation vector
REWARD = auto() # predict scalar reward
DONE = auto() # predict episode-termination flag
ALL = auto() # joint head — obs + reward + done
Data Classes
from dataclasses import dataclass, field
from typing import Sequence
@dataclass(frozen=True)
class ModelInput:
"""Single (obs, action) pair passed to WorldModel."""
observation: tuple[float, ...] # current obs vector
action: int # discrete action index
@dataclass(frozen=True)
class ModelOutput:
"""Predicted next state."""
next_observation: tuple[float, ...] # predicted next obs
predicted_reward: float # predicted reward
predicted_done: bool # predicted termination
prediction_error: float # MSE vs last ground-truth (0 if unavailable)
@dataclass(frozen=True)
class DreamRollout:
"""Result of an imagined trajectory through the world model."""
steps: tuple[ModelOutput, ...] # ordered sequence of imagined steps
total_imagined_reward: float # sum of predicted rewards
surprise_score: float # mean prediction_error across steps
@dataclass(frozen=True)
class WorldModelConfig:
backend: TransitionBackend = TransitionBackend.MLP
obs_dim: int = 64
action_dim: int = 8
hidden_dim: int = 256
ensemble_size: int = 5 # used when backend == ENSEMBLE
rollout_horizon: int = 10 # steps per dream rollout
learning_rate: float = 3e-4
surprise_threshold: float = 0.05 # error above this triggers re-planning
max_buffer_size: int = 100_000 # replay buffer for model training
Protocol
from typing import Protocol, runtime_checkable
@runtime_checkable
class WorldModel(Protocol):
"""Predictive model of environment dynamics."""
async def predict(self, inp: ModelInput) -> ModelOutput:
"""One-step forward pass: (obs, action) → next state prediction."""
...
async def dream_rollout(
self,
initial_obs: tuple[float, ...],
action_sequence: Sequence[int],
) -> DreamRollout:
"""Simulate a trajectory through the model."""
...
async def update(
self,
inp: ModelInput,
actual_next_obs: tuple[float, ...],
actual_reward: float,
actual_done: bool,
) -> float:
"""Online update from a real transition; returns scalar loss."""
...
async def surprise(self, inp: ModelInput, actual_next_obs: tuple[float, ...]) -> float:
"""Return prediction error for surprise detection."""
...
async def snapshot(self) -> dict:
"""Return serialisable model weights + config."""
...
InMemoryWorldModel Implementation
import asyncio
import math
import random
from collections import deque
class InMemoryWorldModel:
"""NumPy/pure-Python reference implementation (no torch dependency)."""
def __init__(self, config: WorldModelConfig) -> None:
self._cfg = config
self._weights: dict = {} # placeholder for weight tensors
self._buffer: deque = deque(maxlen=config.max_buffer_size)
self._step_count = 0
self._lock = asyncio.Lock()
# ------------------------------------------------------------------
# Protocol implementation
# ------------------------------------------------------------------
async def predict(self, inp: ModelInput) -> ModelOutput:
async with self._lock:
return await self._forward(inp)
async def dream_rollout(
self,
initial_obs: tuple[float, ...],
action_sequence: Sequence[int],
) -> DreamRollout:
obs = initial_obs
steps: list[ModelOutput] = []
async with self._lock:
for action in action_sequence:
out = await self._forward(ModelInput(observation=obs, action=action))
steps.append(out)
if out.predicted_done:
break
obs = out.next_observation
total_reward = sum(s.predicted_reward for s in steps)
surprise = sum(s.prediction_error for s in steps) / max(len(steps), 1)
return DreamRollout(
steps=tuple(steps),
total_imagined_reward=total_reward,
surprise_score=surprise,
)
async def update(
self,
inp: ModelInput,
actual_next_obs: tuple[float, ...],
actual_reward: float,
actual_done: bool,
) -> float:
async with self._lock:
self._buffer.append((inp, actual_next_obs, actual_reward, actual_done))
loss = await self._sgd_step(inp, actual_next_obs, actual_reward, actual_done)
self._step_count += 1
_wm_update_total.inc()
_wm_buffer_size.set(len(self._buffer))
return loss
async def surprise(self, inp: ModelInput, actual_next_obs: tuple[float, ...]) -> float:
out = await self.predict(inp)
err = _mse(out.next_observation, actual_next_obs)
_wm_surprise_score.set(err)
return err
async def snapshot(self) -> dict:
return {
"config": vars(self._cfg),
"step_count": self._step_count,
"buffer_size": len(self._buffer),
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _forward(self, inp: ModelInput) -> ModelOutput:
"""Stub forward pass — implementer replaces with real net."""
noise = [random.gauss(0, 0.01) for _ in range(self._cfg.obs_dim)]
next_obs = tuple(o + n for o, n in zip(inp.observation, noise))
return ModelOutput(
next_observation=next_obs,
predicted_reward=0.0,
predicted_done=False,
prediction_error=0.0,
)
async def _sgd_step(self, inp, actual_next_obs, actual_reward, actual_done) -> float:
"""Stub SGD — implementer replaces with autograd update."""
pred = await self._forward(inp)
loss = _mse(pred.next_observation, actual_next_obs)
return loss
def _mse(a: tuple[float, ...], b: tuple[float, ...]) -> float:
return sum((x - y) ** 2 for x, y in zip(a, b)) / max(len(a), 1)
def build_world_model(config: WorldModelConfig | None = None) -> WorldModel:
return InMemoryWorldModel(config or WorldModelConfig())
CognitiveCycle Integration
class CognitiveCycle:
def __init__(self, ..., world_model: WorldModel) -> None:
self._world_model = world_model
async def _model_based_step(
self,
obs: tuple[float, ...],
candidate_actions: list[int],
) -> int:
"""Select best action via dream rollout."""
best_action, best_reward = candidate_actions[0], float("-inf")
for action in candidate_actions:
rollout = await self._world_model.dream_rollout(
initial_obs=obs,
action_sequence=[action] * self._cfg.rollout_horizon,
)
if rollout.total_imagined_reward > best_reward:
best_reward = rollout.total_imagined_reward
best_action = action
if rollout.surprise_score > self._world_model_cfg.surprise_threshold:
await self._trigger_replan()
return best_action
async def _update_world_model(
self,
inp: ModelInput,
actual_next_obs: tuple[float, ...],
actual_reward: float,
actual_done: bool,
) -> None:
loss = await self._world_model.update(inp, actual_next_obs, actual_reward, actual_done)
_wm_train_loss.set(loss)
Prometheus Metrics
| Metric |
Type |
Description |
wm_update_total |
Counter |
Total update() calls (real transitions consumed) |
wm_train_loss |
Gauge |
Most recent SGD loss |
wm_surprise_score |
Gauge |
Most recent surprise score |
wm_buffer_size |
Gauge |
Current replay-buffer fill level |
wm_rollout_steps_total |
Counter |
Cumulative dream-rollout steps executed |
from prometheus_client import Counter, Gauge
_wm_update_total = Counter("wm_update_total", "Real transitions consumed")
_wm_train_loss = Gauge("wm_train_loss", "Most recent SGD loss")
_wm_surprise_score = Gauge("wm_surprise_score", "Most recent surprise score")
_wm_buffer_size = Gauge("wm_buffer_size", "Replay buffer fill level")
_wm_rollout_steps_total = Counter("wm_rollout_steps_total", "Dream rollout steps executed")
PromQL examples:
# Surprise spikes (potential distribution shift)
wm_surprise_score > 0.05
# Model learning rate (transitions/min)
rate(wm_update_total[1m])
# Training convergence
wm_train_loss
mypy Compliance
| Class/Function |
Issues |
Resolution |
ModelInput.observation |
tuple[float, ...] vs list |
always wrap in tuple() at boundary |
DreamRollout.steps |
mutable list built internally |
cast to tuple(steps) before freeze |
InMemoryWorldModel._forward |
stub returns fixed obs_dim |
assert len(next_obs) == cfg.obs_dim |
build_world_model |
return type WorldModel (Protocol) |
# type: ignore[return-value] if needed |
Test Targets (12 minimum)
test_predict_returns_model_output — output type + field types
test_predict_obs_dim_preserved — len(next_obs) == obs_dim
test_dream_rollout_length — steps == min(horizon, done_step)
test_dream_rollout_terminates_on_done — early-exit when predicted_done=True
test_update_returns_scalar_loss — isinstance(loss, float)
test_update_increments_buffer — buffer grows after each real transition
test_surprise_zero_on_perfect_prediction — surprise(inp, predicted_obs) ≈ 0
test_surprise_high_on_bad_prediction — large error vector → high surprise
test_snapshot_keys — snapshot contains config, step_count, buffer_size
test_build_world_model_returns_protocol — isinstance(wm, WorldModel)
test_concurrent_predict_safe — 50 concurrent predict() calls, no race
test_dream_rollout_accumulates_reward — summed predicted_reward correct
Implementation Order (14 steps)
- Define
TransitionBackend + PredictionTarget enums
- Define
ModelInput, ModelOutput, DreamRollout, WorldModelConfig dataclasses
- Define
WorldModel Protocol
- Implement
_mse() helper + build_world_model() factory
- Implement
InMemoryWorldModel.__init__() with asyncio lock + deque buffer
- Implement
_forward() stub (noise-perturbed copy, obs_dim-safe)
- Implement
predict() — lock + _forward()
- Implement
update() — buffer append + _sgd_step() + metrics
- Implement
_sgd_step() stub
- Implement
dream_rollout() — iterative _forward() with early-exit on done
- Implement
surprise() — predict + MSE
- Implement
snapshot()
- Add 5 Prometheus metrics + PromQL examples
- Write 12 test targets
Phase 13 Roadmap
| Sub-phase |
Component |
Status |
| 13.1 |
WorldModel |
🟡 This issue |
| 13.2 |
DreamRolloutPlanner |
⏳ Planned |
| 13.3 |
ModelBasedPolicyOptimizer |
⏳ Planned |
| 13.4 |
SurpriseDetector |
⏳ Planned |
| 13.5 |
WorldModelDashboard |
⏳ Planned |
Phase 12 recap — Coalition infrastructure complete:
Related: Discussion #367 (Phase 13 direction) | Discussion #369 (Show & Tell) | Discussion #370 (Q&A)
Wiki: Phase-13-World-Model (coming soon)
Phase 13.1 —
WorldModel: predictive environment model for model-based planningThis issue tracks the specification and implementation of
WorldModel, the first sub-phase of Phase 13: World Modeling & Model-Based Planning.Phase 13 context — Discussion #367 opened the question of which direction to take ASI-Build next. World Modeling was selected as Phase 13 because it is foundational to model-based planning, dream rollouts, and surprise-driven exploration — capabilities that naturally complement the multi-agent coalition infrastructure built in Phase 12.
Motivation
A
WorldModelenables the agent to simulate "what happens if" before committing to an action. Instead of relying solely on trial-and-error in the environment, the agent rolls out imagined trajectories through its internal model, selects the best plan, and executes with far greater sample efficiency.Key benefits:
Enumerations
Data Classes
Protocol
InMemoryWorldModelImplementationCognitiveCycle Integration
Prometheus Metrics
wm_update_totalupdate()calls (real transitions consumed)wm_train_losswm_surprise_scorewm_buffer_sizewm_rollout_steps_totalPromQL examples:
mypy Compliance
ModelInput.observationtuple[float, ...]vslisttuple()at boundaryDreamRollout.stepstuple(steps)before freezeInMemoryWorldModel._forwardlen(next_obs) == cfg.obs_dimbuild_world_modelWorldModel(Protocol)# type: ignore[return-value]if neededTest Targets (12 minimum)
test_predict_returns_model_output— output type + field typestest_predict_obs_dim_preserved—len(next_obs) == obs_dimtest_dream_rollout_length— steps == min(horizon, done_step)test_dream_rollout_terminates_on_done— early-exit whenpredicted_done=Truetest_update_returns_scalar_loss—isinstance(loss, float)test_update_increments_buffer— buffer grows after each real transitiontest_surprise_zero_on_perfect_prediction—surprise(inp, predicted_obs) ≈ 0test_surprise_high_on_bad_prediction— large error vector → high surprisetest_snapshot_keys— snapshot containsconfig,step_count,buffer_sizetest_build_world_model_returns_protocol—isinstance(wm, WorldModel)test_concurrent_predict_safe— 50 concurrentpredict()calls, no racetest_dream_rollout_accumulates_reward— summed predicted_reward correctImplementation Order (14 steps)
TransitionBackend+PredictionTargetenumsModelInput,ModelOutput,DreamRollout,WorldModelConfigdataclassesWorldModelProtocol_mse()helper +build_world_model()factoryInMemoryWorldModel.__init__()with asyncio lock + deque buffer_forward()stub (noise-perturbed copy, obs_dim-safe)predict()— lock +_forward()update()— buffer append +_sgd_step()+ metrics_sgd_step()stubdream_rollout()— iterative_forward()with early-exit on donesurprise()— predict + MSEsnapshot()Phase 13 Roadmap
Phase 12 recap — Coalition infrastructure complete:
Related: Discussion #367 (Phase 13 direction) | Discussion #369 (Show & Tell) | Discussion #370 (Q&A)
Wiki: Phase-13-World-Model (coming soon)