Skip to content

Phase 13.1 — WorldModel: predictive environment model for model-based planning #368

@web3guru888

Description

@web3guru888

Phase 13.1 — WorldModel: predictive environment model for model-based planning

This issue tracks the specification and implementation of WorldModel, the first sub-phase of Phase 13: World Modeling & Model-Based Planning.

Phase 13 context — Discussion #367 opened the question of which direction to take ASI-Build next. World Modeling was selected as Phase 13 because it is foundational to model-based planning, dream rollouts, and surprise-driven exploration — capabilities that naturally complement the multi-agent coalition infrastructure built in Phase 12.


Motivation

A WorldModel enables the agent to simulate "what happens if" before committing to an action. Instead of relying solely on trial-and-error in the environment, the agent rolls out imagined trajectories through its internal model, selects the best plan, and executes with far greater sample efficiency.

Key benefits:

  • Dyna-style planning: interleave real and imagined experience
  • Zero-shot transfer: updated world model generalises across tasks
  • Surprise detection: high prediction error flags distributional shift
  • Scalable: model is a lightweight neural net, not a full simulator

Enumerations

from enum import Enum, auto

class TransitionBackend(Enum):
    """Backend used for world-model forward pass."""
    MLP = auto()           # simple feed-forward MLP (fast, low-capacity)
    LSTM = auto()          # recurrent — tracks temporal context
    TRANSFORMER = auto()   # attention over observation history window
    ENSEMBLE = auto()      # ensemble of N MLPs for epistemic uncertainty

class PredictionTarget(Enum):
    """What the model predicts on each step."""
    NEXT_OBS = auto()      # predict next observation vector
    REWARD = auto()        # predict scalar reward
    DONE = auto()          # predict episode-termination flag
    ALL = auto()           # joint head — obs + reward + done

Data Classes

from dataclasses import dataclass, field
from typing import Sequence

@dataclass(frozen=True)
class ModelInput:
    """Single (obs, action) pair passed to WorldModel."""
    observation: tuple[float, ...]   # current obs vector
    action: int                      # discrete action index

@dataclass(frozen=True)
class ModelOutput:
    """Predicted next state."""
    next_observation: tuple[float, ...]  # predicted next obs
    predicted_reward: float              # predicted reward
    predicted_done: bool                 # predicted termination
    prediction_error: float              # MSE vs last ground-truth (0 if unavailable)

@dataclass(frozen=True)
class DreamRollout:
    """Result of an imagined trajectory through the world model."""
    steps: tuple[ModelOutput, ...]   # ordered sequence of imagined steps
    total_imagined_reward: float     # sum of predicted rewards
    surprise_score: float            # mean prediction_error across steps

@dataclass(frozen=True)
class WorldModelConfig:
    backend: TransitionBackend = TransitionBackend.MLP
    obs_dim: int = 64
    action_dim: int = 8
    hidden_dim: int = 256
    ensemble_size: int = 5           # used when backend == ENSEMBLE
    rollout_horizon: int = 10        # steps per dream rollout
    learning_rate: float = 3e-4
    surprise_threshold: float = 0.05 # error above this triggers re-planning
    max_buffer_size: int = 100_000   # replay buffer for model training

Protocol

from typing import Protocol, runtime_checkable

@runtime_checkable
class WorldModel(Protocol):
    """Predictive model of environment dynamics."""

    async def predict(self, inp: ModelInput) -> ModelOutput:
        """One-step forward pass: (obs, action) → next state prediction."""
        ...

    async def dream_rollout(
        self,
        initial_obs: tuple[float, ...],
        action_sequence: Sequence[int],
    ) -> DreamRollout:
        """Simulate a trajectory through the model."""
        ...

    async def update(
        self,
        inp: ModelInput,
        actual_next_obs: tuple[float, ...],
        actual_reward: float,
        actual_done: bool,
    ) -> float:
        """Online update from a real transition; returns scalar loss."""
        ...

    async def surprise(self, inp: ModelInput, actual_next_obs: tuple[float, ...]) -> float:
        """Return prediction error for surprise detection."""
        ...

    async def snapshot(self) -> dict:
        """Return serialisable model weights + config."""
        ...

