Skip to content

Different batch sizes per dataset/loss (CachedMNRL + CoSENT) in multi-dataset training #3684

@ddofer

Description

@ddofer

Hi,
I’m training a model using the v3 multi-dataset training API to combine a cached contrastive loss (CachedMultipleNegativesRankingLoss, or cached gist, etc) on one dataset with a regression loss (CoSENTLoss, or triplet loss) on a second dataset.

Because of the different memory requirements and nature of these losses, I want to use different batch sizes for each dataset. For example, 2K batch size with the cached mnrl, and 3264 with the triplet or cosent losses (i.e the max effective batch size I can fit in per device memory).

Here is a simplified version of my setup:

from sentence_transformers import SentenceTransformerTrainer, losses
from sentence_transformers.training_args import SentenceTransformerTrainingArguments, MultiDatasetBatchSamplers

train_dataset = {
    "contrastive_ds": dataset_a, # e.g. positive pairs
    "regression_ds": dataset_b,  # e.g. pairs with similarity scores
}

loss_model = {
    "contrastive_ds": losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=32),
    "regression_ds": losses.CoSENTLoss(model),
}

args = SentenceTransformerTrainingArguments(
    per_device_train_batch_size=64,  # <-- Applies globally?
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.ROUND_ROBIN,
    output_dir="./output"
)

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss_model,
)

My questions:

  1. Dataset-Specific Batch Sizes: Is there a supported way to specify different per_device_train_batch_size values per dataset in SentenceTransformerTrainingArguments? (Multi GPU/DDP setup)
  2. Interaction with CachedMNRL: Since CachedMNRL accepts a mini_batch_size parameter, is the recommended pattern to set a globally large per_device_train_batch_size? If I do this, won't CoSENTLoss attempt to process the massive global batch size all at once and cause an OOM error?
  3. Best Practices: What is the recommended pattern for decoupling the batch sizes when mixing cached and non-cached losses via ROUND_ROBIN? (I keep getting nvcc crashes/issues, maybe due to batch sizes not being cleany divisible in the multigpu setting and different losses)

Thanks in advance for any guidance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions