You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Phase 6.3 wired EWCRegulariser into STDPOnlineLearner — after each LEARNING boundary the regulariser applies a penalty proportional to the Fisher-weighted parameter deviation from the task checkpoint.
Phase 6.4 closes a key gap: right now the Fisher matrix is computed once per SLEEP cycle (full pass over recent episodes). This is accurate but expensive and creates staleness between SLEEP cycles. Phase 6.4 introduces online (incremental) Fisher estimation so that importance weights are updated continuously as new spikes arrive.
Background
The empirical Fisher for parameter θ_i is:
F_i = E[(∂ log p(y|x,θ) / ∂θ_i)²]
Full-batch estimation at SLEEP_PHASE time approximates this expectation over the episode buffer. Online estimation maintains a rolling exponential moving average (EMA):
F_i(t) = α · g_i(t)² + (1 - α) · F_i(t-1)
where g_i(t) is the gradient at step t and α ∈ (0, 1] is the decay rate (freshness vs. stability trade-off).
New Components
1. FisherAccumulator (learning/fisher_online.py)
@dataclassclassFisherAccumulatorConfig:
alpha: float=0.01# EMA decay — lower = smoothermin_samples: int=10# steps before first snapshotsnapshot_every: int=50# steps between FisherStore writesmax_staleness_steps: int=500# warn if snapshot not written@dataclassclassAccumulatorStats:
steps: intsnapshots_written: intlast_snapshot_step: intmax_delta: float# max |F_new - F_old| across paramsclassFisherAccumulator:
"""``` Maintains a per-parameter EMA of squared gradients. Called once per STDP update step. Writes to FisherMatrixStore every snapshot_every steps. """""```
def__init__(
self,
config: FisherAccumulatorConfig,
store: FisherMatrixBase,
task_id: str,
) ->None: ...
asyncdefupdate(
self,
gradients: dict[str, npt.NDArray[np.float32]],
) ->None:
"""One EMA step. Writes snapshot when due."""```
asyncdefforce_snapshot(self) ->FisherSnapshot:
"""Used by SLEEP_PHASE consolidation."""```
defstats(self) ->AccumulatorStats: ...
2. Modified STDPOnlineLearner.update()
The existing update() method gains an optional accumulator parameter:
Summary
Phase 6.3 wired
EWCRegulariserintoSTDPOnlineLearner— after each LEARNING boundary the regulariser applies a penalty proportional to the Fisher-weighted parameter deviation from the task checkpoint.Phase 6.4 closes a key gap: right now the Fisher matrix is computed once per SLEEP cycle (full pass over recent episodes). This is accurate but expensive and creates staleness between SLEEP cycles. Phase 6.4 introduces online (incremental) Fisher estimation so that importance weights are updated continuously as new spikes arrive.
Background
The empirical Fisher for parameter θ_i is:
Full-batch estimation at SLEEP_PHASE time approximates this expectation over the episode buffer. Online estimation maintains a rolling exponential moving average (EMA):
where
g_i(t)is the gradient at steptandα ∈ (0, 1]is the decay rate (freshness vs. stability trade-off).New Components
1.
FisherAccumulator(learning/fisher_online.py)2. Modified
STDPOnlineLearner.update()The existing
update()method gains an optionalaccumulatorparameter:3.
OnlineFisherSnapshot(extended FisherSnapshot)Acceptance Criteria
FisherAccumulatorinstantiates withFisherAccumulatorConfig+ anyFisherMatrixBasebackendupdate()performs one EMA step per call; does not write to store every callsnapshot_everysteps (configurable)force_snapshot()callable from SLEEP_PHASE hook without waiting for step counterSTDPOnlineLearner.update()accepts optionalaccumulatorkwarg (fully backwards-compatible)AccumulatorStats.max_deltatracks maximum cross-parameter Fisher drift per snapshot cycleewc_fisher_ema_steps_total,ewc_fisher_snapshot_age_steps,ewc_fisher_max_deltamax_staleness_stepsbreach triggers warning log + Prometheus counter incrementmypy --strict; numpy arrays typed asnpt.NDArray[np.float32]Test Targets
test_accumulator_ema_decaytest_accumulator_no_snapshot_before_min_samplesmin_samplestest_snapshot_written_at_intervalsnapshot_everytest_force_snapshot_bypasses_countertest_force_snapshot_resets_countertest_max_staleness_warningtest_stdp_learner_accumulator_kwargtest_stdp_backwards_compat_no_accumulatortest_online_fisher_snapshot_fieldsOnlineFisherSnapshotcarries provenance fieldstest_accumulator_multi_task_isolationtask_idpaths don't cross-contaminateImplementation Order
FisherAccumulatorConfigandAccumulatorStatsdataclasses tolearning/fisher_online.pyFisherAccumulator.__init__()(store ref, internal EMA dict, step counter)FisherAccumulator.update()(EMA step + conditional snapshot)FisherAccumulator.force_snapshot()(used by SLEEP_PHASE)OnlineFisherSnapshotdataclass extendingFisherSnapshotSTDPOnlineLearner.update()signature to accept optionalaccumulatorkwargewc_fisher_ema_steps_totaletc.) — pre-init in module bodyFisherMatrixBase.save, assert call counts + EMA values)EWCConfigif needed (e.g.online_fisher: bool = Falsefeature flag)Phase 6 Tracker
Related