Skip to content

Phase 6.4 — Incremental/Online Fisher Matrix Updates: FisherAccumulator and rolling-window estimation #252

@web3guru888

Description

@web3guru888

Summary

Phase 6.3 wired EWCRegulariser into STDPOnlineLearner — after each LEARNING boundary the regulariser applies a penalty proportional to the Fisher-weighted parameter deviation from the task checkpoint.

Phase 6.4 closes a key gap: right now the Fisher matrix is computed once per SLEEP cycle (full pass over recent episodes). This is accurate but expensive and creates staleness between SLEEP cycles. Phase 6.4 introduces online (incremental) Fisher estimation so that importance weights are updated continuously as new spikes arrive.


Background

The empirical Fisher for parameter θ_i is:

F_i = E[(∂ log p(y|x,θ) / ∂θ_i)²]

Full-batch estimation at SLEEP_PHASE time approximates this expectation over the episode buffer. Online estimation maintains a rolling exponential moving average (EMA):

F_i(t) = α · g_i(t)² + (1 - α) · F_i(t-1)

where g_i(t) is the gradient at step t and α ∈ (0, 1] is the decay rate (freshness vs. stability trade-off).


New Components

1. FisherAccumulator (learning/fisher_online.py)

@dataclass
class FisherAccumulatorConfig:
    alpha: float = 0.01            # EMA decay — lower = smoother
    min_samples: int = 10          # steps before first snapshot
    snapshot_every: int = 50       # steps between FisherStore writes
    max_staleness_steps: int = 500 # warn if snapshot not written

@dataclass
class AccumulatorStats:
    steps: int
    snapshots_written: int
    last_snapshot_step: int
    max_delta: float               # max |F_new - F_old| across params

class FisherAccumulator:
    """```
    Maintains a per-parameter EMA of squared gradients.
    Called once per STDP update step.
    Writes to FisherMatrixStore every snapshot_every steps.
    """""```

    def __init__(
        self,
        config: FisherAccumulatorConfig,
        store: FisherMatrixBase,
        task_id: str,
    ) -> None: ...

    async def update(
        self,
        gradients: dict[str, npt.NDArray[np.float32]],
    ) -> None:
        """One EMA step. Writes snapshot when due."""```

    async def force_snapshot(self) -> FisherSnapshot:
        """Used by SLEEP_PHASE consolidation."""```

    def stats(self) -> AccumulatorStats: ...

2. Modified STDPOnlineLearner.update()

The existing update() method gains an optional accumulator parameter:

async def update(
    self,
    pre: npt.NDArray[np.float32],
    post: npt.NDArray[np.float32],
    weights: npt.NDArray[np.float32],
    regulariser: EWCRegulariser | None = None,
    accumulator: FisherAccumulator | None = None,  # ← NEW
    task_id: str | None = None,
) -> WeightUpdate:
    ...
    if accumulator is not None:
        # pass pseudo-gradients (weight_delta) as surrogate Fisher signal
        await accumulator.update({"weights": weight_delta})
    ...

3. OnlineFisherSnapshot (extended FisherSnapshot)

@dataclass
class OnlineFisherSnapshot(FisherSnapshot):
    """Adds online-estimation provenance fields."""```
    alpha: float
    steps_accumulated: int
    estimated_at_step: int

Acceptance Criteria

  • FisherAccumulator instantiates with FisherAccumulatorConfig + any FisherMatrixBase backend
  • update() performs one EMA step per call; does not write to store every call
  • Snapshot written to store every snapshot_every steps (configurable)
  • force_snapshot() callable from SLEEP_PHASE hook without waiting for step counter
  • STDPOnlineLearner.update() accepts optional accumulator kwarg (fully backwards-compatible)
  • Weight deltas used as surrogate gradient signal when true gradients unavailable
  • AccumulatorStats.max_delta tracks maximum cross-parameter Fisher drift per snapshot cycle
  • Prometheus metrics exported: ewc_fisher_ema_steps_total, ewc_fisher_snapshot_age_steps, ewc_fisher_max_delta
  • max_staleness_steps breach triggers warning log + Prometheus counter increment
  • All new code passes mypy --strict; numpy arrays typed as npt.NDArray[np.float32]
  • 10 test targets met (see below)

Test Targets

# Test What it guards
1 test_accumulator_ema_decay F(t) = α·g² + (1-α)·F(t-1) applied correctly
2 test_accumulator_no_snapshot_before_min_samples guard on min_samples
3 test_snapshot_written_at_interval store.save called every snapshot_every
4 test_force_snapshot_bypasses_counter force_snapshot ignores step counter
5 test_force_snapshot_resets_counter step counter resets after force
6 test_max_staleness_warning log warning + Prometheus counter when stale
7 test_stdp_learner_accumulator_kwarg accumulator.update called once per STDP step
8 test_stdp_backwards_compat_no_accumulator no accumulator → existing behaviour unchanged
9 test_online_fisher_snapshot_fields OnlineFisherSnapshot carries provenance fields
10 test_accumulator_multi_task_isolation separate task_id paths don't cross-contaminate

Implementation Order

  1. Add FisherAccumulatorConfig and AccumulatorStats dataclasses to learning/fisher_online.py
  2. Implement FisherAccumulator.__init__() (store ref, internal EMA dict, step counter)
  3. Implement FisherAccumulator.update() (EMA step + conditional snapshot)
  4. Implement FisherAccumulator.force_snapshot() (used by SLEEP_PHASE)
  5. Add OnlineFisherSnapshot dataclass extending FisherSnapshot
  6. Modify STDPOnlineLearner.update() signature to accept optional accumulator kwarg
  7. Wire Prometheus metrics (ewc_fisher_ema_steps_total etc.) — pre-init in module body
  8. Write tests (mock FisherMatrixBase.save, assert call counts + EMA values)
  9. Update EWCConfig if needed (e.g. online_fisher: bool = False feature flag)
  10. Document in wiki — Phase-6-Online-Fisher.md

Phase 6 Tracker

Sub-phase Issue Status
6.1 EWC Foundation #241 ✅ Spec filed
6.2 Fisher Backends #245 ✅ Spec filed
6.3 EWCRegulariser Integration #249 ✅ Spec filed
6.4 Online Fisher Updates this 🔲 Open

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions