Skip to content

Phase 33.1 — ElasticWeightConsolidator: Fisher Information-based parameter protection for continual learning #691

@web3guru888

Description

@web3guru888

Overview

The ElasticWeightConsolidator implements Elastic Weight Consolidation (EWC) and related regularization-based continual learning methods. It computes Fisher Information matrices to identify task-critical parameters and applies quadratic penalties during subsequent task learning, preventing catastrophic forgetting while allowing plasticity on less important weights.

Key Responsibilities

  • Fisher Information computation — Diagonal and block-diagonal Fisher approximation from task-specific data
  • Parameter importance mapping — Online importance estimation using both EWC (Kirkpatrick et al., 2017) and Synaptic Intelligence (Zenke et al., 2017)
  • Task-specific regularization — Quadratic penalty terms anchoring important parameters to their post-training values
  • Online Fisher updates — Incremental Fisher matrix accumulation across tasks without storing full task datasets
  • Memory Aware Synapses (MAS) — Unsupervised importance estimation via gradient magnitude (Aljundi et al., 2018)
  • Multi-task consolidation — Merging importance maps across task sequence with configurable decay
  • Importance visualization — Heatmaps of parameter importance across layers and tasks

Interfaces

Inputs

  • model: nn.Module — Neural network whose parameters are being consolidated
  • task_data: DataLoader — Task-specific data for Fisher computation
  • task_id: str — Unique task identifier
  • lambda_ewc: float — Regularization strength (default: 5000.0)
  • fisher_method: Literal["diagonal", "block_diagonal", "kfac"] — Fisher approximation type
  • online: bool — Whether to use online (running) Fisher estimation

Outputs

  • ConsolidationResult — Contains penalty loss, per-layer importance stats, forgetting risk score
  • ImportanceMap — Dict[str, Tensor] mapping parameter names to importance scores
  • FisherSnapshot — Serializable Fisher matrix for checkpointing

Acceptance Criteria

  • Fisher Information matrix computed within 2x forward-pass time
  • EWC penalty prevents >5% accuracy degradation on previous tasks (Split-MNIST benchmark)
  • Online Fisher updates require O(1) additional memory per task
  • MAS importance estimation works without task labels
  • Supports models with >100M parameters efficiently
  • Importance maps serializable and loadable across sessions
  • Integration tests with Phase 33.5 ContinualOrchestrator
  • 95%+ test coverage, all public methods typed

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestphase-33Phase 33: Continual Learning & Catastrophic Forgetting Prevention

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions