Performed a rigorous review of the full 20-file codebase, fixed 4 bugs, added a compression visualization feature, and created 2 comprehensive documentation files.
Problem: Used globals() to resolve string type annotations, which fails when annotations aren't importable from the module scope.
Fix: Replaced with an explicit _DATACLASS_REGISTRY dict mapping field names to their concrete dataclass types.
"""Central configuration for the prompt-optimizer project."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
import yaml
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Project root – resolved relative to this file so it works both locally
# and on Colab after cloning the repo.
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).resolve().parent.parent
@dataclass
class ModelConfig:
name: str = "mistralai/Mistral-7B-Instruct-v0.2"
trust_remote_code: bool = True
torch_dtype: str = "float16" # "float16" | "bfloat16" | "auto"
@dataclass
class QuantizationConfig:
load_in_4bit: bool = True
bnb_4bit_quant_type: str = "nf4"
bnb_4bit_compute_dtype: str = "float16"
bnb_4bit_use_double_quant: bool = True
@dataclass
class LoraConfig:
r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
target_modules: list[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)
bias: str = "none"
task_type: str = "CAUSAL_LM"
@dataclass
class TrainingConfig:
output_dir: str = "outputs/checkpoints"
num_train_epochs: int = 3
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 4
learning_rate: float = 2e-4
weight_decay: float = 0.01
warmup_ratio: float = 0.03
lr_scheduler_type: str = "cosine"
max_seq_length: int = 1024
logging_steps: int = 10
save_strategy: str = "epoch"
save_total_limit: int = 2
fp16: bool = True
bf16: bool = False
gradient_checkpointing: bool = True
optim: str = "paged_adamw_8bit"
report_to: str = "none"
seed: int = 42
@dataclass
class GenerationConfig:
max_new_tokens: int = 512
temperature: float = 0.3
top_p: float = 0.9
repetition_penalty: float = 1.15
do_sample: bool = True
@dataclass
class DatasetConfig:
raw_dir: str = "data/raw"
processed_dir: str = "data/processed"
train_file: str = "train.jsonl"
val_file: str = "val.jsonl"
val_split: float = 0.1
seed: int = 42
max_samples: Optional[int] = None # limit for quick tests
@dataclass
class EvalConfig:
max_token_threshold: int = 300
min_compression_ratio: float = 0.5
@dataclass
class UIConfig:
server_name: str = "0.0.0.0"
server_port: int = 7860
share: bool = False
@dataclass
class Config:
model: ModelConfig = field(default_factory=ModelConfig)
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
lora: LoraConfig = field(default_factory=LoraConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
generation: GenerationConfig = field(default_factory=GenerationConfig)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
evaluation: EvalConfig = field(default_factory=EvalConfig)
ui: UIConfig = field(default_factory=UIConfig)
# ---- adapter / output helpers ------------------------------------------
@property
def adapter_dir(self) -> Path:
return PROJECT_ROOT / "outputs" / "adapter"
@property
def merged_dir(self) -> Path:
return PROJECT_ROOT / "outputs" / "merged"
# ---------------------------------------------------------------------------
# YAML loader
# ---------------------------------------------------------------------------
def _merge(base: dict, override: dict) -> dict:
"""Recursively merge *override* into *base*."""
for k, v in override.items():
if isinstance(v, dict) and isinstance(base.get(k), dict):
_merge(base[k], v)
else:
base[k] = v
return base
def _dict_to_dataclass(dc_cls: type, data: dict) -> Any:
"""Instantiate a nested dataclass hierarchy from a plain dict."""
field_types = {f.name: f.type for f in dc_cls.__dataclass_fields__.values()}
kwargs: dict[str, Any] = {}
for key, value in data.items():
if key in field_types and isinstance(value, dict):
inner_cls = dc_cls.__dataclass_fields__[key].type
# Resolve string type annotations
if isinstance(inner_cls, str):
inner_cls = globals().get(inner_cls, inner_cls)
kwargs[key] = _dict_to_dataclass(inner_cls, value)
else:
kwargs[key] = value
return dc_cls(**kwargs)
def load_config(yaml_path: Optional[str | Path] = None) -> Config:
"""Load configuration from a YAML file, falling back to defaults."""
if yaml_path is None:
yaml_path = PROJECT_ROOT / "configs" / "default.yaml"
path = Path(yaml_path)
if path.exists():
logger.info("Loading config from %s", path)
with open(path, "r", encoding="utf-8") as fh:
overrides = yaml.safe_load(fh) or {}
# Build default dict, merge overrides, convert back
import dataclasses
base = dataclasses.asdict(Config())
merged = _merge(base, overrides)
return _dict_to_dataclass(Config, merged)
logger.info("No config file found at %s – using defaults.", path)
return Config()
===
"""Central configuration for the prompt-optimizer project."""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
import yaml
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Project root – resolved relative to this file so it works both locally
# and on Colab after cloning the repo.
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).resolve().parent.parent
@dataclass
class ModelConfig:
name: str = "mistralai/Mistral-7B-Instruct-v0.2"
trust_remote_code: bool = True
torch_dtype: str = "float16" # "float16" | "bfloat16" | "auto"
@dataclass
class QuantizationConfig:
load_in_4bit: bool = True
bnb_4bit_quant_type: str = "nf4"
bnb_4bit_compute_dtype: str = "float16"
bnb_4bit_use_double_quant: bool = True
@dataclass
class LoraConfig:
r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.05
target_modules: list[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)
bias: str = "none"
task_type: str = "CAUSAL_LM"
@dataclass
class TrainingConfig:
output_dir: str = "outputs/checkpoints"
num_train_epochs: int = 3
per_device_train_batch_size: int = 2
gradient_accumulation_steps: int = 4
learning_rate: float = 2e-4
weight_decay: float = 0.01
warmup_ratio: float = 0.03
lr_scheduler_type: str = "cosine"
max_seq_length: int = 1024
logging_steps: int = 10
save_strategy: str = "epoch"
save_total_limit: int = 2
fp16: bool = True
bf16: bool = False
gradient_checkpointing: bool = True
optim: str = "paged_adamw_8bit"
report_to: str = "none"
seed: int = 42
@dataclass
class GenerationConfig:
max_new_tokens: int = 512
temperature: float = 0.3
top_p: float = 0.9
repetition_penalty: float = 1.15
do_sample: bool = True
@dataclass
class DatasetConfig:
raw_dir: str = "data/raw"
processed_dir: str = "data/processed"
train_file: str = "train.jsonl"
val_file: str = "val.jsonl"
val_split: float = 0.1
seed: int = 42
max_samples: Optional[int] = None # limit for quick tests
@dataclass
class EvalConfig:
max_token_threshold: int = 300
min_compression_ratio: float = 0.5
@dataclass
class UIConfig:
server_name: str = "0.0.0.0"
server_port: int = 7860
share: bool = False
@dataclass
class Config:
model: ModelConfig = field(default_factory=ModelConfig)
quantization: QuantizationConfig = field(default_factory=QuantizationConfig)
lora: LoraConfig = field(default_factory=LoraConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
generation: GenerationConfig = field(default_factory=GenerationConfig)
dataset: DatasetConfig = field(default_factory=DatasetConfig)
evaluation: EvalConfig = field(default_factory=EvalConfig)
ui: UIConfig = field(default_factory=UIConfig)
# ---- adapter / output helpers ------------------------------------------
@property
def adapter_dir(self) -> Path:
return PROJECT_ROOT / "outputs" / "adapter"
@property
def merged_dir(self) -> Path:
return PROJECT_ROOT / "outputs" / "merged"
# ---------------------------------------------------------------------------
# YAML loader
# ---------------------------------------------------------------------------
def _merge(base: dict, override: dict) -> dict:
"""Recursively merge *override* into *base*."""
for k, v in override.items():
if isinstance(v, dict) and isinstance(base.get(k), dict):
_merge(base[k], v)
else:
base[k] = v
return base
# Registry of known nested dataclass types for reliable resolution
_DATACLASS_REGISTRY: dict[str, type] = {
"model": ModelConfig,
"quantization": QuantizationConfig,
"lora": LoraConfig,
"training": TrainingConfig,
"generation": GenerationConfig,
"dataset": DatasetConfig,
"evaluation": EvalConfig,
"ui": UIConfig,
}
def _dict_to_dataclass(dc_cls: type, data: dict) -> Any:
"""Instantiate a nested dataclass hierarchy from a plain dict."""
kwargs: dict[str, Any] = {}
for key, value in data.items():
if isinstance(value, dict) and key in _DATACLASS_REGISTRY:
kwargs[key] = _dict_to_dataclass(_DATACLASS_REGISTRY[key], value)
else:
kwargs[key] = value
return dc_cls(**kwargs)
def load_config(yaml_path: Optional[str | Path] = None) -> Config:
"""Load configuration from a YAML file, falling back to defaults."""
if yaml_path is None:
yaml_path = PROJECT_ROOT / "configs" / "default.yaml"
path = Path(yaml_path)
if path.exists():
logger.info("Loading config from %s", path)
with open(path, "r", encoding="utf-8") as fh:
overrides = yaml.safe_load(fh) or {}
# Build default dict, merge overrides, convert back
import dataclasses
base = dataclasses.asdict(Config())
merged = _merge(base, overrides)
return _dict_to_dataclass(Config, merged)
logger.info("No config file found at %s – using defaults.", path)
return Config()Problem: Both bnb_4bit_compute_dtype and torch_dtype were hardcoded to torch.float16, ignoring YAML config values. This means switching to bfloat16 in config had no effect.
Fix: Added a local _resolve_dtype() helper that maps config strings to concrete torch.dtype values, including auto-detection of bf16 support.
"""Inference engine – load base model + LoRA adapter and generate optimised prompts.
Designed so that both the CLI script and the Gradio UI can reuse the same
``PromptOptimizer`` class without duplicating logic.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import torch
from peft import PeftModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedTokenizerBase,
)
from src.config import Config, PROJECT_ROOT
from src.dataset.formatter import SYSTEM_MESSAGE
logger = logging.getLogger(__name__)
class PromptOptimizer:
"""Wraps model loading and generation behind a simple API."""
def __init__(self, cfg: Config, adapter_path: Optional[Path] = None):
self.cfg = cfg
self._adapter_path = adapter_path or cfg.adapter_dir
self.model, self.tokenizer = self._load()
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
def _load(self):
bnb_config = BitsAndBytesConfig(
load_in_4bit=self.cfg.quantization.load_in_4bit,
bnb_4bit_quant_type=self.cfg.quantization.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=self.cfg.quantization.bnb_4bit_use_double_quant,
)
logger.info("Loading base model: %s", self.cfg.model.name)
base_model = AutoModelForCausalLM.from_pretrained(
self.cfg.model.name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=self.cfg.model.trust_remote_code,
torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.model.name,
trust_remote_code=self.cfg.model.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
adapter = Path(self._adapter_path)
if adapter.exists():
logger.info("Loading LoRA adapter from %s", adapter)
model = PeftModel.from_pretrained(base_model, str(adapter))
else:
logger.warning(
"Adapter not found at %s – running base model only.", adapter
)
model = base_model
model.eval()
return model, tokenizer
# ------------------------------------------------------------------
# Prompt construction using chat template
# ------------------------------------------------------------------
def _build_input(self, raw_prompt: str) -> str:
"""Format the user prompt using the tokenizer's chat template."""
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": raw_prompt},
]
try:
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
return (
f"### System:\n{SYSTEM_MESSAGE}\n\n"
f"### Instruction:\n{raw_prompt}\n\n"
f"### Response:\n"
)
# ------------------------------------------------------------------
# Generation
# ------------------------------------------------------------------
@torch.inference_mode()
def optimize(
self,
raw_prompt: str,
*,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
do_sample: Optional[bool] = None,
) -> str:
"""Generate an optimised prompt from a raw user prompt.
Generation parameters fall back to ``self.cfg.generation`` defaults
when not explicitly provided.
"""
gen = self.cfg.generation
max_new_tokens = max_new_tokens or gen.max_new_tokens
temperature = temperature if temperature is not None else gen.temperature
top_p = top_p if top_p is not None else gen.top_p
repetition_penalty = repetition_penalty or gen.repetition_penalty
do_sample = do_sample if do_sample is not None else gen.do_sample
input_text = self._build_input(raw_prompt)
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode only the newly generated tokens
generated_ids = outputs[0][inputs["input_ids"].shape[-1]:]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return result
# ---------------------------------------------------------------------------
# Convenience function for one-off use
# ---------------------------------------------------------------------------
def optimize_prompt(
raw_prompt: str,
cfg: Optional[Config] = None,
adapter_path: Optional[Path] = None,
) -> str:
"""Functional wrapper – loads the model once and returns the optimised prompt."""
if cfg is None:
from src.config import load_config
cfg = load_config()
engine = PromptOptimizer(cfg, adapter_path=adapter_path)
return engine.optimize(raw_prompt)
===
"""Inference engine – load base model + LoRA adapter and generate optimised prompts.
Designed so that both the CLI script and the Gradio UI can reuse the same
``PromptOptimizer`` class without duplicating logic.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import torch
from peft import PeftModel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedTokenizerBase,
)
from src.config import Config, PROJECT_ROOT
from src.dataset.formatter import SYSTEM_MESSAGE
logger = logging.getLogger(__name__)
def _resolve_dtype(name: str) -> torch.dtype:
"""Map a config string to a concrete torch dtype."""
mapping = {"float16": torch.float16, "bfloat16": torch.bfloat16}
if name == "auto":
return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
return mapping.get(name, torch.float16)
class PromptOptimizer:
"""Wraps model loading and generation behind a simple API."""
def __init__(self, cfg: Config, adapter_path: Optional[Path] = None):
self.cfg = cfg
self._adapter_path = adapter_path or cfg.adapter_dir
self.model, self.tokenizer = self._load()
# ------------------------------------------------------------------
# Model loading
# ------------------------------------------------------------------
def _load(self):
compute_dtype = _resolve_dtype(self.cfg.quantization.bnb_4bit_compute_dtype)
model_dtype = _resolve_dtype(self.cfg.model.torch_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=self.cfg.quantization.load_in_4bit,
bnb_4bit_quant_type=self.cfg.quantization.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=self.cfg.quantization.bnb_4bit_use_double_quant,
)
logger.info("Loading base model: %s", self.cfg.model.name)
base_model = AutoModelForCausalLM.from_pretrained(
self.cfg.model.name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=self.cfg.model.trust_remote_code,
torch_dtype=model_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.model.name,
trust_remote_code=self.cfg.model.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
adapter = Path(self._adapter_path)
if adapter.exists():
logger.info("Loading LoRA adapter from %s", adapter)
model = PeftModel.from_pretrained(base_model, str(adapter))
else:
logger.warning(
"Adapter not found at %s – running base model only.", adapter
)
model = base_model
model.eval()
return model, tokenizer
# ------------------------------------------------------------------
# Prompt construction using chat template
# ------------------------------------------------------------------
def _build_input(self, raw_prompt: str) -> str:
"""Format the user prompt using the tokenizer's chat template."""
messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "user", "content": raw_prompt},
]
try:
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
return (
f"### System:\n{SYSTEM_MESSAGE}\n\n"
f"### Instruction:\n{raw_prompt}\n\n"
f"### Response:\n"
)
# ------------------------------------------------------------------
# Generation
# ------------------------------------------------------------------
@torch.inference_mode()
def optimize(
self,
raw_prompt: str,
*,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
do_sample: Optional[bool] = None,
) -> str:
"""Generate an optimised prompt from a raw user prompt.
Generation parameters fall back to ``self.cfg.generation`` defaults
when not explicitly provided.
"""
gen = self.cfg.generation
max_new_tokens = max_new_tokens or gen.max_new_tokens
temperature = temperature if temperature is not None else gen.temperature
top_p = top_p if top_p is not None else gen.top_p
repetition_penalty = repetition_penalty or gen.repetition_penalty
do_sample = do_sample if do_sample is not None else gen.do_sample
input_text = self._build_input(raw_prompt)
inputs = self.tokenizer(input_text, return_tensors="pt").to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id,
)
# Decode only the newly generated tokens
generated_ids = outputs[0][inputs["input_ids"].shape[-1]:]
result = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
return result
# ---------------------------------------------------------------------------
# Convenience function for one-off use
# ---------------------------------------------------------------------------
def optimize_prompt(
raw_prompt: str,
cfg: Optional[Config] = None,
adapter_path: Optional[Path] = None,
) -> str:
"""Functional wrapper – loads the model once and returns the optimised prompt."""
if cfg is None:
from src.config import load_config
cfg = load_config()
engine = PromptOptimizer(cfg, adapter_path=adapter_path)
return engine.optimize(raw_prompt)Problem: _resolve_dtype() returned "auto" (a string) for the auto case, but callers expected torch.dtype. Also, SFTTrainer used processing_class= which may not exist in older TRL versions.
Fix: _resolve_dtype("auto") now auto-detects bf16 support and returns a concrete dtype. Changed processing_class=tokenizer → tokenizer=tokenizer for broader compatibility.
"""Fine-tuning script – LoRA / QLoRA with SFTTrainer.
This module encapsulates:
- Model + tokenizer loading with 4-bit quantisation
- LoRA adapter configuration
- Dataset formatting via the tokenizer's chat template
- SFTTrainer initialisation and training loop
- Checkpoint & adapter saving
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import torch
from datasets import DatasetDict
from peft import LoraConfig as PeftLoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer
from src.config import Config, PROJECT_ROOT
from src.dataset.formatter import get_formatting_fn
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Hardware helpers
# ---------------------------------------------------------------------------
def _resolve_dtype(name: str) -> torch.dtype:
mapping = {"float16": torch.float16, "bfloat16": torch.bfloat16, "auto": "auto"}
return mapping.get(name, torch.float16)
def _detect_fp16_bf16(cfg: Config) -> tuple[bool, bool]:
"""Return (fp16, bf16) booleans depending on hardware + config."""
if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return False, True
return True, False
# CPU-only – disable both
return False, False
# ---------------------------------------------------------------------------
# Model / tokenizer loading
# ---------------------------------------------------------------------------
def load_model_and_tokenizer(cfg: Config):
"""Load quantised base model and tokenizer."""
bnb_config = BitsAndBytesConfig(
load_in_4bit=cfg.quantization.load_in_4bit,
bnb_4bit_quant_type=cfg.quantization.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=_resolve_dtype(cfg.quantization.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=cfg.quantization.bnb_4bit_use_double_quant,
)
logger.info("Loading model: %s", cfg.model.name)
model = AutoModelForCausalLM.from_pretrained(
cfg.model.name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=cfg.model.trust_remote_code,
torch_dtype=_resolve_dtype(cfg.model.torch_dtype),
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model.name,
trust_remote_code=cfg.model.trust_remote_code,
)
# Ensure pad token exists (many models lack one)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
return model, tokenizer
# ---------------------------------------------------------------------------
# LoRA setup
# ---------------------------------------------------------------------------
def build_lora_config(cfg: Config) -> PeftLoraConfig:
return PeftLoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.lora_alpha,
lora_dropout=cfg.lora.lora_dropout,
target_modules=cfg.lora.target_modules,
bias=cfg.lora.bias,
task_type=cfg.lora.task_type,
)
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train(
cfg: Config,
dataset: DatasetDict,
resume_from_checkpoint: Optional[str] = None,
) -> None:
"""Run the full fine-tuning pipeline."""
model, tokenizer = load_model_and_tokenizer(cfg)
lora_cfg = build_lora_config(cfg)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
fp16, bf16 = _detect_fp16_bf16(cfg)
output_dir = str(PROJECT_ROOT / cfg.training.output_dir)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=cfg.training.num_train_epochs,
per_device_train_batch_size=cfg.training.per_device_train_batch_size,
gradient_accumulation_steps=cfg.training.gradient_accumulation_steps,
learning_rate=cfg.training.learning_rate,
weight_decay=cfg.training.weight_decay,
warmup_ratio=cfg.training.warmup_ratio,
lr_scheduler_type=cfg.training.lr_scheduler_type,
logging_steps=cfg.training.logging_steps,
save_strategy=cfg.training.save_strategy,
save_total_limit=cfg.training.save_total_limit,
fp16=fp16,
bf16=bf16,
gradient_checkpointing=cfg.training.gradient_checkpointing,
optim=cfg.training.optim,
report_to=cfg.training.report_to,
seed=cfg.training.seed,
remove_unused_columns=False,
)
# Format dataset via chat template
formatting_fn = get_formatting_fn(tokenizer)
train_ds = dataset["train"].map(formatting_fn)
val_ds = dataset["validation"].map(formatting_fn) if "validation" in dataset else None
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
processing_class=tokenizer,
dataset_text_field="text",
max_seq_length=cfg.training.max_seq_length,
packing=False,
)
logger.info("Starting training …")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save adapter
adapter_dir = str(cfg.adapter_dir)
logger.info("Saving LoRA adapter to %s", adapter_dir)
Path(adapter_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
logger.info("Training complete.")
===
"""Fine-tuning script – LoRA / QLoRA with SFTTrainer.
This module encapsulates:
- Model + tokenizer loading with 4-bit quantisation
- LoRA adapter configuration
- Dataset formatting via the tokenizer's chat template
- SFTTrainer initialisation and training loop
- Checkpoint & adapter saving
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
import torch
from datasets import DatasetDict
from peft import LoraConfig as PeftLoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
from trl import SFTTrainer
from src.config import Config, PROJECT_ROOT
from src.dataset.formatter import get_formatting_fn
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Hardware helpers
# ---------------------------------------------------------------------------
def _resolve_dtype(name: str) -> torch.dtype:
"""Map a config string to a concrete torch dtype."""
mapping = {"float16": torch.float16, "bfloat16": torch.bfloat16}
if name == "auto":
return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
return mapping.get(name, torch.float16)
def _detect_fp16_bf16(cfg: Config) -> tuple[bool, bool]:
"""Return (fp16, bf16) booleans depending on hardware + config."""
if torch.cuda.is_available():
if torch.cuda.is_bf16_supported():
return False, True
return True, False
# CPU-only – disable both
return False, False
# ---------------------------------------------------------------------------
# Model / tokenizer loading
# ---------------------------------------------------------------------------
def load_model_and_tokenizer(cfg: Config):
"""Load quantised base model and tokenizer."""
bnb_config = BitsAndBytesConfig(
load_in_4bit=cfg.quantization.load_in_4bit,
bnb_4bit_quant_type=cfg.quantization.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=_resolve_dtype(cfg.quantization.bnb_4bit_compute_dtype),
bnb_4bit_use_double_quant=cfg.quantization.bnb_4bit_use_double_quant,
)
logger.info("Loading model: %s", cfg.model.name)
model = AutoModelForCausalLM.from_pretrained(
cfg.model.name,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=cfg.model.trust_remote_code,
torch_dtype=_resolve_dtype(cfg.model.torch_dtype),
)
model = prepare_model_for_kbit_training(model)
tokenizer = AutoTokenizer.from_pretrained(
cfg.model.name,
trust_remote_code=cfg.model.trust_remote_code,
)
# Ensure pad token exists (many models lack one)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
return model, tokenizer
# ---------------------------------------------------------------------------
# LoRA setup
# ---------------------------------------------------------------------------
def build_lora_config(cfg: Config) -> PeftLoraConfig:
return PeftLoraConfig(
r=cfg.lora.r,
lora_alpha=cfg.lora.lora_alpha,
lora_dropout=cfg.lora.lora_dropout,
target_modules=cfg.lora.target_modules,
bias=cfg.lora.bias,
task_type=cfg.lora.task_type,
)
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train(
cfg: Config,
dataset: DatasetDict,
resume_from_checkpoint: Optional[str] = None,
) -> None:
"""Run the full fine-tuning pipeline."""
model, tokenizer = load_model_and_tokenizer(cfg)
lora_cfg = build_lora_config(cfg)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
fp16, bf16 = _detect_fp16_bf16(cfg)
output_dir = str(PROJECT_ROOT / cfg.training.output_dir)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=cfg.training.num_train_epochs,
per_device_train_batch_size=cfg.training.per_device_train_batch_size,
gradient_accumulation_steps=cfg.training.gradient_accumulation_steps,
learning_rate=cfg.training.learning_rate,
weight_decay=cfg.training.weight_decay,
warmup_ratio=cfg.training.warmup_ratio,
lr_scheduler_type=cfg.training.lr_scheduler_type,
logging_steps=cfg.training.logging_steps,
save_strategy=cfg.training.save_strategy,
save_total_limit=cfg.training.save_total_limit,
fp16=fp16,
bf16=bf16,
gradient_checkpointing=cfg.training.gradient_checkpointing,
optim=cfg.training.optim,
report_to=cfg.training.report_to,
seed=cfg.training.seed,
remove_unused_columns=False,
)
# Format dataset via chat template
formatting_fn = get_formatting_fn(tokenizer)
train_ds = dataset["train"].map(formatting_fn)
val_ds = dataset["validation"].map(formatting_fn) if "validation" in dataset else None
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=tokenizer,
dataset_text_field="text",
max_seq_length=cfg.training.max_seq_length,
packing=False,
)
logger.info("Starting training …")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
# Save adapter
adapter_dir = str(cfg.adapter_dir)
logger.info("Saving LoRA adapter to %s", adapter_dir)
Path(adapter_dir).mkdir(parents=True, exist_ok=True)
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
logger.info("Training complete.")torch>=2.1.0
transformers>=4.40.0
datasets>=2.18.0
peft>=0.10.0
trl>=0.8.0
accelerate>=0.28.0
bitsandbytes>=0.43.0
sentencepiece>=0.2.0
protobuf>=4.25.0
gradio>=4.20.0
pyyaml>=6.0
===
torch>=2.1.0
transformers>=4.40.0
datasets>=2.18.0
peft>=0.10.0
trl>=0.8.0
accelerate>=0.28.0
bitsandbytes>=0.43.0
sentencepiece>=0.2.0
protobuf>=4.25.0
gradio>=4.20.0
pyyaml>=6.0
matplotlib>=3.7.0| File | Status |
|---|---|
| src/init.py | ✅ Verified |
| src/dataset/init.py | ✅ Verified |
| src/dataset/seeds.py | ✅ Verified |
| src/dataset/generator.py | ✅ Verified |
| src/dataset/formatter.py | ✅ Verified |
| src/evaluation/init.py | ✅ Verified |
| src/inference/init.py | ✅ Verified |
| src/training/init.py | ✅ Verified |
| src/utils/init.py | ✅ Verified |
| app/init.py | ✅ Verified |
| configs/default.yaml | ✅ Verified |
| scripts/generate_dataset.py | ✅ Verified |
| scripts/train.py | ✅ Verified |
| scripts/infer.py | ✅ Verified |
| scripts/evaluate.py | ✅ Verified |
| scripts/launch_ui.py | ✅ Verified |
| README.md | ✅ Verified |
| .gitignore | ✅ Verified |
Added two visualization functions (all original functions preserved):
- generate_compression_chart(metrics) – single-prompt bar chart with annotation arrow showing tokens saved and % reduction
- generate_batch_compression_chart(results) – grouped bar chart for batch evaluation
- Added
gr.Plot()component below the metrics section - run_optimization() now returns 7 values (was 6), including the matplotlib Figure
- Added CSS styling and UI polish
Created DEPLOYMENT.md covering:
- GCP Compute Engine setup with L4/T4 GPU
- Complete Dockerfile and Docker Compose
- Vertex AI deployment with custom prediction server
- Nginx + SSL reverse proxy for secure exposure
- systemd auto-restart and monitoring
Created PAPER_OUTLINE.md with:
- 11-section professional structure (Abstract → Appendix)
- Formal problem definition and methodology
- Evaluation metrics framework (automated + manual)
- 5 ablation study suggestions
- LaTeX template header
- 10 key references to cite
All 5 modified Python files pass py_compile:
python -m py_compile src/config.py ✅
python -m py_compile src/evaluation/metrics.py ✅
python -m py_compile src/training/train.py ✅
python -m py_compile src/inference/engine.py ✅
python -m py_compile app/ui.py ✅