Problem
SetFit training goes OOM (Out of Memory) on main memory (not GPU VRAM) for large datasets. The root cause is that pair generation for contrastive learning materializes O(n²) data in memory before training begins.
Environment
- SetFit version: 1.1.3
- Python version: 3.10
- OS: Linux
Reproduction
Any dataset with a sufficiently large number of samples using contrastive loss (e.g., CosineSimilarityLoss) will likely OOM on machines with a given RAM. I tried a 40,000 sample dataset with num_iterations=1 on a 32 GB machine.
from setfit import SetFitModel, Trainer, TrainingArguments
from datasets import load_dataset
# Load a large dataset
dataset = load_dataset("some_large_dataset") # e.g., 100k+ samples
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(
model=model,
train_dataset=dataset["train"],
args=TrainingArguments(num_iterations=20),
)
trainer.train() # OOM before training starts
Root Cause Analysis
There are 3 layers of memory explosion:
Layer 1: shuffle_combinations() in src/setfit/sampler.py
def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1) # O(n²) memory!
for i in np.random.RandomState(seed=42).permutation(len(idxs)): # Another O(n²) array!
_idx, idx = idxs[i, :]
yield iterable[_idx], iterable[idx]
Despite being typed as a Generator, it allocates ALL n*(n-1)/2 pair indices upfront before yielding anything.
Layer 2: ContrastiveDataset.generate_pairs()
def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if is_positive:
self.pos_pairs.append(...) # Stores all pairs in lists
else:
self.neg_pairs.append(...)
Layer 3: trainer.py line 618
dataset = Dataset.from_list(list(data_sampler)) # Materializes iterator + creates copy
Suggested Solution
The solution involves replacing eager pair generation with streaming:
-
ContrastiveDataset: Generate pairs on the fly and track uniqueness using a set.
-
Trainers: Use IterableDataset.from_generator() instead of Dataset.from_list(list(...))
-
Memory after fix: O(n) for label groups + O(num_pairs_sampled) for uniqueness set
I have created a draft PR for this, would be happy to discuss and contribute it here. Here is the draft PR: #627
Problem
SetFit training goes OOM (Out of Memory) on main memory (not GPU VRAM) for large datasets. The root cause is that pair generation for contrastive learning materializes O(n²) data in memory before training begins.
Environment
Reproduction
Any dataset with a sufficiently large number of samples using contrastive loss (e.g.,
CosineSimilarityLoss) will likely OOM on machines with a given RAM. I tried a 40,000 sample dataset with num_iterations=1 on a 32 GB machine.Root Cause Analysis
There are 3 layers of memory explosion:
Layer 1:
shuffle_combinations()insrc/setfit/sampler.pyDespite being typed as a
Generator, it allocates ALL n*(n-1)/2 pair indices upfront before yielding anything.Layer 2:
ContrastiveDataset.generate_pairs()Layer 3:
trainer.pyline 618Suggested Solution
The solution involves replacing eager pair generation with streaming:
ContrastiveDataset: Generate pairs on the fly and track uniqueness using a set.Trainers: Use
IterableDataset.from_generator()instead ofDataset.from_list(list(...))Memory after fix: O(n) for label groups + O(num_pairs_sampled) for uniqueness set
I have created a draft PR for this, would be happy to discuss and contribute it here. Here is the draft PR: #627