Phase 42.4 — CheckpointManager
Overview
The CheckpointManager implements distributed checkpointing and recovery for fault-tolerant AI training. Using the Chandy-Lamport algorithm for consistent distributed snapshots, it enables rollback recovery without stopping computation, combined with write-ahead logging for durability and incremental checkpointing for efficiency.
Academic Foundation
- Chandy & Lamport (1985) — Distributed snapshots: determining global states of distributed systems via marker-based consistent cuts
- Elnozahy et al. (2002) — Survey of rollback-recovery protocols for message-passing systems
- Mohan et al. (1992) — ARIES: write-ahead logging and recovery algorithm
- Plank, Beck & Kingsley (1995) — Checkpoint interval optimization for long-running computations
- Rajbhandari et al. (2020) — ZeRO: activation checkpointing for memory-efficient training
Architecture
CheckpointManager
├── SnapshotProtocol
│ ├── ChandyLamportExecutor # Marker-based consistent cut algorithm
│ ├── ConsistentCutValidator # Verify snapshot consistency (no orphan messages)
│ ├── ChannelStateRecorder # Capture in-flight messages during snapshot
│ └── SnapshotCoordinator # Coordinate snapshot initiation across nodes
├── WriteAheadLog
│ ├── WALWriter # Sequential log entry writing
│ ├── WALReader # Log replay for recovery
│ ├── LogSegmentManager # Log file segmentation and rotation
│ └── LogCompactor # Remove entries before last checkpoint
├── CheckpointStorage
│ ├── LocalCheckpointer # Local disk checkpoint storage
│ ├── DistributedCheckpointer # Distributed storage (S3/GCS/HDFS)
│ ├── IncrementalCheckpointer # Delta-based incremental checkpoints
│ └── CompressionEngine # Checkpoint compression (LZ4/Zstd)
├── RecoveryManager
│ ├── RollbackExecutor # Roll back to consistent checkpoint
│ ├── ForwardRecovery # Replay WAL entries after checkpoint
│ ├── PartialRecovery # Recover individual nodes without full rollback
│ └── ConsistencyVerifier # Post-recovery consistency validation
├── ActivationCheckpointer
│ ├── SelectiveRecompute # Recompute activations instead of storing
│ ├── MemoryBudgetManager # Balance memory vs recomputation cost
│ └── GradientCheckpointer # Checkpoint gradients for distributed backward pass
└── CheckpointMetrics
├── CheckpointDuration # Time to create checkpoint
├── RecoveryTime # Time to recover from checkpoint
├── StorageOverhead # Checkpoint storage usage
└── CheckpointFrequency # Optimal interval tracking
Interface Specification
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional
from pathlib import Path
import asyncio
class CheckpointType(Enum):
FULL = "full" # Complete state snapshot
INCREMENTAL = "incremental" # Delta from last checkpoint
ACTIVATION = "activation" # Training activation checkpoint
class StorageBackend(Enum):
LOCAL = "local"
S3 = "s3"
GCS = "gcs"
HDFS = "hdfs"
class CompressionAlgo(Enum):
NONE = "none"
LZ4 = "lz4"
ZSTD = "zstd"
SNAPPY = "snappy"
@dataclass
class CheckpointConfig:
checkpoint_dir: str = "/checkpoints"
storage_backend: StorageBackend = StorageBackend.LOCAL
checkpoint_type: CheckpointType = CheckpointType.INCREMENTAL
compression: CompressionAlgo = CompressionAlgo.LZ4
max_checkpoints_kept: int = 5
wal_enabled: bool = True
wal_sync_interval_ms: int = 100
wal_segment_size_mb: int = 64
activation_checkpointing: bool = True
checkpoint_interval_steps: int = 1000
async_checkpoint: bool = True
@dataclass
class CheckpointMetadata:
checkpoint_id: str
timestamp: float
step: int
node_states: dict[str, bytes] # node_id → state hash
channel_states: dict[str, list] # channel → in-flight messages
size_bytes: int
is_consistent: bool
parent_checkpoint_id: Optional[str] = None # for incremental
class CheckpointManager:
"""
Distributed checkpointing with Chandy-Lamport consistent snapshots,
write-ahead logging, and rollback recovery for fault-tolerant AI training.
"""
def __init__(self, config: CheckpointConfig):
self.config = config
self.checkpoints: list[CheckpointMetadata] = []
self.wal_position = 0
async def initiate_snapshot(self, initiator_id: str) -> CheckpointMetadata:
"""Initiate Chandy-Lamport distributed snapshot from initiator node."""
...
async def receive_marker(self, sender_id: str, snapshot_id: str) -> None:
"""Handle marker message in Chandy-Lamport protocol."""
...
async def save_checkpoint(self, state: dict, step: int,
checkpoint_type: CheckpointType = None) -> CheckpointMetadata:
"""Save a checkpoint (full or incremental) to storage backend."""
...
async def load_checkpoint(self, checkpoint_id: str) -> dict:
"""Load and reconstruct state from checkpoint."""
...
async def rollback(self, checkpoint_id: str) -> bool:
"""Roll back all nodes to specified consistent checkpoint."""
...
async def forward_recover(self, checkpoint_id: str) -> dict:
"""Recover by replaying WAL entries from checkpoint."""
...
def write_wal_entry(self, entry: dict) -> int:
"""Append entry to write-ahead log, return log position."""
...
def replay_wal(self, from_position: int, to_position: Optional[int] = None) -> list:
"""Replay WAL entries for recovery."""
...
def enable_activation_checkpointing(self, model: Any,
budget_mb: float) -> Any:
"""Enable activation checkpointing on model within memory budget."""
...
async def cleanup_old_checkpoints(self) -> int:
"""Remove checkpoints beyond retention limit, return count removed."""
...
def get_metrics(self) -> dict:
"""Return checkpoint performance metrics."""
...
Testing Requirements
class TestCheckpointManager:
async def test_chandy_lamport_consistent_cut(self): ...
async def test_no_orphan_messages_in_snapshot(self): ...
async def test_full_checkpoint_save_load_roundtrip(self): ...
async def test_incremental_checkpoint_delta(self): ...
async def test_wal_write_and_replay(self): ...
async def test_rollback_restores_consistent_state(self): ...
async def test_forward_recovery_from_wal(self): ...
async def test_partial_node_recovery(self): ...
async def test_compression_reduces_size(self): ...
async def test_async_checkpoint_no_training_pause(self): ...
async def test_activation_checkpointing_memory_savings(self): ...
async def test_checkpoint_cleanup_retention(self): ...
async def test_concurrent_snapshots_isolation(self): ...
async def test_recovery_after_multiple_node_failures(self): ...
Acceptance Criteria
Dependencies
- Phase 42.1 —
ConsensusEngine for coordinating checkpoint initiation
- Phase 42.3 —
FaultDetector for triggering recovery on failure
References
Phase 42.4 — CheckpointManager
Overview
The CheckpointManager implements distributed checkpointing and recovery for fault-tolerant AI training. Using the Chandy-Lamport algorithm for consistent distributed snapshots, it enables rollback recovery without stopping computation, combined with write-ahead logging for durability and incremental checkpointing for efficiency.
Academic Foundation
Architecture
Interface Specification
Testing Requirements
Acceptance Criteria
Dependencies
ConsensusEnginefor coordinating checkpoint initiationFaultDetectorfor triggering recovery on failureReferences