Skip to content

Phase 50.4: MetaOptimizer — Learned Optimizers & Learning Rate Adaptation #973

@web3guru888

Description

@web3guru888

Phase 50.4 — MetaOptimizer

Summary

Implement the MetaOptimizer module providing learned optimization algorithms — neural networks that learn to produce parameter updates, replacing hand-designed optimizers like SGD and Adam. This includes LSTM-based update rules, meta-SGD with learned per-parameter learning rates, and learning rate adaptation strategies.

Motivation

Hand-designed optimizers use fixed update rules that may be suboptimal for specific task distributions. Learned optimizers can discover update rules tailored to the meta-learning setting, potentially achieving faster convergence, better generalization, and adaptive behavior that traditional optimizers cannot express. Meta-SGD and LSTM-based optimizers learn per-parameter learning rates and update directions from data.

Architecture

┌─────────────────────────────────────────────────────┐
│                   MetaOptimizer                     │
│                                                     │
│  ┌───────────────────────────────────────────────┐  │
│  │            Optimizer Selection                 │  │
│  │  LSTM-updater │ Meta-SGD │ Warp-Grad │ L2O   │  │
│  └───────┬───────────┬───────────┬───────────────┘  │
│          │           │           │                  │
│  ┌───────▼───┐ ┌─────▼─────┐ ┌──▼────────────┐    │
│  │ LSTM      │ │ Meta-SGD  │ │ Warp-Grad /   │    │
│  │ Optimizer │ │           │ │ L2O Framework │    │
│  │           │ │ α_i = f(θ)│ │               │    │
│  │ h_t, c_t  │ │ per-param │ │ preconditioning│   │
│  │ update =  │ │ learned   │ │ & warping     │    │
│  │ LSTM(∇,h) │ │ lr & dir  │ │               │    │
│  └─────┬─────┘ └─────┬─────┘ └──────┬────────┘    │
│        │             │               │             │
│  ┌─────▼─────────────▼───────────────▼──────────┐  │
│  │          Parameter Update Engine              │  │
│  │  θ_{t+1} = θ_t + Δθ(∇L, state, meta-params)│  │
│  └──────────────────────────────────────────────┘  │
│  ┌──────────────────────────────────────────────┐  │
│  │       Learning Rate Schedule Manager          │  │
│  │  warmup │ cosine │ learned │ cyclical        │  │
│  └──────────────────────────────────────────────┘  │
└─────────────────────────────────────────────────────┘

Core API

class MetaOptimizer:
    def __init__(self, base_model: nn.Module, config: MetaOptimizerConfig):
        """Initialize meta-optimizer for the given base model."""
    
    def compute_update(self, gradients: Dict[str, Tensor],
                      state: OptimizerState) -> Tuple[Dict[str, Tensor], OptimizerState]:
        """Compute parameter updates using the learned optimizer."""
    
    def meta_train_step(self, episodes: List[Episode],
                       base_model: nn.Module) -> MetaOptTrainResult:
        """Train the meta-optimizer on a batch of episodes."""
    
    def get_learned_lr(self) -> Dict[str, float]:
        """Return per-parameter learned learning rates (Meta-SGD)."""
    
    def reset_optimizer_state(self) -> None:
        """Reset hidden states for new task adaptation."""

@dataclass
class MetaOptimizerConfig:
    method: str = "meta_sgd"          # 'lstm', 'meta_sgd', 'warp_grad', 'l2o'
    hidden_size: int = 20             # LSTM hidden size
    num_layers: int = 2               # LSTM layers
    preconditioning: bool = True      # learned preconditioning
    coordinatewise: bool = True       # per-coordinate LSTM (scalable)
    meta_lr: float = 0.001            # learning rate for meta-optimizer params
    unroll_steps: int = 20            # BPTT unroll length for LSTM training
    gradient_clipping: float = 1.0

@dataclass
class OptimizerState:
    hidden: Optional[Tensor] = None   # LSTM hidden state
    cell: Optional[Tensor] = None     # LSTM cell state
    step_count: int = 0
    running_loss: float = 0.0
    per_param_lr: Optional[Dict[str, Tensor]] = None

Key Features

  • LSTM-based learned optimizer — coordinatewise LSTM producing update rules
  • Meta-SGD — learnable per-parameter learning rates and update directions
  • Warp-Grad — learned preconditioning of task loss surfaces
  • L2O (Learning to Optimize) — generalized framework for learned optimization
  • Unrolled differentiation — backprop through optimizer for meta-training
  • Truncated BPTT for scalable meta-optimizer training
  • Transfer across architectures — optimizer generalizes to unseen models
  • Learning rate scheduling — learned warmup and decay strategies

Acceptance Criteria

  • LSTM optimizer converges 20%+ faster than Adam on held-out tasks
  • Meta-SGD per-parameter LRs improve over scalar LR by ≥ 5% accuracy
  • Coordinatewise LSTM scales to models with 10M+ parameters
  • Learned optimizer transfers to unseen architectures (test on 3 architectures)
  • Unrolled training stable for ≥ 20 inner steps without gradient explosion
  • Memory overhead ≤ 2x base model parameters for LSTM states
  • Unit tests with ≥ 95% coverage, integration with Phase 50.2

Academic References

  • Andrychowicz et al. (2016) — "Learning to Learn by Gradient Descent by Gradient Descent"
  • Ravi & Larochelle (2017) — "Optimization as a Model for Few-Shot Learning"
  • Li & Malik (2017) — "Learning to Optimize" (L2O framework)
  • Flennerhag et al. (2020) — "Meta-Learning with Warped Gradient Descent" (Warp-Grad)
  • Li, Zhou, Chen, & Li (2017) — "Meta-SGD: Learning to Learn Quickly for Few-Shot Learning"
  • Chen et al. (2022) — "Learning to Optimize: A Primer and A Benchmark"

Dependencies

  • Phase 50.2 (GradientMetaLearner) for inner-loop integration
  • Phase 50.1 (TaskDistributionSampler) for training episodes
  • Phase 43.3 (HyperparameterTuner) for search space integration

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions