Skip to content

Commit 53328b4

Browse files
authored
ping-pong weight sync (#763)
1 parent d3eb3bf commit 53328b4

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

apps/grpo/main.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from forge.rl import collate, ComputeAdvantages, Episode, RewardActor
2828
from forge.rl.loss import DAPOLoss, GRPOLoss
2929
from forge.types import LauncherConfig, ProvisionerConfig
30-
from forge.util.checkpoint import drop_weights
3130
from forge.util.config import parse
3231
from forge.util.logging import get_logger
3332
from omegaconf import DictConfig, OmegaConf
@@ -335,10 +334,6 @@ async def continuous_training():
335334
await generator.update_weights.fanout(training_step)
336335
t.step("update_weights")
337336

338-
if training_step >= 2:
339-
await drop_weights(training_step - 1)
340-
t.step("drop_weights")
341-
342337
t.stop()
343338
restart_tracer = True
344339

src/forge/actors/_torchstore_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,25 @@
1010

1111
KEY_DELIM = "."
1212

13+
# Alternate between two storage version IDs
14+
# This reuses allocations instead of incrementing versions and deleting old ones
15+
VERSION_A = 0
16+
VERSION_B = 1
17+
18+
19+
def get_storage_version(step: int) -> int:
20+
"""Map incrementing step to ping-pong storage version (0 or 1)."""
21+
return VERSION_A if step % 2 == 0 else VERSION_B
22+
1323

1424
def get_param_prefix(policy_version: int) -> str:
15-
return f"policy_ver_{policy_version:010d}"
25+
storage_version = get_storage_version(policy_version)
26+
return f"policy_ver_{storage_version:010d}"
1627

1728

1829
def get_param_key(policy_version: int, name: str) -> str:
19-
return f"policy_ver_{policy_version:010d}{KEY_DELIM}{name}"
30+
storage_version = get_storage_version(policy_version)
31+
return f"policy_ver_{storage_version:010d}{KEY_DELIM}{name}"
2032

2133

2234
def extract_param_name(key: str) -> str:

src/forge/actors/trainer/titan.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,12 @@ async def push_weights(self, policy_version: int) -> None:
327327
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
328328
)
329329
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
330-
for name, param in hf_state_dict.items():
331-
key = get_param_key(policy_version, name)
332-
await ts.put(key, param)
330+
331+
entries = [
332+
(get_param_key(policy_version, name), param)
333+
for name, param in hf_state_dict.items()
334+
]
335+
await ts.put_batch(entries)
333336
end_time = time.perf_counter()
334337
logger.info("Completed weights push in %.2f seconds", end_time - start_time)
335338

0 commit comments

Comments
 (0)