Skip to content

OOM on large datasets due to O(n²) pair generation in contrastive learning #628

@Wert1996

Description

@Wert1996

Problem

SetFit training goes OOM (Out of Memory) on main memory (not GPU VRAM) for large datasets. The root cause is that pair generation for contrastive learning materializes O(n²) data in memory before training begins.

Environment

  • SetFit version: 1.1.3
  • Python version: 3.10
  • OS: Linux

Reproduction

Any dataset with a sufficiently large number of samples using contrastive loss (e.g., CosineSimilarityLoss) will likely OOM on machines with a given RAM. I tried a 40,000 sample dataset with num_iterations=1 on a 32 GB machine.

from setfit import SetFitModel, Trainer, TrainingArguments
from datasets import load_dataset

# Load a large dataset
dataset = load_dataset("some_large_dataset")  # e.g., 100k+ samples

model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(
    model=model,
    train_dataset=dataset["train"],
    args=TrainingArguments(num_iterations=20),
)
trainer.train()  # OOM before training starts

Root Cause Analysis

There are 3 layers of memory explosion:

Layer 1: shuffle_combinations() in src/setfit/sampler.py

def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
    n = len(iterable)
    k = 1 if not replacement else 0
    idxs = np.stack(np.triu_indices(n, k), axis=-1)  # O(n²) memory!
    for i in np.random.RandomState(seed=42).permutation(len(idxs)):  # Another O(n²) array!
        _idx, idx = idxs[i, :]
        yield iterable[_idx], iterable[idx]

Despite being typed as a Generator, it allocates ALL n*(n-1)/2 pair indices upfront before yielding anything.

Layer 2: ContrastiveDataset.generate_pairs()

def generate_pairs(self) -> None:
    for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
        if is_positive:
            self.pos_pairs.append(...)  # Stores all pairs in lists
        else:
            self.neg_pairs.append(...)

Layer 3: trainer.py line 618

dataset = Dataset.from_list(list(data_sampler))  # Materializes iterator + creates copy

Suggested Solution

The solution involves replacing eager pair generation with streaming:

  1. ContrastiveDataset: Generate pairs on the fly and track uniqueness using a set.

  2. Trainers: Use IterableDataset.from_generator() instead of Dataset.from_list(list(...))

  3. Memory after fix: O(n) for label groups + O(num_pairs_sampled) for uniqueness set

I have created a draft PR for this, would be happy to discuss and contribute it here. Here is the draft PR: #627

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions