Conversation
JenniferWang
left a comment
There was a problem hiding this comment.
- 1 for removing the
Groupabstraction.
| # Calculate advantages and add to replay buffer | ||
| advantages = await compute_advantages.compute.call_one(group) | ||
| for episode, advantage in zip(group.episodes, advantages): | ||
| advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
Not related to this diff but now since we're scrutinizing the main flow again, I think making compute_advantages its own Actor is very weird and probably the opposite to an "optimization"
- We do not expose capability to specify the hostmesh for a specific actor -- ideally, this should be collocated with the generator replica that produces this batch.
- ComputeAdvantage only needs the rewards; so very likely the entire episodes are serialized.
I wonder, if for now it should be just inlined in the sample call; or allocating a proc on the Policy mesh along side the PolicyWorker to handle the computation but chain the calls and return the result together in sample
There was a problem hiding this comment.
@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
with policy.session() as s:
host: HostMesh = await s.get_host_mesh() # returns the host mesh associated with this replica
advantages = host.run_task(compute_advantages) # where compute_advantages is a function
There was a problem hiding this comment.
Chained calls would be cool 👀
There was a problem hiding this comment.
This looks legit; +1 on chained calls
| # Calculate advantages and add to replay buffer | ||
| advantages = await compute_advantages.compute.call_one(group) | ||
| for episode, advantage in zip(group.episodes, advantages): | ||
| advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
with policy.session() as s:
host: HostMesh = await s.get_host_mesh() # returns the host mesh associated with this replica
advantages = host.run_task(compute_advantages) # where compute_advantages is a function
joecummings
left a comment
There was a problem hiding this comment.
Awesome stuff! Just a bunch of small comments.
| # Calculate advantages and add to replay buffer | ||
| advantages = await compute_advantages.compute.call_one(group) | ||
| for episode, advantage in zip(group.episodes, advantages): | ||
| advantages = await compute_advantages.compute.call_one(episodes) |
There was a problem hiding this comment.
Chained calls would be cool 👀
The updates boil down to 2 changes that don't alter behavior in
grpo/main:Groupis downgraded from a dataclass to a typedef oflist[Episode], since it's never requiredEpisodenow directly holds aCompletionwith redundant attributes inEpisodebeing removedEpisodeandScoredCompletion.(There's also various typehint improvements sprinkled in)
Note: This PR does not address or utilize
Episodefrom data_models, but convergence is imminentWandb looks roughly the same
Before: torchforge/grpo-training/runs/wca6wke2
After torchforge/grpo-training/runs/ul34xjr9