InMemoryWorldModel Implementation

import asyncio
import math
import random
from collections import deque

class InMemoryWorldModel:
    """NumPy/pure-Python reference implementation (no torch dependency)."""

    def __init__(self, config: WorldModelConfig) -> None:
        self._cfg = config
        self._weights: dict = {}   # placeholder for weight tensors
        self._buffer: deque = deque(maxlen=config.max_buffer_size)
        self._step_count = 0
        self._lock = asyncio.Lock()

    # ------------------------------------------------------------------
    # Protocol implementation
    # ------------------------------------------------------------------

    async def predict(self, inp: ModelInput) -> ModelOutput:
        async with self._lock:
            return await self._forward(inp)

    async def dream_rollout(
        self,
        initial_obs: tuple[float, ...],
        action_sequence: Sequence[int],
    ) -> DreamRollout:
        obs = initial_obs
        steps: list[ModelOutput] = []
        async with self._lock:
            for action in action_sequence:
                out = await self._forward(ModelInput(observation=obs, action=action))
                steps.append(out)
                if out.predicted_done:
                    break
                obs = out.next_observation
        total_reward = sum(s.predicted_reward for s in steps)
        surprise = sum(s.prediction_error for s in steps) / max(len(steps), 1)
        return DreamRollout(
            steps=tuple(steps),
            total_imagined_reward=total_reward,
            surprise_score=surprise,
        )

    async def update(
        self,
        inp: ModelInput,
        actual_next_obs: tuple[float, ...],
        actual_reward: float,
        actual_done: bool,
    ) -> float:
        async with self._lock:
            self._buffer.append((inp, actual_next_obs, actual_reward, actual_done))
            loss = await self._sgd_step(inp, actual_next_obs, actual_reward, actual_done)
            self._step_count += 1
            _wm_update_total.inc()
            _wm_buffer_size.set(len(self._buffer))
            return loss

    async def surprise(self, inp: ModelInput, actual_next_obs: tuple[float, ...]) -> float:
        out = await self.predict(inp)
        err = _mse(out.next_observation, actual_next_obs)
        _wm_surprise_score.set(err)
        return err

    async def snapshot(self) -> dict:
        return {
            "config": vars(self._cfg),
            "step_count": self._step_count,
            "buffer_size": len(self._buffer),
        }

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    async def _forward(self, inp: ModelInput) -> ModelOutput:
        """Stub forward pass — implementer replaces with real net."""
        noise = [random.gauss(0, 0.01) for _ in range(self._cfg.obs_dim)]
        next_obs = tuple(o + n for o, n in zip(inp.observation, noise))
        return ModelOutput(
            next_observation=next_obs,
            predicted_reward=0.0,
            predicted_done=False,
            prediction_error=0.0,
        )

    async def _sgd_step(self, inp, actual_next_obs, actual_reward, actual_done) -> float:
        """Stub SGD — implementer replaces with autograd update."""
        pred = await self._forward(inp)
        loss = _mse(pred.next_observation, actual_next_obs)
        return loss

def _mse(a: tuple[float, ...], b: tuple[float, ...]) -> float:
    return sum((x - y) ** 2 for x, y in zip(a, b)) / max(len(a), 1)

def build_world_model(config: WorldModelConfig | None = None) -> WorldModel:
    return InMemoryWorldModel(config or WorldModelConfig())

CognitiveCycle Integration

class CognitiveCycle:
    def __init__(self, ..., world_model: WorldModel) -> None:
        self._world_model = world_model

    async def _model_based_step(
        self,
        obs: tuple[float, ...],
        candidate_actions: list[int],
    ) -> int:
        """Select best action via dream rollout."""
        best_action, best_reward = candidate_actions[0], float("-inf")
        for action in candidate_actions:
            rollout = await self._world_model.dream_rollout(
                initial_obs=obs,
                action_sequence=[action] * self._cfg.rollout_horizon,
            )
            if rollout.total_imagined_reward > best_reward:
                best_reward = rollout.total_imagined_reward
                best_action = action
            if rollout.surprise_score > self._world_model_cfg.surprise_threshold:
                await self._trigger_replan()
        return best_action

    async def _update_world_model(
        self,
        inp: ModelInput,
        actual_next_obs: tuple[float, ...],
        actual_reward: float,
        actual_done: bool,
    ) -> None:
        loss = await self._world_model.update(inp, actual_next_obs, actual_reward, actual_done)
        _wm_train_loss.set(loss)

