-
Notifications
You must be signed in to change notification settings - Fork 98
Expand file tree
/
Copy pathmain.py
More file actions
303 lines (256 loc) · 10.4 KB
/
main.py
File metadata and controls
303 lines (256 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""To run:
python -m apps.sft.main --config apps/sft/llama3_8b.yaml
"""
import asyncio
import logging
import math
import os
import sys
from functools import partial
from typing import Any
import torch
import torchtitan.experiments.forge.train_spec as forge_train_spec
from forge.controller import ForgeActor
from forge.data.collate import collate_packed
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.util.config import parse
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
# from tqdm import tqdm
# stubs for now
Checkpointer = Any
Dataloader = Any
MetricLogger = Any
Profiler = Any
Tokenizer = Any
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class ForgeSFTRecipe(ForgeActor, ForgeEngine):
job_config: ForgeJobConfig
train_spec: forge_train_spec.ForgeTrainSpec
parallel_dims: ParallelDims
model: list[nn.Module]
loss_fn: LossFunction
optimizer: OptimizersContainer
lr_scheduler: LRSchedulersContainer
checkpointer: Checkpointer
tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
metric_logger: MetricLogger
profiler: Profiler
device: torch.device
step: int
def __init__(self, config: DictConfig):
job_config = ForgeJobConfig().to_dict()
# Hack to deal with literal types from titan
job_config = OmegaConf.merge(job_config, config)
self.current_step = 0
self.num_training_steps = job_config.training.steps
self.metric_logger = None # TODO: fix this
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())
self._init_dist()
super().__init__(job_config)
def _init_dist(self):
"""Initializes torch distributed.
torchrun normally hands this, but we need to do it ourselves
in monarch for now.
We should consider putting this into ForgeActor, but having this
be explicit for now.
"""
env = {
"RANK": str(self._rank),
"LOCAL_RANK": str(self._rank),
"LOCAL_WORLD_SIZE": str(self._size),
"GROUP_RANK": str(self._size),
"GROUP_WORLD_SIZE": str(self._size),
"ROLE_RANK": str(self._rank),
"ROLE_WORLD_SIZE": str(self._size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self._size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
logger.info("env: {}".format(env))
@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )
# TODO: confirm that this is working properly
# Should also use load, not dcp_load
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)
def setup_data(self):
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
),
tokenizer_config_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer_config.json"
),
generation_config_path=os.path.join(
self.job_config.model.hf_assets_path, "generation_config.json"
),
)
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
dataset=dataset,
packer=packer,
target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model
)
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
)
# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader
def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["tokens"]
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
)
if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)
# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
return loss
def train_step(self, batch) -> None:
# TODO
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()
@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
self.optimizers.zero_grad()
# TODO: tqdm is broken in Monarch actors
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)
while self.current_step < self.num_training_steps:
batch = next(dataloader)
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
self.train_step(batch)
# self.profiler.step()
self.current_step += 1
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)
# self.pbar.close()
@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
if self.metric_logger:
self.metric_logger.close()
def __repr__(self) -> str:
return "Trainer"
async def run(cfg: DictConfig) -> None:
logging.info("Spawing recipe...")
process_cfg = cfg.pop("processes")
recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
logging.info("Created recipe, running setup.")
await recipe.setup.call()
logging.info("Recipe has been setup. Training now.")
await recipe.train.call()
logging.info("Done training. Clean up")
await recipe.cleanup.call()
await recipe.mesh.stop()
logging.info("All done!")
@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))
if __name__ == "__main__":
sys.exit(recipe_main())