|
| 1 | +# /// script |
| 2 | +# dependencies = [ |
| 3 | +# "trackio>=0.23.0", |
| 4 | +# "trl>=1.2.0", |
| 5 | +# ] |
| 6 | +# /// |
| 7 | + |
| 8 | +import random |
| 9 | + |
| 10 | +from datasets import Dataset |
| 11 | +from trl import SFTConfig, SFTTrainer |
| 12 | + |
| 13 | +suffix = random.randint(100000, 999999) |
| 14 | +project_name = f"trackio-trl-demo-{suffix}" |
| 15 | + |
| 16 | +prompts = [ |
| 17 | + [{"role": "user", "content": "What is the capital of France?"}], |
| 18 | + [{"role": "user", "content": "Who wrote Hamlet?"}], |
| 19 | + [{"role": "user", "content": "What is 2 + 2?"}], |
| 20 | + [{"role": "user", "content": "What color is the sky?"}], |
| 21 | + [{"role": "user", "content": "Name a primary color."}], |
| 22 | + [{"role": "user", "content": "What is the largest planet?"}], |
| 23 | +] * 4 |
| 24 | +completions = [ |
| 25 | + [{"role": "assistant", "content": "Paris."}], |
| 26 | + [{"role": "assistant", "content": "Shakespeare."}], |
| 27 | + [{"role": "assistant", "content": "4."}], |
| 28 | + [{"role": "assistant", "content": "Blue."}], |
| 29 | + [{"role": "assistant", "content": "Red."}], |
| 30 | + [{"role": "assistant", "content": "Jupiter."}], |
| 31 | +] * 4 |
| 32 | +dataset = Dataset.from_dict({"prompt": prompts, "completion": completions}) |
| 33 | + |
| 34 | +trainer = SFTTrainer( |
| 35 | + model="Qwen/Qwen3-0.6B", |
| 36 | + args=SFTConfig( |
| 37 | + output_dir="./model_output", |
| 38 | + num_train_epochs=1, |
| 39 | + per_device_train_batch_size=4, |
| 40 | + learning_rate=2e-5, |
| 41 | + logging_steps=1, |
| 42 | + report_to="trackio", |
| 43 | + project=project_name, |
| 44 | + ), |
| 45 | + train_dataset=dataset, |
| 46 | +) |
| 47 | + |
| 48 | +trainer.train() |
| 49 | + |
| 50 | +print(f"Run complete. Open with: trackio show --project {project_name}") |
0 commit comments