-
Notifications
You must be signed in to change notification settings - Fork 254
Description
Problem
SetFit training goes OOM (Out of Memory) on main memory (not GPU VRAM) for large datasets. The root cause is that pair generation for contrastive learning materializes O(n²) data in memory before training begins.
Environment
- SetFit version: 1.1.3
- Python version: 3.10
- OS: Linux
Reproduction
Any dataset with a sufficiently large number of samples using contrastive loss (e.g., CosineSimilarityLoss) will likely OOM on machines with a given RAM. I tried a 40,000 sample dataset with num_iterations=1 on a 32 GB machine.
from setfit import SetFitModel, Trainer, TrainingArguments
from datasets import load_dataset
# Load a large dataset
dataset = load_dataset("some_large_dataset") # e.g., 100k+ samples
model = SetFitModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
trainer = Trainer(
model=model,
train_dataset=dataset["train"],
args=TrainingArguments(num_iterations=20),
)
trainer.train() # OOM before training startsRoot Cause Analysis
There are 3 layers of memory explosion:
Layer 1: shuffle_combinations() in src/setfit/sampler.py
def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
n = len(iterable)
k = 1 if not replacement else 0
idxs = np.stack(np.triu_indices(n, k), axis=-1) # O(n²) memory!
for i in np.random.RandomState(seed=42).permutation(len(idxs)): # Another O(n²) array!
_idx, idx = idxs[i, :]
yield iterable[_idx], iterable[idx]Despite being typed as a Generator, it allocates ALL n*(n-1)/2 pair indices upfront before yielding anything.
Layer 2: ContrastiveDataset.generate_pairs()
def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
if is_positive:
self.pos_pairs.append(...) # Stores all pairs in lists
else:
self.neg_pairs.append(...)Layer 3: trainer.py line 618
dataset = Dataset.from_list(list(data_sampler)) # Materializes iterator + creates copySuggested Solution
The solution involves replacing eager pair generation with streaming:
-
ContrastiveDataset: Generate pairs on the fly and track uniqueness using a set. -
Trainers: Use
IterableDataset.from_generator()instead ofDataset.from_list(list(...)) -
Memory after fix: O(n) for label groups + O(num_pairs_sampled) for uniqueness set
I have created a draft PR for this, would be happy to discuss and contribute it here. Here is the draft PR: #627