Skip to content

Commit 57aa6a1

Browse files
felipemello1Felipe Mello
andauthored
[logging] clean up 1/n (#606)
Co-authored-by: Felipe Mello <felipemello@fb.com>
1 parent d983498 commit 57aa6a1

File tree

5 files changed

+67
-104
lines changed

5 files changed

+67
-104
lines changed

apps/grpo/main.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,19 @@ def simple_grpo_loss(
183183
loss = -(mean_policy_loss - beta * mean_kl)
184184

185185
# Log metrics
186+
# TODO: Better design - have loss function return all metrics as a dict,
187+
# then record them in rl_trainer so all training metrics are in one namespace
188+
# and we avoid doing .item here, which is not compile friendly
186189
record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
187190
record_metric(
188191
"grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
189192
)
190-
record_metric("grpo_loss/policy_loss", mean_policy_loss.item(), Reduce.MEAN)
193+
record_metric(
194+
"grpo_loss/policy_gradient_loss", mean_policy_loss.item(), Reduce.MEAN
195+
)
196+
record_metric("grpo_loss/total_loss", loss.item(), Reduce.MEAN)
191197
record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
192198
record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)
193-
194199
return loss
195200

196201

@@ -213,12 +218,8 @@ async def evaluate_response(
213218
reward_fn, "__name__", reward_fn.__class__.__name__
214219
)
215220
reward_breakdown[reward_fn_name] = reward
216-
# per function reward
217-
record_metric(
218-
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
219-
reward,
220-
Reduce.SUM,
221-
)
221+
222+
# log per fn reward and avg total
222223
record_metric(
223224
f"reward/evaluate_response/avg_{reward_fn_name}_reward",
224225
reward,
@@ -236,12 +237,6 @@ async def evaluate_response(
236237
Reduce.MEAN,
237238
)
238239

239-
record_metric(
240-
f"reward/evaluate_response/count_{reward_fn_name}_calls",
241-
1,
242-
Reduce.SUM,
243-
)
244-
245240
avg_reward: float = total_rewards / len(self.reward_functions)
246241
return reward_breakdown, avg_reward
247242

@@ -304,17 +299,6 @@ async def sample(self) -> dict[str, str] | None:
304299
try:
305300
sample = next(self._iterator)
306301

307-
record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
308-
record_metric(
309-
"dataset/sample/avg_sample_len",
310-
len(sample["request"]),
311-
Reduce.MEAN,
312-
)
313-
record_metric(
314-
"dataset/sample/max_sample_len",
315-
len(sample["request"]),
316-
Reduce.MAX,
317-
)
318302
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)
319303

320304
return sample
@@ -443,8 +427,6 @@ async def continuous_rollouts():
443427
print("Dataloader is empty, exiting continuous rollout")
444428
return
445429

446-
t.step("data_loading")
447-
448430
prompt, target = sample["request"], sample["target"]
449431
responses: list[Completion] = await policy.generate.route(prompt)
450432
t.step("policy_generation")
@@ -478,18 +460,53 @@ async def continuous_rollouts():
478460
input_ids[i, :max_req_tokens] = episode.request_tensor
479461
input_ids[i, max_req_tokens:] = episode.response_tensor
480462

463+
# Track token-based metrics
464+
prompt_tokens = episode.completion.prompt_ids.shape[0]
465+
response_tokens = episode.completion.token_ids.shape[0]
466+
467+
record_metric("episode/avg_prompt_tokens", prompt_tokens, Reduce.MEAN)
468+
record_metric("episode/max_prompt_tokens", prompt_tokens, Reduce.MAX)
469+
record_metric("episode/min_prompt_tokens", prompt_tokens, Reduce.MIN)
470+
record_metric(
471+
"episode/avg_response_tokens", response_tokens, Reduce.MEAN
472+
)
473+
record_metric(
474+
"episode/max_response_tokens", response_tokens, Reduce.MAX
475+
)
476+
record_metric(
477+
"episode/min_response_tokens", response_tokens, Reduce.MIN
478+
)
479+
481480
# drop episodes if
482481
# 1> reward std-dev is very small (including all 0s and all 1s)
483-
# 2> response is potentially truncated (response_len >= max_res_tokens)
482+
# 2> any response was truncated (didn't end with EOS)
483+
# TODO: change it to filter only truncated episodes instead of dropping entire group
484484
rewards = [e.reward for e in episodes]
485485
rewards_std = torch.std(torch.tensor(rewards))
486-
max_response_len = max(e.completion.token_ids.shape[0] for e in episodes)
487-
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
486+
is_low_variance = rewards_std < 1e-3
487+
num_truncated = sum(
488+
1 for e in episodes if e.completion.stop_reason == "length"
489+
)
490+
is_truncated = num_truncated > 0
491+
drop = is_low_variance or is_truncated
492+
493+
n = len(episodes)
494+
record_metric(
495+
"main/continuous_rollouts/episodes_dropped/low_variance",
496+
n if is_low_variance else 0,
497+
Reduce.SUM,
498+
)
488499
record_metric(
489-
"main/continuous_rollouts/dropped_episodes",
490-
1 if drop else 0,
500+
"main/continuous_rollouts/episodes_dropped/truncated",
501+
num_truncated,
491502
Reduce.SUM,
492503
)
504+
record_metric(
505+
"main/continuous_rollouts/episodes_dropped/total",
506+
n if drop else 0,
507+
Reduce.SUM,
508+
)
509+
493510
if drop:
494511
del input_ids, episodes
495512
continue

src/forge/actors/generator.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import sys
13+
import time
1314
from collections.abc import Mapping
1415
from copy import copy
1516
from dataclasses import dataclass, field
@@ -258,8 +259,6 @@ async def _fetch_weights(
258259
version: int,
259260
) -> dict[str, SharedTensorHandle]:
260261
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
261-
t = Tracer("generator_perf/_fetch_weights")
262-
t.start()
263262
prefix = get_param_prefix(version)
264263
matching_keys = await ts.keys(prefix)
265264
hf_param_names = [extract_param_name(key) for key in matching_keys]
@@ -282,8 +281,6 @@ def split_keys(keys):
282281
for sd in sub_state_dicts:
283282
state_dict.update(sd)
284283

285-
t.stop()
286-
287284
return state_dict
288285

289286
@endpoint
@@ -336,8 +333,6 @@ async def generate(
336333
priority=priority,
337334
data_parallel_rank=None, # We do not support DP
338335
)
339-
t.step("process_inputs")
340-
341336
# Wait until we're accepting requests (releases lock while waiting)
342337
# If accepting_requests is True, continue immediately (holding the lock)
343338
# If False, release lock, wait for notification, re-acquire and recheck
@@ -369,7 +364,6 @@ async def generate(
369364
self.requests[request_id] = (parent_req, request_fut)
370365

371366
completions = await request_fut
372-
t.step("generate")
373367

374368
# Log some metrics
375369
record_metric(
@@ -378,19 +372,6 @@ async def generate(
378372
Reduce.SUM,
379373
)
380374

381-
for completion in completions:
382-
num_generated_tokens = len(completion.token_ids)
383-
record_metric(
384-
"generator/generate/sum_tokens_generated",
385-
num_generated_tokens,
386-
Reduce.SUM,
387-
)
388-
389-
record_metric(
390-
"generator/generate/avg_tokens_generated",
391-
num_generated_tokens,
392-
Reduce.MEAN,
393-
)
394375
t.stop()
395376
return completions
396377

@@ -465,37 +446,36 @@ async def update_weights(self, version: int) -> None:
465446
async with self.request_lock:
466447
self.accepting_requests = False
467448
curr_requests = [fut for _, fut in self.requests.values()]
449+
468450
if curr_requests:
469-
# Record pending requests metrics
470-
record_metric(
471-
"generator_perf/update_weights/avg_pending_requests",
472-
len(curr_requests),
473-
Reduce.MEAN,
474-
)
451+
# Record pending requests count
475452
record_metric(
476-
"generator_perf/update_weights/max_pending_requests",
453+
"generator_perf/update_weights/sum_pending_gen_requests",
477454
len(curr_requests),
478-
Reduce.MAX,
455+
Reduce.SUM,
479456
)
480457
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
481458

459+
# Start timing the wait
460+
wait_start = time.perf_counter()
461+
482462
# Wait until all pending requests have been processed
483463
# TODO: If generating long sequences, this might be long and will block
484464
# generator weight updates
485465
await self.request_lock.wait_for(lambda: len(self.requests) == 0)
486466

487-
# Record weight update metrics
488-
record_metric(
489-
"generator/update_weights/count_weight_updates", 1, Reduce.SUM
490-
)
467+
if curr_requests:
468+
wait_duration = time.perf_counter() - wait_start
469+
record_metric(
470+
"generator_perf/update_weights/avg_waiting_for_generation_duration_s",
471+
wait_duration,
472+
Reduce.MEAN,
473+
)
491474

492475
logger.debug(f"Starting weight update on {self.__class__.__name__}")
493476

494477
if fetch_fut is not None:
495-
t = Tracer("generator_perf/waiting_for_fetch_weights")
496-
t.start()
497478
fetched_weights = await fetch_fut
498-
t.stop()
499479
# Call update_weights on every policy_worker
500480
await self.worker.update_weights.call(
501481
shared_memory_state_dict=fetched_weights
@@ -672,10 +652,6 @@ async def update_weights(
672652
model = self.worker.model_runner.model
673653
if shared_memory_state_dict is not None:
674654
logger.info("[PolicyWorker] update weights from shared memory.")
675-
t = Tracer(
676-
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
677-
)
678-
t.start()
679655
loaded_weights = set()
680656
for name, param_handle in shared_memory_state_dict.items():
681657
# Use context manager for automatic cleanup
@@ -685,7 +661,6 @@ async def update_weights(
685661
del param
686662
loaded_weights.update(loaded)
687663
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} parameters")
688-
t.stop()
689664
return
690665
# normal update_weights without shared memory prefetching
691666
if version is None:
@@ -698,8 +673,6 @@ async def update_weights(
698673
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
699674
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
700675
loaded_weights = set()
701-
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
702-
t.start()
703676

704677
if use_dcp_for_weight_sync:
705678
dcp_handle = await ts.get(dcp_whole_state_dict_key)
@@ -720,8 +693,6 @@ async def update_weights(
720693
del param
721694
loaded_weights.update(loaded)
722695

723-
t.stop()
724-
725696
@endpoint
726697
async def save_model_params(self):
727698
"""Save model parameters before weight update, used for testing purposes only."""

src/forge/actors/reference_model.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,15 @@ async def forward(
144144
"""
145145
# Record reference model metrics
146146
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
147-
record_metric(
148-
"reference_perf/forward/avg_sequence_length",
149-
input_ids.shape[1],
150-
Reduce.MEAN,
151-
)
152147

153148
t = Tracer("reference_perf/forward", timer="gpu", track_memory=True)
154149
t.start()
155150
self.engine.gc_handler.run(self.step)
156-
t.step("garbage_collection")
157151

158152
model_parts = self.engine.model_parts
159153
parallel_dims = self.engine.parallel_dims
160154
input_ids = input_ids.to("cuda")
161-
t.step("to_device")
155+
162156
# optional_context_parallel_ctx = (
163157
# dist_utils.create_context_parallel_ctx(
164158
# cp_mesh=parallel_dims.world_mesh["cp"],
@@ -182,13 +176,11 @@ async def forward(
182176
self.step += 1
183177
if isinstance(logits, DTensor):
184178
logits = logits.full_tensor()
185-
t.step("forward")
186179

187180
if not return_logprobs:
188181
t.stop()
189182
return logits
190183
else:
191184
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
192-
t.step("compute_logprobs")
193185
t.stop()
194186
return logprobs

src/forge/actors/replay_buffer.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from forge.controller import ForgeActor
1515
from forge.observability.metrics import record_metric, Reduce
16-
from forge.observability.perf_tracker import trace
1716

1817
from monarch.actor import endpoint
1918

@@ -75,7 +74,6 @@ async def add(self, episode: "Episode") -> None:
7574
record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM)
7675

7776
@endpoint
78-
@trace("buffer_perf/sample", track_memory=False)
7977
async def sample(
8078
self, curr_policy_version: int
8179
) -> tuple[tuple[Any, ...], ...] | None:
@@ -87,8 +85,6 @@ async def sample(
8785
Returns:
8886
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
8987
"""
90-
# Record sample request metric
91-
record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM)
9288

9389
total_samples = self.dp_size * self.batch_size
9490

@@ -98,7 +94,7 @@ async def sample(
9894
# Calculate metrics
9995
if len(self.buffer) > 0:
10096
record_metric(
101-
"buffer/sample/avg_data_utilization",
97+
"buffer/sample/demand_to_size_ratio",
10298
total_samples / len(self.buffer),
10399
Reduce.MEAN,
104100
)
@@ -135,12 +131,6 @@ async def sample(
135131
max(sampled_policy_ages),
136132
Reduce.MAX,
137133
)
138-
record_metric(
139-
"buffer/sample/min_sampled_policy_age",
140-
min(sampled_policy_ages),
141-
Reduce.MIN,
142-
)
143-
144134
# Reshape into (dp_size, bsz, ...)
145135
reshaped_episodes = [
146136
sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]

0 commit comments

Comments
 (0)