Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ContrastiveDataset and ContrastiveDistillationDataset #579

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cb4e803
fixed to work with processing_class instead tokenizer after transform…
DemirTonchev Dec 20, 2024
6a69303
refactor attempt for ContrastiveDataset so that it does not blow RAM …
DemirTonchev Dec 23, 2024
09869e1
added Samplit strategy enum
DemirTonchev Dec 24, 2024
82474dc
improved logic and fixed iterator pattern
DemirTonchev Dec 24, 2024
d01dc36
fix for negative samples formula
DemirTonchev Dec 24, 2024
53eaace
added multilalbel support as in the original implementation
DemirTonchev Dec 24, 2024
3e3fa5f
ContrastiveDataset iterator refactor
DemirTonchev Dec 25, 2024
2d5e29b
trainer fixed to work with ContrastiveDataset iter method
DemirTonchev Dec 25, 2024
1c905b1
ContrastiveDistillationDataset iter refactor
DemirTonchev Dec 25, 2024
9cafa02
typing fix
DemirTonchev Dec 26, 2024
421b1e5
Merge branch 'main' into pr-579
tomaarsen Jan 10, 2025
ea088ff
TypeAlias will be deprecated in 3.12 again, so let's avoid it
tomaarsen Jan 10, 2025
467c507
Remove args.sampling_strategy from ContrastiveDistillationDataset init
tomaarsen Jan 10, 2025
9784485
Run formatting
tomaarsen Jan 10, 2025
3f67190
Add 'save_strategy="no"' in tests to counteract transformers v4.48.0 bug
tomaarsen Jan 10, 2025
509ee68
fix some autocomplete mistake
DemirTonchev Jan 11, 2025
8fdbbb3
Merge branch 'refactor-contrds' of github.com:DemirTonchev/setfit int…
DemirTonchev Jan 11, 2025
7d9119b
cleanup leftovers attributes from old logic
DemirTonchev Jan 12, 2025
fe30eaa
docs
DemirTonchev Jan 13, 2025
e7f9cf0
fix to correctly specify positives and negatives when passing num_ite…
DemirTonchev Jan 13, 2025
d989f22
fix for negative or positive examples only(not really possible)
DemirTonchev Jan 13, 2025
d229add
removed unnecessary np uses and casting
DemirTonchev Jan 13, 2025
ea88e9b
safeguard for oversampling strategy, when DS is negatives only or sin…
DemirTonchev Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 136 additions & 105 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,64 @@
from itertools import zip_longest
from typing import Dict, Generator, Iterable, List, Optional, Union
from collections import Counter
from itertools import combinations
from typing import Dict, Generator, Iterable, List, Literal, Optional, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset
from transformers.utils import ExplicitEnum

from . import logging


logging.set_verbosity_info()
logger = logging.get_logger(__name__)

SentencePair = Dict[str, Union[str, float]]

def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:

class SamplingStrategy(ExplicitEnum):
"""
## Oversampling

By default, SetFit applies the oversampling strategy for its contrastive pairs. This strategy samples an equal amount of positive and negative training
pairs, oversampling the minority pair type to match that of the majority pair type. As the number of negative pairs is generally larger than the number
of positive pairs, this usually involves oversampling the positive pairs.

In our running example, this would involve oversampling the 62 positive pairs up to 128, resulting in one epoch of 128 + 128 = 256 pairs. In summary:

* Y An equal amount of positive and negative pairs are sampled.
* Y Every possible pair is used.
* X There is some data duplication.

## Undersampling

Like oversampling, this strategy samples an equal amount of positive and negative training pairs. However, it undersamples the majority pair type to match
that of the minority pair type. This usually involves undersampling the negative pairs to match the positive pairs.

In our running example, this would involve undersampling the 128 negative pairs down to 62, resulting in one epoch of 62 + 62 = 124 pairs. In summary:

* Y An equal amount of positive and negative pairs are sampled.
* X **Not** every possible pair is used.
* Y There is **no** data duplication.

## Unique

Thirdly, the unique strategy does not sample an equal amount of positive and negative training pairs. Instead, it simply samples all possible pairs exactly
once. No form of oversampling or undersampling is used here.

In our running example, this would involve sampling all negative and positive pairs, resulting in one epoch of 62 + 128 = 190 pairs. In summary:

* X **Not** an equal amount of positive and negative pairs are sampled.
* Y Every possible pair is used.
* Y There is **no** data duplication.
"""

