2020from forge .observability .metrics import record_metric , Reduce
2121from forge .observability .perf_tracker import Tracer
2222from forge .rl .loss import create_shifted_targets
23+ from forge .types import TrainBatch
2324from monarch .actor import endpoint
2425from torch import Tensor
2526from torch .distributed .checkpoint ._nested_dict import flatten_state_dict
@@ -117,17 +118,15 @@ async def setup(self):
117118 self .engine .checkpointer .load (step = self .step )
118119 self .engine .optimizers .zero_grad ()
119120
120- def forward_backward (
121- self , inputs : dict [str , Tensor ], targets : dict [str , Tensor ]
122- ) -> Tensor :
121+ def forward_backward (self , batch : TrainBatch ) -> Tensor :
123122 model_parts = self .engine .model_parts
124123 parallel_dims = self .engine .parallel_dims
125124 optional_context_parallel_ctx = None
126125
127126 # Create shifted target_ids for next-token prediction
128127 # target_ids[i] = input_ids[i+1], with loss_mask applied
129- targets ["target_ids" ] = create_shifted_targets (
130- inputs ["tokens" ], targets .get ("loss_mask" )
128+ batch . loss_inputs ["target_ids" ] = create_shifted_targets (
129+ batch . model_inputs ["tokens" ], batch . loss_inputs .get ("loss_mask" )
131130 )
132131
133132 if parallel_dims .pp_enabled :
@@ -136,8 +135,8 @@ def forward_backward(
136135 with self .engine .train_context (optional_context_parallel_ctx ):
137136 assert len (model_parts ) == 1
138137 with self .engine .maybe_enable_amp :
139- logits = model_parts [0 ](** inputs )
140- loss_output = self .loss (logits , ** targets )
138+ logits = model_parts [0 ](** batch . model_inputs )
139+ loss_output = self .loss (logits , ** batch . loss_inputs )
141140 loss = loss_output .loss
142141
143142 # Record metrics from loss output
@@ -156,19 +155,16 @@ def forward_backward(
156155 return loss
157156
158157 @endpoint
159- async def train_step (
160- self , inputs : list [dict [str , Tensor ]], targets : list [dict [str , Tensor ]]
161- ) -> float :
158+ async def train_step (self , batches : list [TrainBatch ]) -> float :
162159 t = Tracer ("rl_trainer_perf/step" , timer = "gpu" , track_memory = True )
163160 t .start ()
164161
165162 self .engine .gc_handler .run (self .step )
166- local_inputs = inputs [self .engine .dp_rank ]
167- local_targets = targets [self .engine .dp_rank ]
168- batch_to_device (local_inputs , self .engine .device )
169- batch_to_device (local_targets , self .engine .device )
163+ batch = batches [self .engine .dp_rank ]
164+ batch_to_device (batch .model_inputs , self .engine .device )
165+ batch_to_device (batch .loss_inputs , self .engine .device )
170166
171- loss = self .forward_backward (local_inputs , local_targets )
167+ loss = self .forward_backward (batch )
172168 torch .distributed .all_reduce (loss )
173169
174170 t .step ("forward_backward" )
0 commit comments