-
Notifications
You must be signed in to change notification settings - Fork 111
Expand file tree
/
Copy pathtrain_tiny_model_cpu.py
More file actions
72 lines (61 loc) · 2.19 KB
/
train_tiny_model_cpu.py
File metadata and controls
72 lines (61 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0
"""
This is a tutorial on how to train a tiny model on a small dataset using CPU.
This script demonstrates how to:
1. Train a tiny model on TinyStories using CPU
2. Use CPU-specific training configuration
3. Run a quick training experiment
For GPU training, see train_tiny_model_gpu.py
"""
from fray import ResourceConfig
from levanter.data.text import TextLmDatasetFormat
from marin.execution.executor import executor_main, versioned
from experiments.defaults import default_tokenize, default_train
from experiments.llama import llama_nano
from experiments.marin_models import marin_tokenizer
from experiments.simple_train_config import SimpleTrainConfig
# 1. Choose a dataset
tinystories_hf_id = "roneneldan/TinyStories"
# 2. Tokenize the dataset with sampling
# For this tutorial, we limit to 1000 documents per shard
tinystories_tokenized = default_tokenize(
name=tinystories_hf_id,
dataset=tinystories_hf_id,
tokenizer=marin_tokenizer,
format=TextLmDatasetFormat(),
sample_count=1000,
)
# 3. Define training configuration
nano_train_config = SimpleTrainConfig(
# Here we define the hardware resources we need.
resources=ResourceConfig.with_cpu(),
train_batch_size=4,
num_train_steps=100,
# set hyperparameters
learning_rate=6e-4,
weight_decay=0.1,
# keep eval quick for tutorial
max_eval_batches=4,
)
# 4. Train the model
nano_tinystories_model = default_train(
name="marin-nano-tinystories",
# Steps can depend on other steps: nano_tinystories_model depends on tinystories_tokenized
tokenized=tinystories_tokenized,
model_config=versioned(llama_nano),
train_config=nano_train_config,
# wandb tags
tags=["llama", "nano", "tinystories", "tutorial"],
# We can run many [eval_harness](https://github.com/EleutherAI/lm-evaluation-harness) tasks in the loop
# during training, but there's no point in running evals on such a tiny model
eval_harness_tasks=[],
# to keep tutorial fast, skip default validation sets
use_default_validation=False,
)
if __name__ == "__main__":
executor_main(
steps=[
nano_tinystories_model,
]
)