Skip to content

Commit 21f20ca

Browse files
authored
Move all trainer types into trainer.py (#684)
1 parent 8fe8742 commit 21f20ca

File tree

3 files changed

+190
-208
lines changed

3 files changed

+190
-208
lines changed

src/forge/api/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
This module defines the public interfaces that all Forge implementations conform to.
1010
"""
1111

12-
from forge.api.trainer import Trainer
13-
from forge.api.types import (
12+
from forge.api.trainer import (
1413
ForwardBackwardResult,
1514
LossFn,
1615
OptimStepResult,
1716
ParallelismConfig,
1817
TextTrainBatch,
18+
Trainer,
1919
TrainerConfig,
2020
TrainerStatus,
2121
)

src/forge/api/trainer.py

Lines changed: 188 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,196 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Protocol, runtime_checkable
7+
"""Type definitions and trainer protocol for the Forge API."""
8+
9+
from dataclasses import dataclass
10+
from typing import Any, Callable, Protocol, runtime_checkable, TypeAlias
811

912
import torch
1013

11-
from forge.api.types import (
12-
ForwardBackwardResult,
13-
LossFn,
14-
OptimStepResult,
15-
TextTrainBatch,
16-
TrainerConfig,
17-
TrainerStatus,
18-
)
14+
15+
# Loss function signature: takes model outputs (as dict) and batch, returns scalar loss
16+
# The dict will typically contain logits, but may include other keys depending on use case.
17+
LossFn: TypeAlias = Callable[[dict[str, Any], "TextTrainBatch"], torch.Tensor]
18+
19+
20+
@dataclass
21+
class TextTrainBatch:
22+
"""A batch of text training data for forward_backward.
23+
24+
This dataclass defines the standard format for text training batches across all
25+
Forge text trainers.
26+
27+
Attributes:
28+
input_ids: Input token IDs. Shape: [batch_size, seq_len]
29+
target_ids: Target token IDs for loss computation. Shape: [batch_size, seq_len]
30+
target_mask: Mask indicating which tokens to compute loss on.
31+
Shape: [batch_size, seq_len]. Values are 0 (ignore) or 1 (compute loss).
32+
If None, computes loss on all tokens.
33+
target_weights: Per-token weights for loss computation.
34+
Shape: [batch_size, seq_len]. Used for importance weighting, such as
35+
advantages in RL (GRPO, PPO) or custom loss weighting schemes.
36+
If None, all tokens have weight 1.0.
37+
38+
Example:
39+
>>> batch = TextTrainBatch(
40+
>>> input_ids=torch.tensor([[1, 2, 3, 4, 5]]),
41+
>>> target_ids=torch.tensor([[2, 3, 4, 5, 6]]),
42+
>>> target_mask=torch.tensor([[0, 0, 1, 1, 1]]), # Only predict last 3 tokens
43+
>>> target_weights=torch.tensor([[0, 0, 1.0, 0.8, 1.2]]), # Weight by advantage
44+
>>> )
45+
>>> result = await trainer.forward_backward(batch)
46+
"""
47+
48+
input_ids: torch.Tensor
49+
target_ids: torch.Tensor
50+
target_mask: torch.Tensor | None = None
51+
target_weights: torch.Tensor | None = None
52+
53+
54+
@dataclass
55+
class ForwardBackwardResult:
56+
"""Result from a forward_backward pass.
57+
58+
Attributes:
59+
loss: Loss value computed for the batch
60+
metrics: Additional metrics computed during training (e.g., perplexity,
61+
accuracy, KL divergence). May be empty if no additional metrics are tracked.
62+
Values can be scalars, tensors, or other structured data depending on the loss.
63+
64+
Example:
65+
>>> result = await trainer.forward_backward(batch)
66+
>>> result.loss
67+
0.3542
68+
>>> result.metrics
69+
{"perplexity": 1.42, "kl_divergence": 0.05}
70+
"""
71+
72+
loss: float
73+
metrics: dict[str, Any]
74+
75+
76+
@dataclass
77+
class OptimStepResult:
78+
"""Result from an optimizer step.
79+
80+
Attributes:
81+
step: Training step number after this optimizer step
82+
learning_rate: Current learning rate used for this step
83+
accumulated_microbatches: Number of forward_backward calls that were
84+
accumulated before this optimizer step. Useful for tracking gradient
85+
accumulation behavior.
86+
87+
Example:
88+
>>> result = await trainer.optim_step()
89+
>>> result.step
90+
1000
91+
>>> result.learning_rate
92+
0.0001
93+
>>> result.accumulated_microbatches
94+
4
95+
"""
96+
97+
step: int
98+
learning_rate: float
99+
accumulated_microbatches: int
100+
101+
102+
@dataclass
103+
class ParallelismConfig:
104+
"""Parallelism configuration for distributed training.
105+
106+
Attributes:
107+
dp_degree: Data parallel degree (number of data parallel replicas)
108+
tp_degree: Tensor parallel degree (model sharding across devices)
109+
pp_degree: Pipeline parallel degree (model sharding across pipeline stages)
110+
cp_degree: Context parallel degree (sequence parallelism for long contexts)
111+
ep_degree: Expert parallel degree (for MoE models)
112+
world_size: Total number of processes in the distributed training job
113+
dp_rank: Current data parallel rank (0 to dp_degree-1)
114+
tp_rank: Current tensor parallel rank (0 to tp_degree-1)
115+
device: Device identifier (e.g., "cuda:0", "cuda:1")
116+
117+
Example:
118+
>>> config = await trainer.get_config()
119+
>>> config.parallelism.dp_degree
120+
4
121+
>>> config.parallelism.tp_degree
122+
2
123+
>>> config.parallelism.pp_degree
124+
1
125+
>>> config.parallelism.cp_degree
126+
1
127+
>>> config.parallelism.ep_degree
128+
1
129+
>>> config.parallelism.device
130+
"cuda:0"
131+
"""
132+
133+
dp_degree: int
134+
tp_degree: int
135+
pp_degree: int
136+
cp_degree: int
137+
ep_degree: int
138+
world_size: int
139+
dp_rank: int
140+
tp_rank: int
141+
device: str
142+
143+
144+
@dataclass
145+
class TrainerConfig:
146+
"""Static trainer and model configuration.
147+
148+
This contains configuration information that doesn't change during training.
149+
150+
Attributes:
151+
model_name: Name or path of the model being trained
152+
model_config: Model architecture configuration. Common keys include:
153+
- vocab_size: int - Size of the vocabulary
154+
- hidden_size: int - Hidden dimension size
155+
- num_layers: int - Number of transformer layers
156+
- num_attention_heads: int - Number of attention heads
157+
- max_seq_len: int - Maximum sequence length
158+
parallelism: Parallelism configuration for distributed training
159+
160+
Example:
161+
>>> config = await trainer.get_config()
162+
>>> config.model_name
163+
"Qwen/Qwen2.5-7B"
164+
>>> config.model_config["vocab_size"]
165+
151936
166+
>>> config.parallelism.dp_degree
167+
4
168+
"""
169+
170+
model_name: str
171+
model_config: dict[str, Any]
172+
parallelism: ParallelismConfig
173+
174+
175+
@dataclass
176+
class TrainerStatus:
177+
"""Runtime status of the trainer.
178+
179+
This contains dynamic information about the trainer's current state that
180+
changes during training.
181+
182+
Attributes:
183+
step: Current training step
184+
accumulated_microbatches: Number of batches accumulated since the last
185+
optim_step. Will be 0 if gradients were just applied/cleared.
186+
187+
Example:
188+
>>> status = await trainer.get_status()
189+
>>> status.step
190+
1000
191+
>>> status.accumulated_microbatches
192+
2
193+
"""
194+
195+
step: int
196+
accumulated_microbatches: int
19197

20198

21199
@runtime_checkable
@@ -57,7 +235,7 @@ async def forward_backward(
57235
58236
Args:
59237
batch: TextTrainBatch containing input_ids, target_ids, and optional
60-
target_mask/target_weights. See forge.api.types.TextTrainBatch for details.
238+
target_mask/target_weights. See forge.api.trainer.TextTrainBatch for details.
61239
loss_fn: Optional custom loss function. If None, uses the loss function
62240
configured at trainer creation. Signature: (outputs, batch) -> loss
63241
where outputs is a dict with at least "logits" key.

0 commit comments

Comments
 (0)