Prometheus Metrics

Metric Type Description
wm_update_total Counter Total update() calls (real transitions consumed)
wm_train_loss Gauge Most recent SGD loss
wm_surprise_score Gauge Most recent surprise score
wm_buffer_size Gauge Current replay-buffer fill level
wm_rollout_steps_total Counter Cumulative dream-rollout steps executed
from prometheus_client import Counter, Gauge

_wm_update_total      = Counter("wm_update_total", "Real transitions consumed")
_wm_train_loss        = Gauge("wm_train_loss", "Most recent SGD loss")
_wm_surprise_score    = Gauge("wm_surprise_score", "Most recent surprise score")
_wm_buffer_size       = Gauge("wm_buffer_size", "Replay buffer fill level")
_wm_rollout_steps_total = Counter("wm_rollout_steps_total", "Dream rollout steps executed")

PromQL examples:

# Surprise spikes (potential distribution shift)
wm_surprise_score > 0.05

# Model learning rate (transitions/min)
rate(wm_update_total[1m])

# Training convergence
wm_train_loss

mypy Compliance

Class/Function Issues Resolution
ModelInput.observation tuple[float, ...] vs list always wrap in tuple() at boundary
DreamRollout.steps mutable list built internally cast to tuple(steps) before freeze
InMemoryWorldModel._forward stub returns fixed obs_dim assert len(next_obs) == cfg.obs_dim
build_world_model return type WorldModel (Protocol) # type: ignore[return-value] if needed

Test Targets (12 minimum)

  1. test_predict_returns_model_output — output type + field types
  2. test_predict_obs_dim_preservedlen(next_obs) == obs_dim
  3. test_dream_rollout_length — steps == min(horizon, done_step)
  4. test_dream_rollout_terminates_on_done — early-exit when predicted_done=True
  5. test_update_returns_scalar_lossisinstance(loss, float)
  6. test_update_increments_buffer — buffer grows after each real transition
  7. test_surprise_zero_on_perfect_predictionsurprise(inp, predicted_obs) ≈ 0
  8. test_surprise_high_on_bad_prediction — large error vector → high surprise
  9. test_snapshot_keys — snapshot contains config, step_count, buffer_size
  10. test_build_world_model_returns_protocolisinstance(wm, WorldModel)
  11. test_concurrent_predict_safe — 50 concurrent predict() calls, no race
  12. test_dream_rollout_accumulates_reward — summed predicted_reward correct

Implementation Order (14 steps)

  1. Define TransitionBackend + PredictionTarget enums
  2. Define ModelInput, ModelOutput, DreamRollout, WorldModelConfig dataclasses
  3. Define WorldModel Protocol
  4. Implement _mse() helper + build_world_model() factory
  5. Implement InMemoryWorldModel.__init__() with asyncio lock + deque buffer
  6. Implement _forward() stub (noise-perturbed copy, obs_dim-safe)
  7. Implement predict() — lock + _forward()
  8. Implement update() — buffer append + _sgd_step() + metrics
  9. Implement _sgd_step() stub
  10. Implement dream_rollout() — iterative _forward() with early-exit on done
  11. Implement surprise() — predict + MSE
  12. Implement snapshot()
  13. Add 5 Prometheus metrics + PromQL examples
  14. Write 12 test targets

Phase 13 Roadmap

Sub-phase Component Status
13.1 WorldModel 🟡 This issue
13.2 DreamRolloutPlanner ⏳ Planned
13.3 ModelBasedPolicyOptimizer ⏳ Planned
13.4 SurpriseDetector ⏳ Planned
13.5 WorldModelDashboard ⏳ Planned

Phase 12 recap — Coalition infrastructure complete:


Related: Discussion #367 (Phase 13 direction) | Discussion #369 (Show & Tell) | Discussion #370 (Q&A)
Wiki: Phase-13-World-Model (coming soon)

Metadata

Metadata

Assignees

No one assigned

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions