|
1 | | -from trl import GRPOConfig, GRPOTrainer |
| 1 | +# /// script |
| 2 | +# dependencies = [ |
| 3 | +# "trackio", |
| 4 | +# "trl", |
| 5 | +# "datasets", |
| 6 | +# "transformers", |
| 7 | +# "torch", |
| 8 | +# ] |
| 9 | +# /// |
| 10 | + |
| 11 | +import random |
| 12 | + |
| 13 | +from datasets import Dataset |
| 14 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback |
| 15 | +from trl import SFTConfig, SFTTrainer |
2 | 16 |
|
3 | 17 | import trackio |
4 | 18 |
|
| 19 | +PROJECT_ID = random.randint(100000, 999999) |
| 20 | +PROJECT_NAME = f"trace-demo-trl-{PROJECT_ID}" |
| 21 | +MODEL_NAME = "sshleifer/tiny-gpt2" |
| 22 | + |
| 23 | +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| 24 | +if tokenizer.pad_token is None: |
| 25 | + tokenizer.pad_token = tokenizer.eos_token |
| 26 | + |
| 27 | +examples = [ |
| 28 | + {"prompt": "What is 2 + 2?", "completion": "2 + 2 = 4."}, |
| 29 | + { |
| 30 | + "prompt": "What color is the sky on a clear day?", |
| 31 | + "completion": "The sky is typically blue on a clear day.", |
| 32 | + }, |
| 33 | + {"prompt": "Translate 'good morning' to French.", "completion": "Bonjour."}, |
| 34 | + { |
| 35 | + "prompt": "Name the capital of Japan.", |
| 36 | + "completion": "Tokyo is the capital of Japan.", |
| 37 | + }, |
| 38 | + { |
| 39 | + "prompt": "Give one use of Trackio.", |
| 40 | + "completion": "Trackio can be used to inspect training logs and traces.", |
| 41 | + }, |
| 42 | +] |
| 43 | + |
| 44 | + |
| 45 | +def format_example(example): |
| 46 | + return { |
| 47 | + "text": ( |
| 48 | + "### Instruction:\n" |
| 49 | + f"{example['prompt']}\n\n" |
| 50 | + "### Response:\n" |
| 51 | + f"{example['completion']}" |
| 52 | + ) |
| 53 | + } |
| 54 | + |
5 | 55 |
|
6 | | -trackio.init(project="trace-demo-trl") |
| 56 | +dataset = Dataset.from_list([format_example(example) for example in examples * 2]) |
7 | 57 |
|
8 | 58 |
|
9 | | -def log_rollouts(prompts, completions, rewards, step, model_version): |
10 | | - trackio.log( |
11 | | - { |
12 | | - "traces": [ |
13 | | - trackio.Trace( |
| 59 | +class TraceLoggingCallback(TrainerCallback): |
| 60 | + def __init__(self, prompt_examples, run_label): |
| 61 | + self.prompt_examples = prompt_examples |
| 62 | + self.run_label = run_label |
| 63 | + |
| 64 | + def on_log(self, args, state, control, logs=None, **kwargs): |
| 65 | + if not logs or state.global_step <= 0: |
| 66 | + return |
| 67 | + |
| 68 | + sample = self.prompt_examples[ |
| 69 | + (state.global_step - 1) % len(self.prompt_examples) |
| 70 | + ] |
| 71 | + reward = max(0.0, 1.0 - float(logs.get("loss", 0.0))) |
| 72 | + trackio.log( |
| 73 | + { |
| 74 | + "trace": trackio.Trace( |
14 | 75 | messages=[ |
15 | | - {"role": "user", "content": prompt}, |
16 | | - {"role": "assistant", "content": completion}, |
| 76 | + { |
| 77 | + "role": "system", |
| 78 | + "content": "You are a supervised fine-tuning demo model.", |
| 79 | + }, |
| 80 | + {"role": "user", "content": sample["prompt"]}, |
| 81 | + {"role": "assistant", "content": sample["completion"]}, |
17 | 82 | ], |
18 | 83 | metadata={ |
19 | | - "reward": float(reward), |
20 | | - "step": step, |
21 | | - "model_version": model_version, |
| 84 | + "model_version": self.run_label, |
| 85 | + "trainer": "trl-sft", |
| 86 | + "loss": float(logs.get("loss", 0.0)), |
| 87 | + "reward": reward, |
| 88 | + "global_step": int(state.global_step), |
22 | 89 | }, |
23 | 90 | ) |
24 | | - for prompt, completion, reward in zip(prompts, completions, rewards) |
25 | | - ] |
26 | | - }, |
27 | | - step=step, |
28 | | - ) |
| 91 | + }, |
| 92 | + step=int(state.global_step), |
| 93 | + ) |
29 | 94 |
|
30 | 95 |
|
31 | | -trainer = GRPOTrainer( |
32 | | - model="Qwen/Qwen2.5-0.5B", |
33 | | - reward_funcs=[], |
34 | | - args=GRPOConfig(output_dir="out", report_to="trackio"), |
35 | | - train_dataset=[], |
36 | | -) |
| 96 | +for run_idx in range(2): |
| 97 | + run_name = f"trl-run-{run_idx}" |
| 98 | + trackio.init(project=PROJECT_NAME, name=run_name) |
| 99 | + |
| 100 | + model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
| 101 | + |
| 102 | + trainer = SFTTrainer( |
| 103 | + model=model, |
| 104 | + args=SFTConfig( |
| 105 | + output_dir=f"./trl_trace_output_{PROJECT_ID}_{run_idx}", |
| 106 | + per_device_train_batch_size=2, |
| 107 | + max_steps=5, |
| 108 | + logging_steps=1, |
| 109 | + save_strategy="no", |
| 110 | + report_to="none", |
| 111 | + learning_rate=5e-5, |
| 112 | + dataset_text_field="text", |
| 113 | + max_length=64, |
| 114 | + ), |
| 115 | + train_dataset=dataset, |
| 116 | + processing_class=tokenizer, |
| 117 | + callbacks=[TraceLoggingCallback(examples, run_name)], |
| 118 | + ) |
37 | 119 |
|
38 | | -# Wire `log_rollouts(...)` into your callback or reward loop. |
39 | | -# trainer.train() |
| 120 | + trainer.train() |
| 121 | + trackio.finish() |
0 commit comments