-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsft.py
More file actions
124 lines (107 loc) · 3.74 KB
/
sft.py
File metadata and controls
124 lines (107 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from huggingface_hub import login
import os
import torch
import transformers
from datasets import Dataset
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
TaskType
)
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TrainingArguments
)
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
import random
import numpy as np
import pandas as pd
from trl import SFTTrainer, SFTConfig
from utils import create_bnb_config, create_peft_config, prepare_dataset, seed_everything
seed_everything(0)
train_df = pd.read_parquet('/projects/florence_echo/ankush_agent_projects/SFT/train-00000-of-00001_new.parquet')
val_df = pd.read_parquet('/projects/florence_echo/ankush_agent_projects/SFT/validation-00000-of-00001_new.parquet')
test_df = pd.read_parquet('/projects/florence_echo/ankush_agent_projects/SFT/test-00000-of-00001_new.parquet')
# Model and training configuration
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
OUTPUT_DIR = "medical_qa_model"
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.1
LEARNING_RATE = 2e-4
BATCH_SIZE = 32
NUM_EPOCHS = 3
train_dataset = prepare_dataset(train_df)
val_dataset = prepare_dataset(val_df)
test_dataset = prepare_dataset(test_df)
def setup_model_and_tokenizer():
"""Setup the model and tokenizer with quantization and PEFT."""
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=create_bnb_config(),
trust_remote_code=True
)
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)
# Add LoRA adaptor
model = get_peft_model(model, create_peft_config(LORA_R=LORA_R, LORA_ALPHA=LORA_ALPHA, LORA_DROPOUT=LORA_DROPOUT))
return model, tokenizer
from trl import DataCollatorForCompletionOnlyLM
def main():
"""Main training function."""
# Initialize accelerator with DeepSpeed plugin
accelerator = Accelerator()
# Setup model and tokenizer
model, tokenizer = setup_model_and_tokenizer()
response_template = "### Response:\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
# Setup training arguments
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
eval_strategy="epoch",
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=1,
lr_scheduler_type="cosine",
warmup_ratio=0.3,
learning_rate=LEARNING_RATE,
bf16=True,
fp16=False,
logging_strategy="epoch",
save_strategy="epoch",
report_to="none",
dataloader_num_workers=8,
dataloader_prefetch_factor=2,
max_seq_length=512,
dataloader_persistent_workers=True,
remove_unused_columns=True,
optim="paged_adamw_32bit",
)
# Create trainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
args=training_args,
peft_config=create_peft_config(LORA_R=LORA_R, LORA_ALPHA=LORA_ALPHA, LORA_DROPOUT=LORA_DROPOUT),
processing_class=tokenizer,
data_collator=collator, # Use default collator
)
# Prepare everything with accelerator
trainer = accelerator.prepare(trainer)
# Train the model
trainer.train()
# Save the model (only on main process)
if accelerator.is_main_process:
trainer.save_model(OUTPUT_DIR)
if __name__ == "__main__":
main()