|
7 | 7 | # Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml |
8 | 8 |
|
9 | 9 | import asyncio |
10 | | -import time |
11 | 10 | import uuid |
12 | 11 | from dataclasses import dataclass |
13 | | -from typing import Any, Callable |
14 | 12 |
|
15 | 13 | import torch |
16 | | -import torch.nn.functional as F |
17 | 14 | import torchstore as ts |
18 | 15 | import yaml |
19 | 16 | from datasets import load_dataset |
20 | | -from forge.actors._torchstore_utils import ( |
21 | | - get_dcp_whole_state_dict_key, |
22 | | - get_param_prefix, |
23 | | -) |
24 | | -from forge.actors.generator import Generator |
25 | 17 | from forge.actors.reference_model import ReferenceModel |
26 | 18 | from forge.actors.replay_buffer import ReplayBuffer |
27 | 19 | from forge.actors.trainer import TitanTrainer |
|
32 | 24 | from forge.observability.metric_actors import get_or_create_metric_logger |
33 | 25 | from forge.observability.metrics import record_metric, Reduce |
34 | 26 | from forge.observability.perf_tracker import Tracer |
| 27 | +from forge.rl import collate, ComputeAdvantages, Episode, Policy, RewardActor |
35 | 28 | from forge.types import LauncherConfig, ProvisionerConfig |
| 29 | +from forge.util.checkpoint import drop_weights |
36 | 30 | from forge.util.config import parse |
37 | 31 | from forge.util.logging import get_logger |
38 | 32 | from forge.util.ops import compute_logprobs |
|
43 | 37 | logger = get_logger("INFO") |
44 | 38 |
|
45 | 39 |
|
46 | | -@dataclass |
47 | | -class Episode: |
48 | | - episode_id: str |
49 | | - pad_id: int |
50 | | - request_len: int |
51 | | - response_len: int |
52 | | - target: Any | None = None |
53 | | - request: str | None = None |
54 | | - response: str | None = None |
55 | | - # Processed data |
56 | | - completion: Completion | None = None |
57 | | - ref_logprobs: torch.Tensor | None = None |
58 | | - reward: float | None = None |
59 | | - reward_breakdown: dict[str, float] | None = None |
60 | | - advantage: float | None = None |
61 | | - |
62 | | - @property |
63 | | - def policy_version(self) -> int | None: |
64 | | - return self.completion.generator_version |
65 | | - |
66 | | - @property |
67 | | - def request_tensor(self) -> torch.Tensor: |
68 | | - tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long) |
69 | | - if tensor.shape[0] < self.request_len: # left pad |
70 | | - diff = self.request_len - tensor.shape[0] |
71 | | - tensor = F.pad(tensor, (diff, 0), value=self.pad_id) |
72 | | - return tensor |
73 | | - |
74 | | - @property |
75 | | - def response_tensor(self) -> torch.Tensor: |
76 | | - tensor: torch.Tensor = self.completion.token_ids.to(torch.long) |
77 | | - if tensor.shape[0] < self.response_len: # right pad |
78 | | - diff = self.response_len - tensor.shape[0] |
79 | | - tensor = F.pad(tensor, (0, diff), value=self.pad_id) |
80 | | - return tensor |
81 | | - |
82 | | - def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]: |
83 | | - """Convert episode to dict, optionally excluding specified fields.""" |
84 | | - result = { |
85 | | - "episode_id": self.episode_id, |
86 | | - "policy_version": self.policy_version, |
87 | | - "prompt": self.request, |
88 | | - "response": self.response, |
89 | | - "target": str(self.target), |
90 | | - "reward": self.reward, |
91 | | - "advantage": self.advantage, |
92 | | - "request_len": self.request_len, |
93 | | - "response_len": self.response_len, |
94 | | - "pad_id": self.pad_id, |
95 | | - "ref_logprobs": self.ref_logprobs, |
96 | | - "completion": self.completion, |
97 | | - } |
98 | | - |
99 | | - if self.reward_breakdown is not None and "reward_breakdown" not in exclude: |
100 | | - result.update(self.reward_breakdown) |
101 | | - |
102 | | - if exclude: |
103 | | - for key in exclude: |
104 | | - result.pop(key, None) |
105 | | - |
106 | | - return result |
107 | | - |
108 | | - |
109 | | -# Represents the group (G) of episodes in GRPO |
110 | | -Group = list[Episode] |
111 | | - |
112 | | -# Represents the Policy Model to collect data from |
113 | | -Policy = Generator |
114 | | - |
115 | | - |
116 | | -def collate( |
117 | | - batches: list[Group], |
118 | | -) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: |
119 | | - """ |
120 | | - Collates a list of batches into a single batch of inputs and targets. |
121 | | - Each batch is a list of episodes, and each episode is a dict of tensors. |
122 | | - """ |
123 | | - inputs = [] |
124 | | - targets = [] |
125 | | - for batch in batches: |
126 | | - request = [e.request_tensor for e in batch] |
127 | | - request = torch.stack(request) # [b x s] |
128 | | - |
129 | | - response = [e.response_tensor for e in batch] |
130 | | - response = torch.stack(response) # [b x s] |
131 | | - |
132 | | - ref_logprobs = [e.ref_logprobs for e in batch] |
133 | | - ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] |
134 | | - |
135 | | - advantages = [e.advantage for e in batch] |
136 | | - advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] |
137 | | - |
138 | | - pad_id = batch[0].pad_id |
139 | | - mask = response != pad_id |
140 | | - |
141 | | - input = {"tokens": torch.cat([request, response], dim=1)} |
142 | | - target = { |
143 | | - "response": response, |
144 | | - "ref_logprobs": ref_logprobs, |
145 | | - "advantages": advantages, |
146 | | - "padding_mask": mask, |
147 | | - } |
148 | | - inputs.append(input) |
149 | | - targets.append(target) |
150 | | - return inputs, targets |
151 | | - |
152 | | - |
153 | 40 | # TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py |
154 | 41 | # Currently duplicated because of function signature differences: |
155 | 42 | # - This function takes logits + response, computes logprobs internally |
@@ -199,60 +86,6 @@ def simple_grpo_loss( |
199 | 86 | return loss |
200 | 87 |
|
201 | 88 |
|
202 | | -@dataclass |
203 | | -class RewardActor(ForgeActor): |
204 | | - reward_functions: list[Callable] |
205 | | - |
206 | | - @endpoint |
207 | | - async def evaluate_response( |
208 | | - self, prompt: str, response: str, target: str |
209 | | - ) -> (dict[str, float], float): |
210 | | - total_rewards = 0.0 |
211 | | - reward_breakdown = {} # reward breakdown by function |
212 | | - for reward_fn in self.reward_functions: |
213 | | - reward = reward_fn(prompt, response, target) |
214 | | - total_rewards += reward |
215 | | - |
216 | | - # Get a name for the reward function (works for classes, functions, lambdas) |
217 | | - reward_fn_name = getattr( |
218 | | - reward_fn, "__name__", reward_fn.__class__.__name__ |
219 | | - ) |
220 | | - reward_breakdown[reward_fn_name] = reward |
221 | | - |
222 | | - # log per fn reward and avg total |
223 | | - record_metric( |
224 | | - f"reward/evaluate_response/avg_{reward_fn_name}_reward", |
225 | | - reward, |
226 | | - Reduce.MEAN, |
227 | | - ) |
228 | | - record_metric( |
229 | | - f"reward/evaluate_response/std_{reward_fn_name}_reward", |
230 | | - reward, |
231 | | - Reduce.STD, |
232 | | - ) |
233 | | - |
234 | | - record_metric( |
235 | | - "reward/evaluate_response/avg_total_reward", |
236 | | - reward, |
237 | | - Reduce.MEAN, |
238 | | - ) |
239 | | - |
240 | | - avg_reward: float = total_rewards / len(self.reward_functions) |
241 | | - return reward_breakdown, avg_reward |
242 | | - |
243 | | - |
244 | | -@dataclass |
245 | | -class ComputeAdvantages(ForgeActor): |
246 | | - @endpoint |
247 | | - async def compute(self, group: Group) -> list[float]: |
248 | | - # TODO: add batch processing |
249 | | - rewards = torch.tensor([[e.reward for e in group]]) |
250 | | - mean = rewards.mean(1, keepdim=True) |
251 | | - std = rewards.std(1, keepdim=True) |
252 | | - advantages = (rewards - mean) / (std + 1e-4) |
253 | | - return advantages.squeeze(0).tolist() |
254 | | - |
255 | | - |
256 | 89 | @dataclass |
257 | 90 | class DatasetActor(ForgeActor): |
258 | 91 | """Actor wrapper for HuggingFace dataset to provide async interface.""" |
@@ -324,23 +157,6 @@ async def pad_token(self): |
324 | 157 | return self._tokenizer.eos_token_id |
325 | 158 |
|
326 | 159 |
|
327 | | -async def drop_weights(version: int): |
328 | | - print(f"Dropping weights @ version {version}") |
329 | | - start_time = time.perf_counter() |
330 | | - prefix = get_param_prefix(version) |
331 | | - matching_keys = await ts.keys(prefix) |
332 | | - # TODO: once we have something like `get_meta()` in torchstore, we can just |
333 | | - # query the type of the object instead of relying on keys. |
334 | | - dcp_key = get_dcp_whole_state_dict_key(version) |
335 | | - if dcp_key in matching_keys: |
336 | | - dcp_handle = await ts.get(dcp_key) |
337 | | - dcp_handle.drop() |
338 | | - for key in matching_keys: |
339 | | - await ts.delete(key) |
340 | | - elapsed = time.perf_counter() - start_time |
341 | | - print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds") |
342 | | - |
343 | | - |
344 | 160 | async def main(cfg: DictConfig): |
345 | 161 | """Main GRPO training loop with rollout and training processes.""" |
346 | 162 | # Convert OmegaConf config to plain dict |
|
0 commit comments