From cb4e803a10fab421db5b6fb24ef99f6274ade84f Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Fri, 20 Dec 2024 15:03:15 +0200 Subject: [PATCH 01/21] fixed to work with processing_class instead tokenizer after transformers 4.45.2 --- src/setfit/trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index 20411d3b..b20409b9 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -72,7 +72,7 @@ def overwritten_call_event(self, event, args, state, control, **kwargs): model=self.setfit_model, st_model=self.model, st_args=args, - tokenizer=self.tokenizer, + tokenizer=self.processing_class, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, train_dataloader=self.train_dataloader, @@ -156,9 +156,9 @@ def _set_logs_prefix(self, logs_prefix: str) -> None: """ self.logs_prefix = logs_prefix - def log(self, logs: Dict[str, float]) -> None: + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: logs = {f"{self.logs_prefix}_{k}" if k == "loss" else k: v for k, v in logs.items()} - return super().log(logs) + return super().log(logs, start_time) def evaluate( self, From 6a6930383ebb332ecd944de95c4527af0ab61a1a Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Tue, 24 Dec 2024 00:31:55 +0200 Subject: [PATCH 02/21] refactor attempt for ContrastiveDataset so that it does not blow RAM with bigger dataset --- src/setfit/sampler.py | 108 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 107 insertions(+), 1 deletion(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 16bc9938..08ed47c3 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -12,7 +12,7 @@ logger = logging.get_logger(__name__) -def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: +def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Generator: """Generates shuffled pair combinations for any iterable data provided. Args: @@ -31,6 +31,112 @@ def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Genera yield iterable[_idx], iterable[idx] +class ContrastiveDatasetIt(IterableDataset): + def __init__( + self, + sentences: List[str], + labels: List[Union[int, float]], + multilabel: bool = False, # False for now + num_iterations: Optional[None] = None, + sampling_strategy: str = "oversampling", + max_pairs: int = -1, + ) -> None: + """Generates positive and negative text pairs for contrastive learning. + + Args: + sentences (List[str]): text sentences to generate pairs from + labels (List[Union[int, float]]): labels for each sentence + multilabel: set to process "multilabel" labels array + sampling_strategy: "unique", "oversampling", or "undersampling" + num_iterations: if provided explicitly sets the number of pairs to be generated + where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) + max_pairs: If not -1, then we only sample pairs until we have certainly reached + max_pairs pairs. + """ + super().__init__() + self.pos_index = 0 + self.neg_index = 0 + self.pos_pairs = [] + self.neg_pairs = [] + self.sentences = sentences + self.labels = labels + self.sentence_labels = list(zip(self.sentences, self.labels)) + self.max_pos_or_neg = np.inf if max_pairs == -1 else max_pairs // 2 + + from collections import Counter + from math import prod + label_counts = Counter(labels) + # postive number of pairs from an n element set without replacement + self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()])) + # negative product + self.total_neg_pairs = prod(label_counts.values()) + + self.generated_pos_pairs = 0 + self.generated_neg_pairs = 0 + + if num_iterations is not None and num_iterations > 0: + self.len_pos_pairs = num_iterations * len(self.sentences) + self.len_neg_pairs = num_iterations * len(self.sentences) + + elif sampling_strategy == "unique": + self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg])) + self.len_neg_pairs = int(np.min([self.total_neg_pairs, self.max_pos_or_neg])) + + elif sampling_strategy == "undersampling": + self.len_pos_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + self.len_neg_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + + elif sampling_strategy == "oversampling": + self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + + else: + raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") + + # generate pair functions are not ideal but still wont blow the memory if you decide to train on big dataset + def generate_positive_pair(self): + + pair_generator = shuffle_combinations(self.sentence_labels) + while True: + for (_text, _label), (text, label) in pair_generator: + is_positive = _label == label + + if is_positive and self.generated_pos_pairs <= self.len_pos_pairs: + self.generated_pos_pairs += 1 + yield {"sentence_1": _text, "sentence_2": text, "label": 1.0} + # restart + pair_generator = shuffle_combinations(self.sentence_labels) + + def generate_negative_pair(self): + pair_generator = shuffle_combinations(self.sentence_labels) + while True: + for (_text, _label), (text, label) in pair_generator: + is_negative = _label != label + + if is_negative and self.generated_neg_pairs <= self.len_neg_pairs: + self.generated_neg_pairs += 1 + yield {"sentence_1": _text, "sentence_2": text, "label": 0.0} + # restart + pair_generator = shuffle_combinations(self.sentence_labels) + + def __iter__(self): + # reset to starting values(state) so that iterator can be recreated and used again if needed. + self.generated_pos_pairs = 0 + self.generated_neg_pairs = 0 + + pos_generator = self.generate_positive_pair() + neg_generator = self.generate_negative_pair() + + while (self.generated_pos_pairs + self.generated_neg_pairs) < len(self): + if self.generated_pos_pairs < self.len_pos_pairs: + yield next(pos_generator) + if self.generated_neg_pairs < self.len_neg_pairs: + yield next(neg_generator) + + def __len__(self) -> int: + return self.len_pos_pairs + self.len_neg_pairs + + class ContrastiveDataset(IterableDataset): def __init__( self, From 09869e1c17128e54895972636e863737f3b42f70 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Tue, 24 Dec 2024 13:13:33 +0200 Subject: [PATCH 03/21] added Samplit strategy enum --- src/setfit/sampler.py | 46 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 08ed47c3..3a414f53 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,9 +1,12 @@ from itertools import zip_longest -from typing import Dict, Generator, Iterable, List, Optional, Union +from typing import Dict, Generator, Iterable, List, Literal, Optional, Union +from collections import Counter +from math import prod import numpy as np import torch from torch.utils.data import IterableDataset +from transformers.utils import ExplicitEnum from . import logging @@ -12,6 +15,47 @@ logger = logging.get_logger(__name__) +class SamplingStrategy(ExplicitEnum): + """ + ## Oversampling + + By default, SetFit applies the oversampling strategy for its contrastive pairs. This strategy samples an equal amount of positive and negative training + pairs, oversampling the minority pair type to match that of the majority pair type. As the number of negative pairs is generally larger than the number + of positive pairs, this usually involves oversampling the positive pairs. + + In our running example, this would involve oversampling the 62 positive pairs up to 128, resulting in one epoch of 128 + 128 = 256 pairs. In summary: + + * Y An equal amount of positive and negative pairs are sampled. + * Y Every possible pair is used. + * X There is some data duplication. + + ## Undersampling + + Like oversampling, this strategy samples an equal amount of positive and negative training pairs. However, it undersamples the majority pair type to match + that of the minority pair type. This usually involves undersampling the negative pairs to match the positive pairs. + + In our running example, this would involve undersampling the 128 negative pairs down to 62, resulting in one epoch of 62 + 62 = 124 pairs. In summary: + + * Y An equal amount of positive and negative pairs are sampled. + * X **Not** every possible pair is used. + * Y There is **no** data duplication. + + ## Unique + + Thirdly, the unique strategy does not sample an equal amount of positive and negative training pairs. Instead, it simply samples all possible pairs exactly + once. No form of oversampling or undersampling is used here. + + In our running example, this would involve sampling all negative and positive pairs, resulting in one epoch of 62 + 128 = 190 pairs. In summary: + + * X **Not** an equal amount of positive and negative pairs are sampled. + * Y Every possible pair is used. + * Y There is **no** data duplication. + """ + OVERSAMPLING = "oversampling" + UNDERSAMPLING = "undersampling" + UNIQUE = "unique" + + def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Generator: """Generates shuffled pair combinations for any iterable data provided. From 82474dcf5ac293e6b9ae6923c6b276d4bb88e97a Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Tue, 24 Dec 2024 13:15:35 +0200 Subject: [PATCH 04/21] improved logic and fixed iterator pattern --- src/setfit/sampler.py | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 3a414f53..64b28012 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -82,7 +82,7 @@ def __init__( labels: List[Union[int, float]], multilabel: bool = False, # False for now num_iterations: Optional[None] = None, - sampling_strategy: str = "oversampling", + sampling_strategy: Literal["oversampling", "undersampling", "unique"] = "oversampling", max_pairs: int = -1, ) -> None: """Generates positive and negative text pairs for contrastive learning. @@ -107,46 +107,39 @@ def __init__( self.sentence_labels = list(zip(self.sentences, self.labels)) self.max_pos_or_neg = np.inf if max_pairs == -1 else max_pairs // 2 - from collections import Counter - from math import prod + sampling_strategy = SamplingStrategy(sampling_strategy) + + # calculate number of positive and negative combinations label_counts = Counter(labels) # postive number of pairs from an n element set without replacement self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()])) # negative product self.total_neg_pairs = prod(label_counts.values()) - self.generated_pos_pairs = 0 - self.generated_neg_pairs = 0 - if num_iterations is not None and num_iterations > 0: self.len_pos_pairs = num_iterations * len(self.sentences) self.len_neg_pairs = num_iterations * len(self.sentences) - elif sampling_strategy == "unique": + elif sampling_strategy == SamplingStrategy.UNIQUE: self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg])) self.len_neg_pairs = int(np.min([self.total_neg_pairs, self.max_pos_or_neg])) - elif sampling_strategy == "undersampling": + elif sampling_strategy == SamplingStrategy.UNDERSAMPLING: self.len_pos_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) self.len_neg_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - elif sampling_strategy == "oversampling": + elif sampling_strategy == SamplingStrategy.OVERSAMPLING: self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - else: - raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") - # generate pair functions are not ideal but still wont blow the memory if you decide to train on big dataset def generate_positive_pair(self): - pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: is_positive = _label == label - if is_positive and self.generated_pos_pairs <= self.len_pos_pairs: - self.generated_pos_pairs += 1 + if is_positive: yield {"sentence_1": _text, "sentence_2": text, "label": 1.0} # restart pair_generator = shuffle_combinations(self.sentence_labels) @@ -157,25 +150,25 @@ def generate_negative_pair(self): for (_text, _label), (text, label) in pair_generator: is_negative = _label != label - if is_negative and self.generated_neg_pairs <= self.len_neg_pairs: - self.generated_neg_pairs += 1 + if is_negative: yield {"sentence_1": _text, "sentence_2": text, "label": 0.0} - # restart pair_generator = shuffle_combinations(self.sentence_labels) def __iter__(self): - # reset to starting values(state) so that iterator can be recreated and used again if needed. - self.generated_pos_pairs = 0 - self.generated_neg_pairs = 0 + + generated_pos_pairs = 0 + generated_neg_pairs = 0 pos_generator = self.generate_positive_pair() neg_generator = self.generate_negative_pair() - while (self.generated_pos_pairs + self.generated_neg_pairs) < len(self): - if self.generated_pos_pairs < self.len_pos_pairs: + while (generated_pos_pairs + generated_neg_pairs) < len(self): + if generated_pos_pairs < self.len_pos_pairs: yield next(pos_generator) - if self.generated_neg_pairs < self.len_neg_pairs: + generated_pos_pairs += 1 + if generated_neg_pairs < self.len_neg_pairs: yield next(neg_generator) + generated_neg_pairs += 1 def __len__(self) -> int: return self.len_pos_pairs + self.len_neg_pairs From d01dc3660cffd4c73ac05af78a7b58e6d4001dcf Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Tue, 24 Dec 2024 13:47:52 +0200 Subject: [PATCH 05/21] fix for negative samples formula --- src/setfit/sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 64b28012..cc4afa6f 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,7 +1,6 @@ -from itertools import zip_longest +from itertools import combinations, zip_longest from typing import Dict, Generator, Iterable, List, Literal, Optional, Union from collections import Counter -from math import prod import numpy as np import torch @@ -114,7 +113,7 @@ def __init__( # postive number of pairs from an n element set without replacement self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()])) # negative product - self.total_neg_pairs = prod(label_counts.values()) + self.total_neg_pairs = self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2)) if num_iterations is not None and num_iterations > 0: self.len_pos_pairs = num_iterations * len(self.sentences) From 53eaace56d272c283dce6aadddc6d927f3b7f807 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Tue, 24 Dec 2024 14:26:04 +0200 Subject: [PATCH 06/21] added multilalbel support as in the original implementation --- src/setfit/sampler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index cc4afa6f..46acd0d2 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -105,6 +105,7 @@ def __init__( self.labels = labels self.sentence_labels = list(zip(self.sentences, self.labels)) self.max_pos_or_neg = np.inf if max_pairs == -1 else max_pairs // 2 + self._multilabel = multilabel sampling_strategy = SamplingStrategy(sampling_strategy) @@ -136,7 +137,10 @@ def generate_positive_pair(self): pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: - is_positive = _label == label + if self._multilabel: + is_positive = any(np.logical_and(_label, label)) + else: + is_positive = _label == label if is_positive: yield {"sentence_1": _text, "sentence_2": text, "label": 1.0} @@ -147,7 +151,10 @@ def generate_negative_pair(self): pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: - is_negative = _label != label + if self._multilabel: + is_negative = not any(np.logical_and(_label, label)) + else: + is_negative = _label != label if is_negative: yield {"sentence_1": _text, "sentence_2": text, "label": 0.0} From 3e3fa5fbe8e2ae89b289318101570e2ae896f621 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Wed, 25 Dec 2024 12:08:57 +0200 Subject: [PATCH 07/21] ContrastiveDataset iterator refactor --- src/setfit/sampler.py | 127 ++---------------------------------------- 1 file changed, 5 insertions(+), 122 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 46acd0d2..c768fac8 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,4 +1,4 @@ -from itertools import combinations, zip_longest +from itertools import combinations from typing import Dict, Generator, Iterable, List, Literal, Optional, Union from collections import Counter @@ -74,7 +74,7 @@ def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Gener yield iterable[_idx], iterable[idx] -class ContrastiveDatasetIt(IterableDataset): +class ContrastiveDataset(IterableDataset): def __init__( self, sentences: List[str], @@ -132,8 +132,7 @@ def __init__( self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - # generate pair functions are not ideal but still wont blow the memory if you decide to train on big dataset - def generate_positive_pair(self): + def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]: pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: @@ -147,7 +146,7 @@ def generate_positive_pair(self): # restart pair_generator = shuffle_combinations(self.sentence_labels) - def generate_negative_pair(self): + def generate_negative_pair(self) -> Generator[Dict[str, Union[str, float]]]: pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: @@ -160,7 +159,7 @@ def generate_negative_pair(self): yield {"sentence_1": _text, "sentence_2": text, "label": 0.0} pair_generator = shuffle_combinations(self.sentence_labels) - def __iter__(self): + def __iter__(self) -> Generator[Dict[str, Union[str, float]]]: generated_pos_pairs = 0 generated_neg_pairs = 0 @@ -180,122 +179,6 @@ def __len__(self) -> int: return self.len_pos_pairs + self.len_neg_pairs -class ContrastiveDataset(IterableDataset): - def __init__( - self, - sentences: List[str], - labels: List[Union[int, float]], - multilabel: bool, - num_iterations: Optional[None] = None, - sampling_strategy: str = "oversampling", - max_pairs: int = -1, - ) -> None: - """Generates positive and negative text pairs for contrastive learning. - - Args: - sentences (List[str]): text sentences to generate pairs from - labels (List[Union[int, float]]): labels for each sentence - multilabel: set to process "multilabel" labels array - sampling_strategy: "unique", "oversampling", or "undersampling" - num_iterations: if provided explicitly sets the number of pairs to be generated - where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) - max_pairs: If not -1, then we only sample pairs until we have certainly reached - max_pairs pairs. - """ - super().__init__() - self.pos_index = 0 - self.neg_index = 0 - self.pos_pairs = [] - self.neg_pairs = [] - self.sentences = sentences - self.labels = labels - self.sentence_labels = list(zip(self.sentences, self.labels)) - self.max_pos_or_neg = -1 if max_pairs == -1 else max_pairs // 2 - - if multilabel: - self.generate_multilabel_pairs() - else: - self.generate_pairs() - - if num_iterations is not None and num_iterations > 0: - self.len_pos_pairs = num_iterations * len(self.sentences) - self.len_neg_pairs = num_iterations * len(self.sentences) - - elif sampling_strategy == "unique": - self.len_pos_pairs = len(self.pos_pairs) - self.len_neg_pairs = len(self.neg_pairs) - - elif sampling_strategy == "undersampling": - self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) - self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) - - elif sampling_strategy == "oversampling": - self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) - self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) - - else: - raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") - - def generate_pairs(self) -> None: - for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - is_positive = _label == label - is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg - is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg - - if is_positive: - if not is_positive_full: - self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - elif not is_negative_full: - self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - - if is_positive_full and is_negative_full: - break - - def generate_multilabel_pairs(self) -> None: - for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): - # logical_and checks if labels are both set for each class - is_positive = any(np.logical_and(_label, label)) - is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg - is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg - - if is_positive: - if not is_positive_full: - self.pos_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 1.0}) - elif not is_negative_full: - self.neg_pairs.append({"sentence_1": _text, "sentence_2": text, "label": 0.0}) - - if is_positive_full and is_negative_full: - break - - def get_positive_pairs(self) -> List[Dict[str, Union[str, float]]]: - pairs = [] - for _ in range(self.len_pos_pairs): - if self.pos_index >= len(self.pos_pairs): - self.pos_index = 0 - pairs.append(self.pos_pairs[self.pos_index]) - self.pos_index += 1 - return pairs - - def get_negative_pairs(self) -> List[Dict[str, Union[str, float]]]: - pairs = [] - for _ in range(self.len_neg_pairs): - if self.neg_index >= len(self.neg_pairs): - self.neg_index = 0 - pairs.append(self.neg_pairs[self.neg_index]) - self.neg_index += 1 - return pairs - - def __iter__(self): - for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()): - if pos_pair is not None: - yield pos_pair - if neg_pair is not None: - yield neg_pair - - def __len__(self) -> int: - return self.len_pos_pairs + self.len_neg_pairs - - class ContrastiveDistillationDataset(ContrastiveDataset): def __init__( self, From 2d5e29b784b6573dba41ef9d814be7f32e051d54 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Wed, 25 Dec 2024 12:19:06 +0200 Subject: [PATCH 08/21] trainer fixed to work with ContrastiveDataset iter method --- src/setfit/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/setfit/trainer.py b/src/setfit/trainer.py index b20409b9..82d9faa0 100644 --- a/src/setfit/trainer.py +++ b/src/setfit/trainer.py @@ -602,7 +602,7 @@ def get_dataset( args.sampling_strategy, max_pairs=max_pairs, ) - dataset = Dataset.from_list(list(data_sampler)) + dataset = Dataset.from_generator(data_sampler.__iter__) loss = args.loss(self.model.model_body) return dataset, loss From 1c905b1455d8062b5b497e79e02147a31196cdb8 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Wed, 25 Dec 2024 12:24:06 +0200 Subject: [PATCH 09/21] ContrastiveDistillationDataset iter refactor --- src/setfit/sampler.py | 28 ++++++++++------------------ src/setfit/trainer_distillation.py | 2 +- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index c768fac8..f84aa1a0 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -185,7 +185,6 @@ def __init__( sentences: List[str], cos_sim_matrix: torch.Tensor, num_iterations: Optional[None] = None, - sampling_strategy: str = "oversampling", max_pairs: int = -1, ) -> None: self.cos_sim_matrix = cos_sim_matrix @@ -194,23 +193,16 @@ def __init__( [0] * len(sentences), multilabel=False, num_iterations=num_iterations, - sampling_strategy=sampling_strategy, + sampling_strategy=SamplingStrategy.UNIQUE, # use unique to create all pos pairs, implementation choice to use generate_positive_pair method (*) max_pairs=max_pairs, ) - # Internally we store all pairs in pos_pairs, regardless of sampling strategy. - # After all, without labels, there isn't much of a strategy. - self.sentence_labels = list(enumerate(self.sentences)) - self.len_neg_pairs = 0 - if num_iterations is not None and num_iterations > 0: - self.len_pos_pairs = num_iterations * len(self.sentences) - else: - self.len_pos_pairs = len(self.pos_pairs) - - def generate_pairs(self) -> None: - for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels): - self.pos_pairs.append( - {"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]} - ) - if self.max_pos_or_neg != -1 and len(self.pos_pairs) > self.max_pos_or_neg: - break + self.sentence_labels = list(zip(self.sentences, range(len(self.sentences)))) + + # (*) Internally we use generate_positive_pair + def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]: + pair_generator = shuffle_combinations(self.sentence_labels) + while True: + for (text_one, id_one), (text_two, id_two) in pair_generator: + yield {"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]} + pair_generator = shuffle_combinations(self.sentence_labels) diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index 20beb300..19e6edcf 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -92,7 +92,7 @@ def get_dataset( data_sampler = ContrastiveDistillationDataset( x, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs ) - dataset = Dataset.from_list(list(data_sampler)) + dataset = Dataset.from_generator(data_sampler.__iter__) loss = args.loss(self.model.model_body) return dataset, loss From 9cafa023b793cd751198fdfd9183a20a704a887f Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Thu, 26 Dec 2024 14:13:31 +0200 Subject: [PATCH 10/21] typing fix --- src/setfit/sampler.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index f84aa1a0..2d67c2cd 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,5 +1,5 @@ from itertools import combinations -from typing import Dict, Generator, Iterable, List, Literal, Optional, Union +from typing import Dict, Generator, Iterable, List, Literal, Optional, Union, TypeAlias from collections import Counter import numpy as np @@ -13,6 +13,8 @@ logging.set_verbosity_info() logger = logging.get_logger(__name__) +SentencePair: TypeAlias = Dict[str, Union[str, float]] + class SamplingStrategy(ExplicitEnum): """ @@ -132,7 +134,7 @@ def __init__( self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]: + def generate_positive_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: @@ -146,7 +148,7 @@ def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]: # restart pair_generator = shuffle_combinations(self.sentence_labels) - def generate_negative_pair(self) -> Generator[Dict[str, Union[str, float]]]: + def generate_negative_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels) while True: for (_text, _label), (text, label) in pair_generator: @@ -159,7 +161,7 @@ def generate_negative_pair(self) -> Generator[Dict[str, Union[str, float]]]: yield {"sentence_1": _text, "sentence_2": text, "label": 0.0} pair_generator = shuffle_combinations(self.sentence_labels) - def __iter__(self) -> Generator[Dict[str, Union[str, float]]]: + def __iter__(self) -> Generator[SentencePair, None, None]: generated_pos_pairs = 0 generated_neg_pairs = 0 @@ -200,7 +202,7 @@ def __init__( self.sentence_labels = list(zip(self.sentences, range(len(self.sentences)))) # (*) Internally we use generate_positive_pair - def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]: + def generate_positive_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels) while True: for (text_one, id_one), (text_two, id_two) in pair_generator: From ea088ff8a44ea37cb11e2b8c3937aae5a44e84e7 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 10 Jan 2025 16:06:00 +0100 Subject: [PATCH 11/21] TypeAlias will be deprecated in 3.12 again, so let's avoid it It also doesn't exist in 3.9 yet. https://docs.python.org/3/library/typing.html#typing.TypeAlias --- src/setfit/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 2d67c2cd..6a40f07b 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,5 +1,5 @@ from itertools import combinations -from typing import Dict, Generator, Iterable, List, Literal, Optional, Union, TypeAlias +from typing import Dict, Generator, Iterable, List, Literal, Optional, Union from collections import Counter import numpy as np @@ -13,7 +13,7 @@ logging.set_verbosity_info() logger = logging.get_logger(__name__) -SentencePair: TypeAlias = Dict[str, Union[str, float]] +SentencePair = Dict[str, Union[str, float]] class SamplingStrategy(ExplicitEnum): From 467c5076c60b3e64e540768c3398df1b5f76ee83 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 10 Jan 2025 16:08:57 +0100 Subject: [PATCH 12/21] Remove args.sampling_strategy from ContrastiveDistillationDataset init --- src/setfit/trainer_distillation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/setfit/trainer_distillation.py b/src/setfit/trainer_distillation.py index 19e6edcf..60080fe3 100644 --- a/src/setfit/trainer_distillation.py +++ b/src/setfit/trainer_distillation.py @@ -89,9 +89,7 @@ def get_dataset( ) cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student) - data_sampler = ContrastiveDistillationDataset( - x, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs - ) + data_sampler = ContrastiveDistillationDataset(x, cos_sim_matrix, args.num_iterations, max_pairs=max_pairs) dataset = Dataset.from_generator(data_sampler.__iter__) loss = args.loss(self.model.model_body) return dataset, loss From 9784485510cc3f88e0c6771c33de3aff001bab07 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 10 Jan 2025 16:09:04 +0100 Subject: [PATCH 13/21] Run formatting --- src/setfit/sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 6a40f07b..a0d71b1f 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -1,6 +1,6 @@ +from collections import Counter from itertools import combinations from typing import Dict, Generator, Iterable, List, Literal, Optional, Union -from collections import Counter import numpy as np import torch @@ -52,6 +52,7 @@ class SamplingStrategy(ExplicitEnum): * Y Every possible pair is used. * Y There is **no** data duplication. """ + OVERSAMPLING = "oversampling" UNDERSAMPLING = "undersampling" UNIQUE = "unique" From 3f67190c8287e818bc7df0bbbde8ec42546154a5 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Fri, 10 Jan 2025 16:45:00 +0100 Subject: [PATCH 14/21] Add 'save_strategy="no"' in tests to counteract transformers v4.48.0 bug --- tests/span/test_model_card.py | 1 + tests/test_model_card.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/span/test_model_card.py b/tests/span/test_model_card.py index 60fd830b..6e4585f8 100644 --- a/tests/span/test_model_card.py +++ b/tests/span/test_model_card.py @@ -26,6 +26,7 @@ def test_model_card(absa_dataset: Dataset, tmp_path: Path) -> None: logging_steps=1, max_steps=2, eval_strategy="steps", + save_strategy="no", ) trainer = AbsaTrainer( model=model, diff --git a/tests/test_model_card.py b/tests/test_model_card.py index 910f7237..9bdb199c 100644 --- a/tests/test_model_card.py +++ b/tests/test_model_card.py @@ -36,6 +36,7 @@ def test_model_card(tmp_path: Path) -> None: logging_steps=1, max_steps=2, eval_strategy="steps", + save_strategy="no", ) trainer = Trainer( model=model, From 509ee68d695895539c2c8a289512cc5ce36dfaff Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Sat, 11 Jan 2025 21:22:27 +0200 Subject: [PATCH 15/21] fix some autocomplete mistake --- src/setfit/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 2d67c2cd..97d8cd1f 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -116,7 +116,7 @@ def __init__( # postive number of pairs from an n element set without replacement self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()])) # negative product - self.total_neg_pairs = self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2)) + self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2)) if num_iterations is not None and num_iterations > 0: self.len_pos_pairs = num_iterations * len(self.sentences) From 7d9119bb974c0a73b9f199b6c349d3cf68c90a74 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Sun, 12 Jan 2025 15:08:12 +0200 Subject: [PATCH 16/21] cleanup leftovers attributes from old logic --- src/setfit/sampler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 1ac34237..9d120d64 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -100,10 +100,6 @@ def __init__( max_pairs pairs. """ super().__init__() - self.pos_index = 0 - self.neg_index = 0 - self.pos_pairs = [] - self.neg_pairs = [] self.sentences = sentences self.labels = labels self.sentence_labels = list(zip(self.sentences, self.labels)) From fe30eaaf662663af5a70e76aae5c4a241b1429f9 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Mon, 13 Jan 2025 16:53:24 +0200 Subject: [PATCH 17/21] docs --- src/setfit/sampler.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 9d120d64..48528354 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -159,6 +159,9 @@ def generate_negative_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels) def __iter__(self) -> Generator[SentencePair, None, None]: + """Use this to create generator to loop over the dataset. + You should rewrite this method if you want to change how pairs are generated. + """ generated_pos_pairs = 0 generated_neg_pairs = 0 @@ -200,6 +203,8 @@ def __init__( # (*) Internally we use generate_positive_pair def generate_positive_pair(self) -> Generator[SentencePair, None, None]: + """We define either generate_positive_pair or generate_negative_pair to change the pairs generation behavior. Does not matter which we choose. + """ pair_generator = shuffle_combinations(self.sentence_labels) while True: for (text_one, id_one), (text_two, id_two) in pair_generator: From e7f9cf08afbc8db46018aaad836c6029627f6ca1 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Mon, 13 Jan 2025 16:55:32 +0200 Subject: [PATCH 18/21] fix to correctly specify positives and negatives when passing num_iterations --- src/setfit/sampler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 48528354..bbbddbd5 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -116,8 +116,9 @@ def __init__( self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2)) if num_iterations is not None and num_iterations > 0: - self.len_pos_pairs = num_iterations * len(self.sentences) - self.len_neg_pairs = num_iterations * len(self.sentences) + iterations = num_iterations * len(self.sentences) + self.len_pos_pairs = int(np.min([self.total_pos_pairs, iterations])) + self.len_neg_pairs = int(np.min([self.total_neg_pairs, iterations])) elif sampling_strategy == SamplingStrategy.UNIQUE: self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg])) From d989f22c1fdc9954b4ff3bfb6f5a3ef016d8d300 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Mon, 13 Jan 2025 18:34:53 +0200 Subject: [PATCH 19/21] fix for negative or positive examples only(not really possible) --- src/setfit/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index bbbddbd5..fe583fa4 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -117,8 +117,8 @@ def __init__( if num_iterations is not None and num_iterations > 0: iterations = num_iterations * len(self.sentences) - self.len_pos_pairs = int(np.min([self.total_pos_pairs, iterations])) - self.len_neg_pairs = int(np.min([self.total_neg_pairs, iterations])) + self.len_pos_pairs = iterations if self.pos_pairs_combination > 0 else 0 + self.len_neg_pairs = iterations if self.neg_pairs_combination > 0 else 0 elif sampling_strategy == SamplingStrategy.UNIQUE: self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg])) From d229add6f4bb5351120742c46980645b7b034325 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Mon, 13 Jan 2025 18:35:27 +0200 Subject: [PATCH 20/21] removed unnecessary np uses and casting --- src/setfit/sampler.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index fe583fa4..0dd48140 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -111,9 +111,9 @@ def __init__( # calculate number of positive and negative combinations label_counts = Counter(labels) # postive number of pairs from an n element set without replacement - self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()])) + self.pos_pairs_combination = sum([n * (n - 1) // 2 for n in label_counts.values()]) # negative product - self.total_neg_pairs = sum(a * b for a, b in combinations(label_counts.values(), 2)) + self.neg_pairs_combination = sum(a * b for a, b in combinations(label_counts.values(), 2)) if num_iterations is not None and num_iterations > 0: iterations = num_iterations * len(self.sentences) @@ -121,16 +121,16 @@ def __init__( self.len_neg_pairs = iterations if self.neg_pairs_combination > 0 else 0 elif sampling_strategy == SamplingStrategy.UNIQUE: - self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg])) - self.len_neg_pairs = int(np.min([self.total_neg_pairs, self.max_pos_or_neg])) + self.len_pos_pairs = min(self.pos_pairs_combination, self.max_pos_or_neg) + self.len_neg_pairs = min(self.neg_pairs_combination, self.max_pos_or_neg) elif sampling_strategy == SamplingStrategy.UNDERSAMPLING: - self.len_pos_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - self.len_neg_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + self.len_pos_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) + self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) elif sampling_strategy == SamplingStrategy.OVERSAMPLING: - self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) - self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg])) + self.len_pos_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) + self.len_neg_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) def generate_positive_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels) From ea88e9bd9c930ad6b845ce93761e1bbd05428d28 Mon Sep 17 00:00:00 2001 From: Demir Tonchev Date: Mon, 13 Jan 2025 21:43:58 +0200 Subject: [PATCH 21/21] safeguard for oversampling strategy, when DS is negatives only or single pos pair --- src/setfit/sampler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 0dd48140..ce025852 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -129,8 +129,10 @@ def __init__( self.len_neg_pairs = min([min(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) elif sampling_strategy == SamplingStrategy.OVERSAMPLING: - self.len_pos_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) - self.len_neg_pairs = min([max(self.pos_pairs_combination, self.neg_pairs_combination), self.max_pos_or_neg]) + num_pos_or_neg_pairs = max(self.pos_pairs_combination, self.neg_pairs_combination) + # saveguard for either negative samples only or single positive. + self.len_pos_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.pos_pairs_combination > 0 else 0 + self.len_neg_pairs = min([num_pos_or_neg_pairs, self.max_pos_or_neg]) if self.neg_pairs_combination > 0 else 0 def generate_positive_pair(self) -> Generator[SentencePair, None, None]: pair_generator = shuffle_combinations(self.sentence_labels)