Skip to content

Commit d232f31

Browse files
Update torchforge docs (#732)
1 parent 19aedcf commit d232f31

File tree

1 file changed

+30
-30
lines changed

1 file changed

+30
-30
lines changed

docs/source/tutorial_sources/zero-to-forge/2_Forge_Internals.md

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ Now that you see the power of the service abstraction, let's understand what's a
66

77
## Service Anatomy: Beyond the Interface
88

9-
When you call `await policy_service.generate(question)`, here's what actually happens:
9+
When you call `await generator_service.generate(question)`, here's what actually happens:
1010

1111
(Don't worry, we will understand Services right in the next section!)
1212

1313
```mermaid
1414
graph TD
1515
Call["Your Code:
16-
await policy_service
16+
await generator_service
1717
.generate.route"]
1818
1919
subgraph ServiceLayer["Service Layer"]
@@ -66,7 +66,7 @@ Here's the actual ServiceConfig from TorchForge source code:
6666

6767
```python
6868
# Configuration pattern from apps/grpo/main.py:
69-
Policy.options(
69+
Generator.options(
7070
procs=1, # Processes per replica
7171
num_replicas=4, # Number of replicas
7272
with_gpus=True # Allocate GPUs
@@ -86,11 +86,11 @@ The service creation automatically handles:
8686
- Message routing and serialization
8787

8888
```python
89-
from forge.actors.generator import Generator as Policy
89+
from forge.actors.generator import Generator
9090

9191
model = "Qwen/Qwen3-1.7B"
9292

93-
policy = await Policy.options(
93+
generator = await Generator.options(
9494
procs=1,
9595
with_gpus=True,
9696
num_replicas=1
@@ -110,11 +110,11 @@ policy = await Policy.options(
110110
)
111111

112112
prompt = "What is 3 + 5?"
113-
responses = await policy.generate.route(prompt)
113+
responses = await generator.generate.route(prompt)
114114
print(f"Response: {responses[0].text}")
115115

116116
# Cleanup when done
117-
await policy.shutdown()
117+
await generator.shutdown()
118118
```
119119

120120
### 3. How Services Actually Work
@@ -125,7 +125,7 @@ When you call `.as_service()`, TorchForge creates a `ServiceInterface` that mana
125125

126126
```python
127127
# Your code sees this simple interface:
128-
responses = await policy.generate.route(prompt=prompt)
128+
responses = await generator.generate.route(prompt=prompt)
129129
# But TorchForge handles all the complexity of replica management, load balancing, and fault tolerance
130130
```
131131

@@ -183,7 +183,7 @@ These communication patterns (\"adverbs\") determine how your service calls are
183183
**When to use**: Normal request routing where any replica can handle the request.
184184

185185
```python
186-
responses = await policy.generate.route(prompt=question)
186+
responses = await generator.generate.route(prompt=question)
187187
answer = responses[0].text # Extract text from Completion object
188188
```
189189

@@ -205,12 +205,12 @@ Behind the scenes:
205205
**When to use**: You need responses from ALL replicas.
206206

207207
```python
208-
# Get version from all policy replicas
209-
current_versions = await policy.get_version.fanout()
208+
# Get version from all generator replicas
209+
current_versions = await generator.get_version.fanout()
210210
# Returns: [version_replica_1, version_replica_2, ...]
211211

212212
# Update weights on all replicas
213-
await policy.update_weights.fanout(new_policy_version)
213+
await generator.update_weights.fanout(new_policy_version)
214214
# Broadcasts to all replicas simultaneously
215215
```
216216

@@ -291,7 +291,7 @@ print(f"All replica values: {results}")
291291
# Output: All replica values: [1, 2, 1, 1] - Each replica has different state!
292292
```
293293

294-
The problem: each `.route()` call can go to different replicas, creating inconsistent state.
294+
The problem: each `.route()` call can go to a different replica, creating inconsistent state.
295295

296296
```python
297297
# WITH SESSIONS: All calls go to the SAME replica
@@ -313,10 +313,10 @@ async with counter_service.session(): # Creates a session that picks one replic
313313
# Final value on this replica: 3
314314

315315
# Same pattern works with Policy for multi-turn conversations:
316-
# async with policy.session():
317-
# response1 = await policy.generate.route(turn1)
316+
# async with generator.session():
317+
# response1 = await generator.generate.route(turn1)
318318
# full_prompt = turn1 + response1[0].text + turn2
319-
# response2 = await policy.generate.route(full_prompt)
319+
# response2 = await generator.generate.route(full_prompt)
320320
# # Both calls hit same replica, preserving KV cache
321321

322322
# Cleanup
@@ -337,22 +337,22 @@ The most complex challenge in distributed RL is maintaining state consistency wh
337337
# This breaks KV cache optimization:
338338
async def naive_multi_turn():
339339
# Each call might go to different replica = cache miss
340-
response1 = await policy_service.generate.choose(question1)
341-
response2 = await policy_service.generate.choose(question1 + response1) # Cache miss!
342-
response3 = await policy_service.generate.choose(conversation_so_far) # Cache miss!
340+
response1 = await generator_service.generate.route(question1)
341+
response2 = await generator_service.generate.route(question1 + response1) # Cache miss!
342+
response3 = await generator_service.generate.route(conversation_so_far) # Cache miss!
343343
```
344344

345345
**The solution**: Sticky sessions ensure all calls go to same replica.
346346

347347
```python
348348
async def optimized_multi_turn():
349-
async with policy.session():
349+
async with generator.session():
350350
# All calls guaranteed to hit same replica = cache hits
351-
response1 = await policy.generate.route(prompt=question1)
351+
response1 = await generator.generate.route(prompt=question1)
352352
full_prompt = question1 + response1[0].text
353-
response2 = await policy.generate.route(prompt=full_prompt) # Cache hit!
353+
response2 = await generator.generate.route(prompt=full_prompt) # Cache hit!
354354
conversation = full_prompt + response2[0].text
355-
response3 = await policy.generate.route(prompt=conversation) # Cache hit!
355+
response3 = await generator.generate.route(prompt=conversation) # Cache hit!
356356

357357
# Session ends, replica can be garbage collected or reused
358358
```
@@ -368,7 +368,7 @@ async def optimized_multi_turn():
368368
```python
369369
# TorchForge ReplayBuffer endpoints (verified from source code)
370370
# Add episodes (thread-safe by actor model)
371-
await replay_buffer.add.call_one(episode) # .choose() would work too, but .call_one() clarifies it's a singleton actor not ActorMesh
371+
await replay_buffer.add.call_one(episode) # .route() would work too, but .call_one() clarifies it's a singleton actor not ActorMesh
372372

373373
# Sample batches for training
374374
batch = await replay_buffer.sample.call_one(
@@ -386,20 +386,20 @@ batch = await replay_buffer.sample.call_one(
386386

387387
### Weight Synchronization Strategy
388388

389-
**The challenge**: Trainer updates policy weights, but policy service needs those weights.
389+
**The challenge**: Trainer updates policy weights, but generator service needs those weights.
390390

391391
```python
392392
# TorchForge weight synchronization pattern from apps/grpo/main.py
393393
async def real_weight_sync(trainer, policy, step):
394394
# Trainer pushes weights to TorchStore with version number
395395
await trainer.push_weights.call_one(policy_version=step + 1)
396396

397-
# Policy service updates to new version from TorchStore
397+
# Generator service updates to new version from TorchStore
398398
# Use .fanout() to update ALL policy replicas
399-
await policy.update_weights.fanout(policy_version=step + 1)
399+
await generator.update_weights.fanout(policy_version=step + 1)
400400

401401
# Check current policy version
402-
current_version = await policy.get_version.route()
402+
current_version = await generator.get_version.route()
403403
print(f"Current policy version: {current_version}")
404404
```
405405

@@ -423,8 +423,8 @@ async def simple_rl_step():
423423
print(f"Prompt: {prompt}")
424424
print(f"Target: {target}")
425425

426-
actions = await policy.generate.route(prompt=prompt) # Policy is a service
427-
print(f"Policy response: {actions[0].text}")
426+
actions = await generator.generate.route(prompt=prompt) # Generator is a service
427+
print(f"Generator response: {actions[0].text}")
428428

429429
# Create input tensor for reference model (requires full context)
430430
input_ids = torch.cat([actions[0].prompt_ids, actions[0].token_ids])

0 commit comments

Comments
 (0)