Skip to content

marksverdhei/bakery

Repository files navigation

bakery

Where LLMs go to get baked.

Prompt baking distills a system prompt into model weights via KL divergence training with LoRA, so you get the behavior of a prompted model at zero inference-time prompt cost.

Based on Prompt Baking.

How it works

A single model serves as both teacher and student through PEFT adapter toggling:

  • Teacher (adapters disabled): sees the system prompt, generates reference behavior
  • Student (adapters enabled): no system prompt, trained to match the teacher's output distribution

The training objective minimizes per-token KL divergence between teacher and student logits on the response portion of each conversation.

Data sources

The dataset field accepts a local JSON file or a HuggingFace dataset ID (auto-detected). The format determines the training mode:

Data format Training mode
Prompts only (list of strings, prompt-only columns) On-the-fly trajectory generation from teacher
Paired data (prompt+response, chat messages) Train directly on precomputed pairs

You can also use training_prompts for inline prompt lists in YAML.

Install

pip install git+https://github.com/marksverdhei/bakery.git

Or for development:

git clone https://github.com/marksverdhei/bakery.git
cd bakery
uv sync --dev
uv pip install -e .

Quick start

bakery --config examples/basic.yaml

All config is flat YAML parsed by HfArgumentParser, so any TrainingArguments field works:

# Standard HF training
output_dir: "./outputs/my_bake"
num_train_epochs: 3
learning_rate: 1e-4
bf16: true

# Prompt baking
system_prompt: "You are a helpful assistant."
num_trajectories: 4
trajectory_length: 128

# Model
model_name_or_path: "Qwen/Qwen3-0.6B"

# LoRA
r: 64
lora_alpha: 128

# Data
training_prompts:
  - "What is the capital of France?"
  - "Explain photosynthesis."

Override any field from CLI:

bakery --config examples/basic.yaml --num_train_epochs 5 --learning_rate 5e-5

As a library

from bakery import BakeryConfig, PromptBakingTrainer, PromptBakingDataset, prompt_baking_collator

config = BakeryConfig(
    output_dir="./outputs",
    system_prompt="You are helpful.",
    num_train_epochs=3,
    learning_rate=1e-4,
)

dataset = PromptBakingDataset(
    prompts=["What is AI?", "Explain gravity."],
    responses=["AI is...", "Gravity is..."],  # optional precomputed
)

trainer = PromptBakingTrainer(
    model=peft_model,
    args=config,
    train_dataset=dataset,
    processing_class=tokenizer,
    data_collator=prompt_baking_collator,
)
trainer.train()

Examples

Config Description
examples/basic.yaml On-the-fly trajectory generation from inline prompts
examples/sft_dataset.yaml Bake from an existing HF chat dataset

About

Where LLMs go to get baked

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors