Skip to content

Update trainer.py#605

Closed
smartIU2 wants to merge 1 commit intohuggingface:mainfrom
smartIU2:main
Closed

Update trainer.py#605
smartIU2 wants to merge 1 commit intohuggingface:mainfrom
smartIU2:main

Conversation

@smartIU2
Copy link
Copy Markdown

@smartIU2 smartIU2 commented Jun 1, 2025

Underlying SentenceTransformerTrainer needs fp16 flag on initialization already, to properly setup amp.

Underlying SentenceTransformerTrainer needs fp16 flag on initialization already, to properly setup amp.
@tomaarsen
Copy link
Copy Markdown
Member

Hello!

I believe the setfit_args setter, which calls _apply_training_arguments on the BCSentenceTransformersTrainer should take care of propagating the SetFit training arguments to the Sentence Transformers training arguments. In short, I don't believe this is necessary.

See e.g. this script, with print(trainer.st_trainer.args.fp16) before training. If you set use_amp=True, then it'll be True, otherwise it's False.

from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    labels=["negative", "positive"],
)

args = TrainingArguments(
    batch_size=16,
    num_epochs=4,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={"sentence": "text", "label": "label"}
)
print(trainer.st_trainer.args.fp16)

# Train and evaluate
trainer.train()
metrics = trainer.evaluate(test_dataset)
print(metrics)
# {'accuracy': 0.8691709844559585}
  • Tom Aarsen

@tomaarsen tomaarsen closed this Aug 5, 2025
@smartIU2
Copy link
Copy Markdown
Author

smartIU2 commented Aug 5, 2025

Hello Tom,

Thanks for taking your time to review my commit.

You're correct in that the fp16 argument gets propagated from the setfit Trainer to the sentence-transformers Trainer.

But, this argument is not the only one driving the use of amp. For this self.use_apex, self.use_cpu_amp & self.amp_dtype are also set based on the fp16 argument inside the transformers Trainer init function starting from line 724, commented with "# Mixed precision setup". This init is called from the sentence-transformers Trainer init. Which gets called right before _apply_training_arguments.

As far as I can tell these three underlying parameters are not adjusted by _apply_training_arguments afterwards.

Best regards,
Martin

@smartIU2
Copy link
Copy Markdown
Author

smartIU2 commented Aug 9, 2025

Hello @tomaarsen,

I had a little spare time today to investigate a bit further. The three parameters described in my previous comment are only of importance when using APEX, or running on a CPU. But I got a performance increase despite neither of those are true for my setup.
So, I found that the transformers TrainingArguments got a __post_init__ function, which also uses the fp16 argument, before the setfit _apply_training_arguments. In particular it sets an environment variable "ACCELERATE_MIXED_PRECISION", which made the difference for me, as I got accelerate installed.

I slightly altered your script to output the additional parameters, and measure the training time. When I run it several times with my commit I get 70.695 ± 1.225s, versus 78.675 ± 2.405s for the original. More importantly though, the accuracy changes between the two (but stays constant for multiple calls of the same version), clearly indicating an impact.

from datasets import load_dataset
from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset

import os
import time


# Load a dataset from the Hugging Face Hub
dataset = load_dataset("sst2")

# Simulate the few-shot regime by sampling 8 examples per class
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
eval_dataset = dataset["validation"].select(range(100))
test_dataset = dataset["validation"].select(range(100, len(dataset["validation"])))

# Load a SetFit model from Hub
model = SetFitModel.from_pretrained(
    "sentence-transformers/paraphrase-mpnet-base-v2",
    labels=["negative", "positive"],
)

args = TrainingArguments(
    batch_size=16,
    num_epochs=4,
    use_amp=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    metric="accuracy",
    column_mapping={"sentence": "text", "label": "label"}
)

print(f'args.fp16: {trainer.st_trainer.args.fp16}')
print(f'args.backend: {trainer.st_trainer.args.half_precision_backend}')
print(f'trainer.use_apex: {trainer.st_trainer.use_apex}')
print(f'trainer.use_cpu_amp: {trainer.st_trainer.use_cpu_amp}')
print(f'env.accelerate: {os.environ.get("ACCELERATE_MIXED_PRECISION", "no")}')

start_time = time.perf_counter()

# Train and evaluate
trainer.train()

elapsed_time = time.perf_counter() - start_time
print(f'training time: {elapsed_time}')

metrics = trainer.evaluate(test_dataset)
print(metrics)

Best regards,
Martin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants