Skip to content

Phase 42.4 — CheckpointManager: Distributed Checkpointing & Rollback Recovery #844

@web3guru888

Description

@web3guru888

Phase 42.4 — CheckpointManager

Overview

The CheckpointManager implements distributed checkpointing and recovery for fault-tolerant AI training. Using the Chandy-Lamport algorithm for consistent distributed snapshots, it enables rollback recovery without stopping computation, combined with write-ahead logging for durability and incremental checkpointing for efficiency.

Academic Foundation

  • Chandy & Lamport (1985) — Distributed snapshots: determining global states of distributed systems via marker-based consistent cuts
  • Elnozahy et al. (2002) — Survey of rollback-recovery protocols for message-passing systems
  • Mohan et al. (1992) — ARIES: write-ahead logging and recovery algorithm
  • Plank, Beck & Kingsley (1995) — Checkpoint interval optimization for long-running computations
  • Rajbhandari et al. (2020) — ZeRO: activation checkpointing for memory-efficient training

Architecture

CheckpointManager
├── SnapshotProtocol
│   ├── ChandyLamportExecutor    # Marker-based consistent cut algorithm
│   ├── ConsistentCutValidator   # Verify snapshot consistency (no orphan messages)
│   ├── ChannelStateRecorder     # Capture in-flight messages during snapshot
│   └── SnapshotCoordinator      # Coordinate snapshot initiation across nodes
├── WriteAheadLog
│   ├── WALWriter                # Sequential log entry writing
│   ├── WALReader                # Log replay for recovery
│   ├── LogSegmentManager        # Log file segmentation and rotation
│   └── LogCompactor             # Remove entries before last checkpoint
├── CheckpointStorage
│   ├── LocalCheckpointer        # Local disk checkpoint storage
│   ├── DistributedCheckpointer  # Distributed storage (S3/GCS/HDFS)
│   ├── IncrementalCheckpointer  # Delta-based incremental checkpoints
│   └── CompressionEngine        # Checkpoint compression (LZ4/Zstd)
├── RecoveryManager
│   ├── RollbackExecutor         # Roll back to consistent checkpoint
│   ├── ForwardRecovery          # Replay WAL entries after checkpoint
│   ├── PartialRecovery          # Recover individual nodes without full rollback
│   └── ConsistencyVerifier      # Post-recovery consistency validation
├── ActivationCheckpointer
│   ├── SelectiveRecompute       # Recompute activations instead of storing
│   ├── MemoryBudgetManager      # Balance memory vs recomputation cost
│   └── GradientCheckpointer     # Checkpoint gradients for distributed backward pass
└── CheckpointMetrics
    ├── CheckpointDuration        # Time to create checkpoint
    ├── RecoveryTime              # Time to recover from checkpoint
    ├── StorageOverhead           # Checkpoint storage usage
    └── CheckpointFrequency       # Optimal interval tracking

Interface Specification

from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from pathlib import Path
import asyncio


class CheckpointType(Enum):
    FULL = "full"                # Complete state snapshot
    INCREMENTAL = "incremental"  # Delta from last checkpoint
    ACTIVATION = "activation"    # Training activation checkpoint


class StorageBackend(Enum):
    LOCAL = "local"
    S3 = "s3"
    GCS = "gcs"
    HDFS = "hdfs"


class CompressionAlgo(Enum):
    NONE = "none"
    LZ4 = "lz4"
    ZSTD = "zstd"
    SNAPPY = "snappy"


@dataclass
class CheckpointConfig:
    checkpoint_dir: str = "/checkpoints"
    storage_backend: StorageBackend = StorageBackend.LOCAL
    checkpoint_type: CheckpointType = CheckpointType.INCREMENTAL
    compression: CompressionAlgo = CompressionAlgo.LZ4
    max_checkpoints_kept: int = 5
    wal_enabled: bool = True
    wal_sync_interval_ms: int = 100
    wal_segment_size_mb: int = 64
    activation_checkpointing: bool = True
    checkpoint_interval_steps: int = 1000
    async_checkpoint: bool = True