OVERSAMPLING = "oversampling"
UNDERSAMPLING = "undersampling"
UNIQUE = "unique"


def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Generator:
"""Generates shuffled pair combinations for any iterable data provided.

Args:
Expand All @@ -36,9 +82,9 @@ def __init__(
self,
sentences: List[str],
labels: List[Union[int, float]],
multilabel: bool,
multilabel: bool = False, # False for now
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
sampling_strategy: Literal["oversampling", "undersampling", "unique"] = "oversampling",
max_pairs: int = -1,
) -> None:
"""Generates positive and negative text pairs for contrastive learning.
Expand All @@ -54,94 +100,85 @@ def __init__(
max_pairs pairs.
"""
super().__init__()
self.pos_index = 0
self.neg_index = 0
self.pos_pairs = []
self.neg_pairs = []
self.sentences = sentences
self.labels = labels
self.sentence_labels = list(zip(self.sentences, self.labels))
self.max_pos_or_neg = -1 if max_pairs == -1 else max_pairs // 2
self.max_pos_or_neg = np.inf if max_pairs == -1 else max_pairs // 2
self._multilabel = multilabel

sampling_strategy = SamplingStrategy(sampling_strategy)

if multilabel:
self.generate_multilabel_pairs()
else:
self.generate_pairs()
# calculate number of positive and negative combinations
label_counts = Counter(labels)
# postive number of pairs from an n element set without replacement
self.pos_pairs_combination = sum([n * (n - 1) // 2 for n in label_counts.values()])
# negative product
self.neg_pairs_combination = sum(a * b for a, b in combinations(label_counts.values(), 2))

if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
self.len_neg_pairs = num_iterations * len(self.sentences)

elif sampling_strategy == "unique":
self.len_pos_pairs = len(self.pos_pairs)
self.len_neg_pairs = len(self.neg_pairs)

elif sampling_strategy == "undersampling":
self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs))

elif sampling_strategy == "oversampling":
self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs))
self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs))

else:
raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")

def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
is_positive = _label == label
is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg
is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg

if is_positive:
if not is_positive_full:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
elif not is_negative_full:
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})

if is_positive_full and is_negative_full:
break

def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
# logical_and checks if labels are both set for each class
is_positive = any(np.logical_and(_label, label))
is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg
is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg

if is_positive:
if not is_positive_full:
self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0})
elif not is_negative_full:
self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0})

if is_positive_full and is_negative_full:
break

def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]:
pairs = []
for _ in range(self.len_pos_pairs):
if self.pos_index >= len(self.pos_pairs):
self.pos_index = 0
pairs.append(self.pos_pairs[self.pos_index])
self.pos_index += 1
return pairs

def get_negative_pairs(self) -> List[Dict[str, Union[str, float]]]:
pairs = []
for _ in range(self.len_neg_pairs):
if self.neg_index >= len(self.neg_pairs):
self.neg_index = 0
pairs.append(self.neg_pairs[self.neg_index])
self.neg_index += 1
return pairs

def __iter__(self):
for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()):
if pos_pair is not None:
yield pos_pair
if neg_pair is not None:
yield neg_pair
iterations = num_iterations * len(self.sentences)
self.len_pos_pairs = iterations if self.pos_pairs_combination > 0 else 0
self.len_neg_pairs = iterations if self.neg_pairs_combination > 0 else 0

elif sampling_strategy == SamplingStrategy.UNIQUE:
self.len_pos_pairs = min(self.pos_pairs_combination, self.max_pos_or_neg)
self.len_neg_pairs = min(self.neg_pairs_combination, self.max_pos_or_neg)

elif sampling_strategy == SamplingStrategy.UNDERSAMPLING:
self.len_pos_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])
self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg])

elif sampling_strategy == SamplingStrategy.OVERSAMPLING:
num_pos_or_neg_pairs = max(self.pos_pairs_combination, self.neg_pairs_combination)
# saveguard for either negative samples only or single positive.
self.len_pos_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.pos_pairs_combination > 0 else 0
self.len_neg_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.neg_pairs_combination > 0 else 0

def generate_positive_pair(self) -> Generator[SentencePair, None, None]:
pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (_text, _label), (text, label) in pair_generator:
if self._multilabel:
is_positive = any(np.logical_and(_label, label))
else:
is_positive = _label == label

if is_positive:
yield {"sentence_1": _text, "sentence_2": text, "label": 1.0}
# restart
pair_generator = shuffle_combinations(self.sentence_labels)

def generate_negative_pair(self) -> Generator[SentencePair, None, None]:
pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (_text, _label), (text, label) in pair_generator:
if self._multilabel:
is_negative = not any(np.logical_and(_label, label))
else:
is_negative = _label != label

if is_negative:
yield {"sentence_1": _text, "sentence_2": text, "label": 0.0}
pair_generator = shuffle_combinations(self.sentence_labels)

def __iter__(self) -> Generator[SentencePair, None, None]:
"""Use this to create generator to loop over the dataset.
You should rewrite this method if you want to change how pairs are generated.
"""

generated_pos_pairs = 0
generated_neg_pairs = 0

pos_generator = self.generate_positive_pair()
neg_generator = self.generate_negative_pair()

while (generated_pos_pairs + generated_neg_pairs) < len(self):
if generated_pos_pairs < self.len_pos_pairs:
yield next(pos_generator)
generated_pos_pairs += 1
if generated_neg_pairs < self.len_neg_pairs:
yield next(neg_generator)
generated_neg_pairs += 1

def __len__(self) -> int:
return self.len_pos_pairs + self.len_neg_pairs
Expand All @@ -153,7 +190,6 @@ def __init__(
sentences: List[str],
cos_sim_matrix: torch.Tensor,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
self.cos_sim_matrix = cos_sim_matrix
Expand All @@ -162,23 +198,18 @@ def __init__(
[0] * len(sentences),
multilabel=False,
num_iterations=num_iterations,
sampling_strategy=sampling_strategy,
sampling_strategy=SamplingStrategy.UNIQUE, # use unique to create all pos pairs, implementation choice to use generate_positive_pair method (*)
max_pairs=max_pairs,
)
# Internally we store all pairs in pos_pairs, regardless of sampling strategy.
# After all, without labels, there isn't much of a strategy.
self.sentence_labels = list(enumerate(self.sentences))

self.len_neg_pairs = 0
if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
else:
self.len_pos_pairs = len(self.pos_pairs)

def generate_pairs(self) -> None:
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
self.pos_pairs.append(
{"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
)
if self.max_pos_or_neg != -1 and len(self.pos_pairs) > self.max_pos_or_neg:
break
self.sentence_labels = list(zip(self.sentences, range(len(self.sentences))))

# (*) Internally we use generate_positive_pair
def generate_positive_pair(self) -> Generator[SentencePair, None, None]:
"""We define either generate_positive_pair or generate_negative_pair to change the pairs generation behavior. Does not matter which we choose.
"""
pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (text_one, id_one), (text_two, id_two) in pair_generator:
yield {"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
pair_generator = shuffle_combinations(self.sentence_labels)
2 changes: 1 addition & 1 deletion src/setfit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def get_dataset(
args.sampling_strategy,
max_pairs=max_pairs,
)
dataset = Dataset.from_list(list(data_sampler))
dataset = Dataset.from_generator(data_sampler.__iter__)
loss = args.loss(self.model.model_body)

return dataset, loss
Expand Down
6 changes: 2 additions & 4 deletions src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ def get_dataset(
)
cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student)

data_sampler = ContrastiveDistillationDataset(
x, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
)
dataset = Dataset.from_list(list(data_sampler))
data_sampler = ContrastiveDistillationDataset(x, cos_sim_matrix, args.num_iterations, max_pairs=max_pairs)
dataset = Dataset.from_generator(data_sampler.__iter__)
loss = args.loss(self.model.model_body)
return dataset, loss

Expand Down
1 change: 1 addition & 0 deletions tests/span/test_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def test_model_card(absa_dataset: Dataset, tmp_path: Path) -> None:
logging_steps=1,
max_steps=2,
eval_strategy="steps",
save_strategy="no",
)
trainer = AbsaTrainer(
model=model,
Expand Down
1 change: 1 addition & 0 deletions tests/test_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_model_card(tmp_path: Path) -> None:
logging_steps=1,
max_steps=2,
eval_strategy="steps",
save_strategy="no",
)
trainer = Trainer(
model=model,
Expand Down