Conversation
Underlying SentenceTransformerTrainer needs fp16 flag on initialization already, to properly setup amp.
|
Hello! I believe the See e.g. this script, with from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2",
labels=["negative", "positive"],
)
args = TrainingArguments(
batch_size=16,
num_epochs=4,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric="accuracy",
column_mapping={"sentence": "text", "label": "label"}
)
print(trainer.st_trainer.args.fp16)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)
# {'accuracy': 0.8691709844559585}
|
|
Hello Tom, Thanks for taking your time to review my commit. You're correct in that the fp16 argument gets propagated from the setfit Trainer to the sentence-transformers Trainer. But, this argument is not the only one driving the use of amp. For this self.use_apex, self.use_cpu_amp & self.amp_dtype are also set based on the fp16 argument inside the transformers Trainer init function starting from line 724, commented with "# Mixed precision setup". This init is called from the sentence-transformers Trainer init. Which gets called right before _apply_training_arguments. As far as I can tell these three underlying parameters are not adjusted by _apply_training_arguments afterwards. Best regards, |
|
Hello @tomaarsen, I had a little spare time today to investigate a bit further. The three parameters described in my previous comment are only of importance when using APEX, or running on a CPU. But I got a performance increase despite neither of those are true for my setup. I slightly altered your script to output the additional parameters, and measure the training time. When I run it several times with my commit I get 70.695 ± 1.225s, versus 78.675 ± 2.405s for the original. More importantly though, the accuracy changes between the two (but stays constant for multiple calls of the same version), clearly indicating an impact. from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
import os
import time
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
"sentence-transformers/paraphrase-mpnet-base-v2",
labels=["negative", "positive"],
)
args = TrainingArguments(
batch_size=16,
num_epochs=4,
use_amp=True,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
metric="accuracy",
column_mapping={"sentence": "text", "label": "label"}
)
print(f'args.fp16: {trainer.st_trainer.args.fp16}')
print(f'args.backend: {trainer.st_trainer.args.half_precision_backend}')
print(f'trainer.use_apex: {trainer.st_trainer.use_apex}')
print(f'trainer.use_cpu_amp: {trainer.st_trainer.use_cpu_amp}')
print(f'env.accelerate: {os.environ.get("ACCELERATE_MIXED_PRECISION", "no")}')
start_time = time.perf_counter()
# Train and evaluate
trainer.train()
elapsed_time = time.perf_counter() - start_time
print(f'training time: {elapsed_time}')
metrics = trainer.evaluate(test_dataset)
print(metrics)Best regards, |
Underlying SentenceTransformerTrainer needs fp16 flag on initialization already, to properly setup amp.