@dataclass
class CheckpointMetadata:
    checkpoint_id: str
    timestamp: float
    step: int
    node_states: dict[str, bytes]  # node_id → state hash
    channel_states: dict[str, list]  # channel → in-flight messages
    size_bytes: int
    is_consistent: bool
    parent_checkpoint_id: Optional[str] = None  # for incremental


class CheckpointManager:
    """
    Distributed checkpointing with Chandy-Lamport consistent snapshots,
    write-ahead logging, and rollback recovery for fault-tolerant AI training.
    """
    
    def __init__(self, config: CheckpointConfig):
        self.config = config
        self.checkpoints: list[CheckpointMetadata] = []
        self.wal_position = 0
    
    async def initiate_snapshot(self, initiator_id: str) -> CheckpointMetadata:
        """Initiate Chandy-Lamport distributed snapshot from initiator node."""
        ...
    
    async def receive_marker(self, sender_id: str, snapshot_id: str) -> None:
        """Handle marker message in Chandy-Lamport protocol."""
        ...
    
    async def save_checkpoint(self, state: dict, step: int,
                               checkpoint_type: CheckpointType = None) -> CheckpointMetadata:
        """Save a checkpoint (full or incremental) to storage backend."""
        ...
    
    async def load_checkpoint(self, checkpoint_id: str) -> dict:
        """Load and reconstruct state from checkpoint."""
        ...
    
    async def rollback(self, checkpoint_id: str) -> bool:
        """Roll back all nodes to specified consistent checkpoint."""
        ...
    
    async def forward_recover(self, checkpoint_id: str) -> dict:
        """Recover by replaying WAL entries from checkpoint."""
        ...
    
    def write_wal_entry(self, entry: dict) -> int:
        """Append entry to write-ahead log, return log position."""
        ...
    
    def replay_wal(self, from_position: int, to_position: Optional[int] = None) -> list:
        """Replay WAL entries for recovery."""
        ...
    
    def enable_activation_checkpointing(self, model: Any, 
                                         budget_mb: float) -> Any:
        """Enable activation checkpointing on model within memory budget."""
        ...
    
    async def cleanup_old_checkpoints(self) -> int:
        """Remove checkpoints beyond retention limit, return count removed."""
        ...
    
    def get_metrics(self) -> dict:
        """Return checkpoint performance metrics."""
        ...

Testing Requirements

class TestCheckpointManager:
    async def test_chandy_lamport_consistent_cut(self): ...
    async def test_no_orphan_messages_in_snapshot(self): ...
    async def test_full_checkpoint_save_load_roundtrip(self): ...
    async def test_incremental_checkpoint_delta(self): ...
    async def test_wal_write_and_replay(self): ...
    async def test_rollback_restores_consistent_state(self): ...
    async def test_forward_recovery_from_wal(self): ...
    async def test_partial_node_recovery(self): ...
    async def test_compression_reduces_size(self): ...
    async def test_async_checkpoint_no_training_pause(self): ...
    async def test_activation_checkpointing_memory_savings(self): ...
    async def test_checkpoint_cleanup_retention(self): ...
    async def test_concurrent_snapshots_isolation(self): ...
    async def test_recovery_after_multiple_node_failures(self): ...

Acceptance Criteria

  • Chandy-Lamport snapshots produce provably consistent global state
  • Incremental checkpoints reduce storage by >60% compared to full
  • WAL replay recovers all committed operations with zero data loss
  • Rollback recovery completes within 2× checkpoint load time
  • Async checkpointing overhead < 5% of training step time
  • Activation checkpointing reduces peak memory by >40%
  • LZ4 compression achieves >2× compression ratio on model states

Dependencies

  • Phase 42.1 — ConsensusEngine for coordinating checkpoint initiation
  • Phase 42.3 — FaultDetector for triggering recovery on failure

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestphase-42Phase 42: Distributed Systems & Fault-Tolerant AI Infrastructure

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions