|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -from typing import Protocol, runtime_checkable |
| 7 | +"""Type definitions and trainer protocol for the Forge API.""" |
| 8 | + |
| 9 | +from dataclasses import dataclass |
| 10 | +from typing import Any, Callable, Protocol, runtime_checkable, TypeAlias |
8 | 11 |
|
9 | 12 | import torch |
10 | 13 |
|
11 | | -from forge.api.types import ( |
12 | | - ForwardBackwardResult, |
13 | | - LossFn, |
14 | | - OptimStepResult, |
15 | | - TextTrainBatch, |
16 | | - TrainerConfig, |
17 | | - TrainerStatus, |
18 | | -) |
| 14 | + |
| 15 | +# Loss function signature: takes model outputs (as dict) and batch, returns scalar loss |
| 16 | +# The dict will typically contain logits, but may include other keys depending on use case. |
| 17 | +LossFn: TypeAlias = Callable[[dict[str, Any], "TextTrainBatch"], torch.Tensor] |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class TextTrainBatch: |
| 22 | + """A batch of text training data for forward_backward. |
| 23 | +
|
| 24 | + This dataclass defines the standard format for text training batches across all |
| 25 | + Forge text trainers. |
| 26 | +
|
| 27 | + Attributes: |
| 28 | + input_ids: Input token IDs. Shape: [batch_size, seq_len] |
| 29 | + target_ids: Target token IDs for loss computation. Shape: [batch_size, seq_len] |
| 30 | + target_mask: Mask indicating which tokens to compute loss on. |
| 31 | + Shape: [batch_size, seq_len]. Values are 0 (ignore) or 1 (compute loss). |
| 32 | + If None, computes loss on all tokens. |
| 33 | + target_weights: Per-token weights for loss computation. |
| 34 | + Shape: [batch_size, seq_len]. Used for importance weighting, such as |
| 35 | + advantages in RL (GRPO, PPO) or custom loss weighting schemes. |
| 36 | + If None, all tokens have weight 1.0. |
| 37 | +
|
| 38 | + Example: |
| 39 | + >>> batch = TextTrainBatch( |
| 40 | + >>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]), |
| 41 | + >>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]), |
| 42 | + >>> target_mask=torch.tensor([[0, 0, 1, 1, 1]]), # Only predict last 3 tokens |
| 43 | + >>> target_weights=torch.tensor([[0, 0, 1.0, 0.8, 1.2]]), # Weight by advantage |
| 44 | + >>> ) |
| 45 | + >>> result = await trainer.forward_backward(batch) |
| 46 | + """ |
| 47 | + |
| 48 | + input_ids: torch.Tensor |
| 49 | + target_ids: torch.Tensor |
| 50 | + target_mask: torch.Tensor | None = None |
| 51 | + target_weights: torch.Tensor | None = None |
| 52 | + |
| 53 | + |
| 54 | +@dataclass |
| 55 | +class ForwardBackwardResult: |
| 56 | + """Result from a forward_backward pass. |
| 57 | +
|
| 58 | + Attributes: |
| 59 | + loss: Loss value computed for the batch |
| 60 | + metrics: Additional metrics computed during training (e.g., perplexity, |
| 61 | + accuracy, KL divergence). May be empty if no additional metrics are tracked. |
| 62 | + Values can be scalars, tensors, or other structured data depending on the loss. |
| 63 | +
|
| 64 | + Example: |
| 65 | + >>> result = await trainer.forward_backward(batch) |
| 66 | + >>> result.loss |
| 67 | + 0.3542 |
| 68 | + >>> result.metrics |
| 69 | + {"perplexity": 1.42, "kl_divergence": 0.05} |
| 70 | + """ |
| 71 | + |
| 72 | + loss: float |
| 73 | + metrics: dict[str, Any] |
| 74 | + |
| 75 | + |
| 76 | +@dataclass |
| 77 | +class OptimStepResult: |
| 78 | + """Result from an optimizer step. |
| 79 | +
|
| 80 | + Attributes: |
| 81 | + step: Training step number after this optimizer step |
| 82 | + learning_rate: Current learning rate used for this step |
| 83 | + accumulated_microbatches: Number of forward_backward calls that were |
| 84 | + accumulated before this optimizer step. Useful for tracking gradient |
| 85 | + accumulation behavior. |
| 86 | +
|
| 87 | + Example: |
| 88 | + >>> result = await trainer.optim_step() |
| 89 | + >>> result.step |
| 90 | + 1000 |
| 91 | + >>> result.learning_rate |
| 92 | + 0.0001 |
| 93 | + >>> result.accumulated_microbatches |
| 94 | + 4 |
| 95 | + """ |
| 96 | + |
| 97 | + step: int |
| 98 | + learning_rate: float |
| 99 | + accumulated_microbatches: int |
| 100 | + |
| 101 | + |
| 102 | +@dataclass |
| 103 | +class ParallelismConfig: |
| 104 | + """Parallelism configuration for distributed training. |
| 105 | +
|
| 106 | + Attributes: |
| 107 | + dp_degree: Data parallel degree (number of data parallel replicas) |
| 108 | + tp_degree: Tensor parallel degree (model sharding across devices) |
| 109 | + pp_degree: Pipeline parallel degree (model sharding across pipeline stages) |
| 110 | + cp_degree: Context parallel degree (sequence parallelism for long contexts) |
| 111 | + ep_degree: Expert parallel degree (for MoE models) |
| 112 | + world_size: Total number of processes in the distributed training job |
| 113 | + dp_rank: Current data parallel rank (0 to dp_degree-1) |
| 114 | + tp_rank: Current tensor parallel rank (0 to tp_degree-1) |
| 115 | + device: Device identifier (e.g., "cuda:0", "cuda:1") |
| 116 | +
|
| 117 | + Example: |
| 118 | + >>> config = await trainer.get_config() |
| 119 | + >>> config.parallelism.dp_degree |
| 120 | + 4 |
| 121 | + >>> config.parallelism.tp_degree |
| 122 | + 2 |
| 123 | + >>> config.parallelism.pp_degree |
| 124 | + 1 |
| 125 | + >>> config.parallelism.cp_degree |
| 126 | + 1 |
| 127 | + >>> config.parallelism.ep_degree |
| 128 | + 1 |
| 129 | + >>> config.parallelism.device |
| 130 | + "cuda:0" |
| 131 | + """ |
| 132 | + |
| 133 | + dp_degree: int |
| 134 | + tp_degree: int |
| 135 | + pp_degree: int |
| 136 | + cp_degree: int |
| 137 | + ep_degree: int |
| 138 | + world_size: int |
| 139 | + dp_rank: int |
| 140 | + tp_rank: int |
| 141 | + device: str |
| 142 | + |
| 143 | + |
| 144 | +@dataclass |
| 145 | +class TrainerConfig: |
| 146 | + """Static trainer and model configuration. |
| 147 | +
|
| 148 | + This contains configuration information that doesn't change during training. |
| 149 | +
|
| 150 | + Attributes: |
| 151 | + model_name: Name or path of the model being trained |
| 152 | + model_config: Model architecture configuration. Common keys include: |
| 153 | + - vocab_size: int - Size of the vocabulary |
| 154 | + - hidden_size: int - Hidden dimension size |
| 155 | + - num_layers: int - Number of transformer layers |
| 156 | + - num_attention_heads: int - Number of attention heads |
| 157 | + - max_seq_len: int - Maximum sequence length |
| 158 | + parallelism: Parallelism configuration for distributed training |
| 159 | +
|
| 160 | + Example: |
| 161 | + >>> config = await trainer.get_config() |
| 162 | + >>> config.model_name |
| 163 | + "Qwen/Qwen2.5-7B" |
| 164 | + >>> config.model_config["vocab_size"] |
| 165 | + 151936 |
| 166 | + >>> config.parallelism.dp_degree |
| 167 | + 4 |
| 168 | + """ |
| 169 | + |
| 170 | + model_name: str |
| 171 | + model_config: dict[str, Any] |
| 172 | + parallelism: ParallelismConfig |
| 173 | + |
| 174 | + |
| 175 | +@dataclass |
| 176 | +class TrainerStatus: |
| 177 | + """Runtime status of the trainer. |
| 178 | +
|
| 179 | + This contains dynamic information about the trainer's current state that |
| 180 | + changes during training. |
| 181 | +
|
| 182 | + Attributes: |
| 183 | + step: Current training step |
| 184 | + accumulated_microbatches: Number of batches accumulated since the last |
| 185 | + optim_step. Will be 0 if gradients were just applied/cleared. |
| 186 | +
|
| 187 | + Example: |
| 188 | + >>> status = await trainer.get_status() |
| 189 | + >>> status.step |
| 190 | + 1000 |
| 191 | + >>> status.accumulated_microbatches |
| 192 | + 2 |
| 193 | + """ |
| 194 | + |
| 195 | + step: int |
| 196 | + accumulated_microbatches: int |
19 | 197 |
|
20 | 198 |
|
21 | 199 | @runtime_checkable |
@@ -57,7 +235,7 @@ async def forward_backward( |
57 | 235 |
|
58 | 236 | Args: |
59 | 237 | batch: TextTrainBatch containing input_ids, target_ids, and optional |
60 | | - target_mask/target_weights. See forge.api.types.TextTrainBatch for details. |
| 238 | + target_mask/target_weights. See forge.api.trainer.TextTrainBatch for details. |
61 | 239 | loss_fn: Optional custom loss function. If None, uses the loss function |
62 | 240 | configured at trainer creation. Signature: (outputs, batch) -> loss |
63 | 241 | where outputs is a dict with at least "logits" key. |
|
0 commit comments