Skip to content

Commit c72fa66

Browse files
felipemello1Felipe Mello
andauthored
[refactor 1/n] - move utilities out of main.py to src/ (#635)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent 374a114 commit c72fa66

File tree

8 files changed

+265
-855
lines changed

8 files changed

+265
-855
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,13 @@
77
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88

99
import asyncio
10-
import time
1110
import uuid
1211
from dataclasses import dataclass
13-
from typing import Any, Callable
1412

1513
import torch
16-
import torch.nn.functional as F
1714
import torchstore as ts
1815
import yaml
1916
from datasets import load_dataset
20-
from forge.actors._torchstore_utils import (
21-
get_dcp_whole_state_dict_key,
22-
get_param_prefix,
23-
)
24-
from forge.actors.generator import Generator
2517
from forge.actors.reference_model import ReferenceModel
2618
from forge.actors.replay_buffer import ReplayBuffer
2719
from forge.actors.trainer import TitanTrainer
@@ -32,7 +24,9 @@
3224
from forge.observability.metric_actors import get_or_create_metric_logger
3325
from forge.observability.metrics import record_metric, Reduce
3426
from forge.observability.perf_tracker import Tracer
27+
from forge.rl import collate, ComputeAdvantages, Episode, Policy, RewardActor
3528
from forge.types import LauncherConfig, ProvisionerConfig
29+
from forge.util.checkpoint import drop_weights
3630
from forge.util.config import parse
3731
from forge.util.logging import get_logger
3832
from forge.util.ops import compute_logprobs
@@ -43,113 +37,6 @@
4337
logger = get_logger("INFO")
4438

4539

46-
@dataclass
47-
class Episode:
48-
episode_id: str
49-
pad_id: int
50-
request_len: int
51-
response_len: int
52-
target: Any | None = None
53-
request: str | None = None
54-
response: str | None = None
55-
# Processed data
56-
completion: Completion | None = None
57-
ref_logprobs: torch.Tensor | None = None
58-
reward: float | None = None
59-
reward_breakdown: dict[str, float] | None = None
60-
advantage: float | None = None
61-
62-
@property
63-
def policy_version(self) -> int | None:
64-
return self.completion.generator_version
65-
66-
@property
67-
def request_tensor(self) -> torch.Tensor:
68-
tensor: torch.Tensor = self.completion.prompt_ids.to(torch.long)
69-
if tensor.shape[0] < self.request_len: # left pad
70-
diff = self.request_len - tensor.shape[0]
71-
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
72-
return tensor
73-
74-
@property
75-
def response_tensor(self) -> torch.Tensor:
76-
tensor: torch.Tensor = self.completion.token_ids.to(torch.long)
77-
if tensor.shape[0] < self.response_len: # right pad
78-
diff = self.response_len - tensor.shape[0]
79-
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
80-
return tensor
81-
82-
def to_dict(self, exclude: list[str] | None = None) -> dict[str, Any]:
83-
"""Convert episode to dict, optionally excluding specified fields."""
84-
result = {
85-
"episode_id": self.episode_id,
86-
"policy_version": self.policy_version,
87-
"prompt": self.request,
88-
"response": self.response,
89-
"target": str(self.target),
90-
"reward": self.reward,
91-
"advantage": self.advantage,
92-
"request_len": self.request_len,
93-
"response_len": self.response_len,
94-
"pad_id": self.pad_id,
95-
"ref_logprobs": self.ref_logprobs,
96-
"completion": self.completion,
97-
}
98-
99-
if self.reward_breakdown is not None and "reward_breakdown" not in exclude:
100-
result.update(self.reward_breakdown)
101-
102-
if exclude:
103-
for key in exclude:
104-
result.pop(key, None)
105-
106-
return result
107-
108-
109-
# Represents the group (G) of episodes in GRPO
110-
Group = list[Episode]
111-
112-
# Represents the Policy Model to collect data from
113-
Policy = Generator
114-
115-
116-
def collate(
117-
batches: list[Group],
118-
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
119-
"""
120-
Collates a list of batches into a single batch of inputs and targets.
121-
Each batch is a list of episodes, and each episode is a dict of tensors.
122-
"""
123-
inputs = []
124-
targets = []
125-
for batch in batches:
126-
request = [e.request_tensor for e in batch]
127-
request = torch.stack(request) # [b x s]
128-
129-
response = [e.response_tensor for e in batch]
130-
response = torch.stack(response) # [b x s]
131-
132-
ref_logprobs = [e.ref_logprobs for e in batch]
133-
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
134-
135-
advantages = [e.advantage for e in batch]
136-
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
137-
138-
pad_id = batch[0].pad_id
139-
mask = response != pad_id
140-
141-
input = {"tokens": torch.cat([request, response], dim=1)}
142-
target = {
143-
"response": response,
144-
"ref_logprobs": ref_logprobs,
145-
"advantages": advantages,
146-
"padding_mask": mask,
147-
}
148-
inputs.append(input)
149-
targets.append(target)
150-
return inputs, targets
151-
152-
15340
# TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
15441
# Currently duplicated because of function signature differences:
15542
# - This function takes logits + response, computes logprobs internally
@@ -199,60 +86,6 @@ def simple_grpo_loss(
19986
return loss
20087

20188

202-
@dataclass
203-
class RewardActor(ForgeActor):
204-
reward_functions: list[Callable]
205-
206-
@endpoint
207-
async def evaluate_response(
208-
self, prompt: str, response: str, target: str
209-
) -> (dict[str, float], float):
210-
total_rewards = 0.0
211-
reward_breakdown = {} # reward breakdown by function
212-
for reward_fn in self.reward_functions:
213-
reward = reward_fn(prompt, response, target)
214-
total_rewards += reward
215-
216-
# Get a name for the reward function (works for classes, functions, lambdas)
217-
reward_fn_name = getattr(
218-
reward_fn, "__name__", reward_fn.__class__.__name__
219-
)
220-
reward_breakdown[reward_fn_name] = reward
221-
222-
# log per fn reward and avg total
223-
record_metric(
224-
f"reward/evaluate_response/avg_{reward_fn_name}_reward",
225-
reward,
226-
Reduce.MEAN,
227-
)
228-
record_metric(
229-
f"reward/evaluate_response/std_{reward_fn_name}_reward",
230-
reward,
231-
Reduce.STD,
232-
)
233-
234-
record_metric(
235-
"reward/evaluate_response/avg_total_reward",
236-
reward,
237-
Reduce.MEAN,
238-
)
239-
240-
avg_reward: float = total_rewards / len(self.reward_functions)
241-
return reward_breakdown, avg_reward
242-
243-
244-
@dataclass
245-
class ComputeAdvantages(ForgeActor):
246-
@endpoint
247-
async def compute(self, group: Group) -> list[float]:
248-
# TODO: add batch processing
249-
rewards = torch.tensor([[e.reward for e in group]])
250-
mean = rewards.mean(1, keepdim=True)
251-
std = rewards.std(1, keepdim=True)
252-
advantages = (rewards - mean) / (std + 1e-4)
253-
return advantages.squeeze(0).tolist()
254-
255-
25689
@dataclass
25790
class DatasetActor(ForgeActor):
25891
"""Actor wrapper for HuggingFace dataset to provide async interface."""
@@ -324,23 +157,6 @@ async def pad_token(self):
324157
return self._tokenizer.eos_token_id
325158

326159

327-
async def drop_weights(version: int):
328-
print(f"Dropping weights @ version {version}")
329-
start_time = time.perf_counter()
330-
prefix = get_param_prefix(version)
331-
matching_keys = await ts.keys(prefix)
332-
# TODO: once we have something like `get_meta()` in torchstore, we can just
333-
# query the type of the object instead of relying on keys.
334-
dcp_key = get_dcp_whole_state_dict_key(version)
335-
if dcp_key in matching_keys:
336-
dcp_handle = await ts.get(dcp_key)
337-
dcp_handle.drop()
338-
for key in matching_keys:
339-
await ts.delete(key)
340-
elapsed = time.perf_counter() - start_time
341-
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
342-
343-
344160
async def main(cfg: DictConfig):
345161
"""Main GRPO training loop with rollout and training processes."""
346162
# Convert OmegaConf config to plain dict

0 commit comments

Comments
 (0)