2424from forge .observability .metrics import record_metric , Reduce
2525from forge .observability .perf_tracker import Tracer
2626from forge .rl import collate , ComputeAdvantages , Episode , RewardActor
27+ from forge .rl .loss import GRPOLoss
2728from forge .types import LauncherConfig , ProvisionerConfig
2829from forge .util .checkpoint import drop_weights
2930from forge .util .config import parse
3031from forge .util .logging import get_logger
31- from forge .util .ops import compute_logprobs
3232from omegaconf import DictConfig , OmegaConf
3333
3434logger = get_logger ("INFO" )
3535
3636
37- # TODO (T245547773): Consolidate with SimpleGRPOLoss in losses/grpo_loss.py
38- # Currently duplicated because of function signature differences:
39- # - This function takes logits + response, computes logprobs internally
40- # - SimpleGRPOLoss takes pre-computed logprobs
41- # - TitanTrainer passes logits, so would need wrapper or signature change
42- # Consider refactoring TitanTrainer's loss interface to standardize this.
43- def simple_grpo_loss (
44- logits : torch .Tensor ,
45- response : torch .Tensor ,
46- ref_logprobs : torch .Tensor ,
47- advantages : torch .Tensor ,
48- padding_mask : torch .Tensor ,
49- beta : float = 1e-6 ,
50- ) -> torch .Tensor :
51- logprobs : torch .Tensor = compute_logprobs (logits , response )
52- kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
53- per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
54-
55- # Compute mean KL per valid token
56- mean_kl = (
57- ((kl * padding_mask ).sum (dim = 1 )) / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
58- ).mean ()
59-
60- # Compute mean policy loss per valid token
61- mean_policy_loss = (
62- ((per_token_policy_loss * padding_mask ).sum (dim = 1 ))
63- / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
64- ).mean ()
65-
66- # Compute loss using the means (mathematically equivalent)
67- loss = - (mean_policy_loss - beta * mean_kl )
68-
69- # Log metrics
70- # TODO: Better design - have loss function return all metrics as a dict,
71- # then record them in rl_trainer so all training metrics are in one namespace
72- # and we avoid doing .item here, which is not compile friendly
73- record_metric ("grpo_loss/kl_divergence_mean" , mean_kl .item (), Reduce .MEAN )
74- record_metric (
75- "grpo_loss/kl_divergence_max" , (kl * padding_mask ).max ().item (), Reduce .MAX
76- )
77- record_metric (
78- "grpo_loss/policy_gradient_loss" , mean_policy_loss .item (), Reduce .MEAN
79- )
80- record_metric ("grpo_loss/total_loss" , loss .item (), Reduce .MEAN )
81- record_metric ("grpo_loss/advantage_mean" , advantages .mean ().item (), Reduce .MEAN )
82- record_metric ("grpo_loss/advantage_std" , advantages .std ().item (), Reduce .MEAN )
83- return loss
84-
85-
8637async def main (cfg : DictConfig ):
8738 """Main GRPO training loop with rollout and training processes."""
8839 # Convert OmegaConf config to plain dict
@@ -116,8 +67,32 @@ async def main(cfg: DictConfig):
11667 backend_config = metric_logging_cfg , run_config = run_config_for_logging
11768 )
11869
70+ # ---- Setup loss function ---- #
71+ loss_fn = GRPOLoss (
72+ clip_low = 0.2 ,
73+ clip_high = 0.28 ,
74+ beta = 0.1 ,
75+ agg_type = "fixed_horizon" ,
76+ )
77+
78+ # Fail-fast: Check loss/ref_model compatibility before spawning actors
79+ uses_ref_model = cfg .get ("services" , {}).get ("ref_model" ) is not None
80+ if uses_ref_model and not isinstance (loss_fn , GRPOLoss ):
81+ raise ValueError (
82+ f"ref_model is configured but { type (loss_fn ).__name__ } does not use ref_logprobs. "
83+ "Either remove the ref_model service config or use GRPOLoss with beta > 0."
84+ )
85+ if isinstance (loss_fn , GRPOLoss ) and loss_fn .beta > 0 and not uses_ref_model :
86+ raise ValueError (
87+ f"GRPOLoss with beta={ loss_fn .beta } requires ref_logprobs, but ref_model is not configured. "
88+ "Either add ref_model to services config or set beta=0."
89+ )
90+
11991 # ---- Setup services ---- #
12092
93+ async def noop ():
94+ return None
95+
12196 (
12297 dataloader ,
12398 generator ,
@@ -130,13 +105,17 @@ async def main(cfg: DictConfig):
130105 DatasetActor .options (** cfg .actors .dataset ).as_actor (** cfg .dataset ),
131106 Generator .options (** cfg .services .generator ).as_service (** cfg .generator ),
132107 TitanTrainer .options (** cfg .actors .trainer ).as_actor (
133- ** cfg .trainer , loss = simple_grpo_loss
108+ ** cfg .trainer , loss = loss_fn
134109 ),
135110 ReplayBuffer .options (** cfg .actors .replay_buffer ).as_actor (
136111 ** cfg .replay_buffer , collate = collate
137112 ),
138113 ComputeAdvantages .options (** cfg .actors .compute_advantages ).as_actor (),
139- ReferenceModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model ),
114+ (
115+ ReferenceModel .options (** cfg .services .ref_model ).as_service (** cfg .ref_model )
116+ if uses_ref_model
117+ else noop ()
118+ ),
140119 RewardActor .options (** cfg .services .reward_actor ).as_service (
141120 reward_functions = [MathReward (), ThinkingReward ()]
142121 ),
@@ -187,7 +166,34 @@ async def continuous_rollouts():
187166 (group_size , max_req_tokens + max_res_tokens ),
188167 dtype = torch .long ,
189168 )
169+ seq_len = max_req_tokens + max_res_tokens
170+
190171 for i , response in enumerate (responses ):
172+ # Validate logprobs exist
173+ if response .logprobs is None :
174+ raise ValueError (
175+ "Completion.logprobs is None. "
176+ "Ensure Generator returns logprobs by setting 'logprobs: 1' in sampling_params config."
177+ )
178+
179+ # Prepare generator_logprobs
180+ # Shift by -1 to align with next-token prediction
181+ actual_response_len = response .token_ids .shape [0 ]
182+ generator_logprobs = torch .zeros (seq_len , dtype = response .logprobs .dtype )
183+ generator_logprobs [
184+ max_req_tokens : max_req_tokens + actual_response_len
185+ ] = response .logprobs
186+ generator_logprobs = torch .roll (generator_logprobs , shifts = - 1 , dims = 0 )
187+ generator_logprobs [- 1 ] = 0.0
188+
189+ # Prepare loss_mask
190+ response_mask = torch .zeros (seq_len , dtype = torch .float32 )
191+ response_mask [max_req_tokens : max_req_tokens + actual_response_len ] = (
192+ 1.0
193+ )
194+ loss_mask = torch .roll (response_mask , shifts = - 1 , dims = 0 )
195+ loss_mask [- 1 ] = 0.0
196+
191197 episode = Episode (
192198 episode_id = str (uuid .uuid4 ()),
193199 pad_id = pad_id ,
@@ -197,7 +203,10 @@ async def continuous_rollouts():
197203 request = prompt ,
198204 response = response .text ,
199205 completion = response ,
206+ generator_logprobs = generator_logprobs ,
207+ loss_mask = loss_mask ,
200208 )
209+
201210 (
202211 episode .reward_breakdown ,
203212 episode .reward ,
@@ -263,21 +272,33 @@ async def continuous_rollouts():
263272
264273 t .step ("reward_evaluation" )
265274
266- ref_logprobs = await ref_model .forward .route (
267- input_ids , max_req_tokens , return_logprobs = True
268- )
269- t .step ("reference_model_calculate_logprobs" )
275+ # Compute ref_logprobs only if ref_model is configured
276+ if ref_model is not None :
277+ ref_logprobs = await ref_model .forward .route (
278+ input_ids , return_logprobs = True
279+ )
280+ t .step ("reference_model_calculate_logprobs" )
281+
282+ for i , episode in enumerate (episodes ):
283+ episode .ref_logprobs = ref_logprobs [i ] # [seq_len]
270284
271- for i , episode in enumerate ( episodes ):
272- episode . ref_logprobs = ref_logprobs [ i ]
273- del ref_logprobs , input_ids
285+ del ref_logprobs
286+
287+ del input_ids
274288
275289 advantages = await compute_advantages .compute .call_one (episodes )
276290 for episode , advantage in zip (episodes , advantages ):
277291 episode .advantage = advantage
278292 await replay_buffer .add .call_one (episode )
279293
280- sample = episode .to_dict (exclude = ["ref_logprobs" , "completion" ])
294+ sample = episode .to_dict (
295+ exclude = [
296+ "completion" ,
297+ "loss_mask" ,
298+ "generator_logprobs" ,
299+ "ref_logprobs" ,
300+ ]
301+ )
281302 sample ["score" ] = sample ["reward" ]
282303 record_metric (
283304 "main_samples/continuous_rollouts/sample_table" ,
0 commit comments