diff --git a/src/synthesizrr/base/algorithm/huggingface/transformers.py b/src/synthesizrr/base/algorithm/huggingface/transformers.py index 4049974..a336ff9 100644 --- a/src/synthesizrr/base/algorithm/huggingface/transformers.py +++ b/src/synthesizrr/base/algorithm/huggingface/transformers.py @@ -26,7 +26,7 @@ from transformers.models.auto.modeling_auto import _BaseAutoModelClass, MODEL_MAPPING_NAMES, \ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, \ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES - from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, StoppingCriteria from transformers import ( LogitsProcessorList, MinLengthLogitsProcessor, TemperatureLogitsWarper, @@ -507,6 +507,51 @@ class HFGenerativeLMTokenizerConfig(HFTokenizerConfig): truncation_side: Literal['left', 'right'] = 'left' ## Keeps tokens at the end of the string, useful for LLMs + class HFSubstringMatchStoppingCriteria(StoppingCriteria): + def __init__( + self, + *, + stop_sequences: List[str], + tokenizer: Any, + tokenizer_decode_dict: Dict, + prompt_input_ids: Tensor, + ): + self.tokenizer: PreTrainedTokenizerBase = tokenizer + self.tokenizer_decode_dict: Dict = tokenizer_decode_dict + self.stop_sequences: List[str] = as_list(stop_sequences) + self.prompt_input_ids: Tensor = prompt_input_ids + + def __call__(self, input_ids, scores, **kwargs): + # Get the generated text as a string + generated_texts: List[str] = self.tokenizer.batch_decode( + input_ids[:, self.prompt_input_ids.shape[1]:], + **self.tokenizer_decode_dict, + ) + # Check if the target sequence appears in ALL generated texts + should_stop_generating: List[bool] = [] + for generated_text in generated_texts: + should_stop_generating.append(False) + for stop_seq in self.stop_sequences: + if stop_seq in generated_text: + should_stop_generating[-1] = True + break + if bool(all(should_stop_generating)): + # print('=' * 40) + # print(f'Stopped at this point:') + # print('=' * 40) + # for generated_text in generated_texts: + # print(generated_text, end='\n\n') + # print('=' * 40) + return True ## Stop generation + return False ## Continue generation + + def __len__(self): + return len(self.stop_sequences) + + def __iter__(self): + yield self + + class HFPyTorchGenerativeLMMixin(GenerativeLM, HFPyTorchTextModel, ABC): class Hyperparameters(HFPyTorchTextModel.Hyperparameters): prompt_prefix: str = '' @@ -529,6 +574,14 @@ def set_generative_lm_params(cls, params: Dict) -> Dict: def max_num_generated_tokens(self) -> int: return self.hyperparams.generation_params.max_new_tokens + @property + def tokenizer_decode_dict(self) -> Dict: + return self.hyperparams.tokenizer_decode.dict() + + @property + def stop_sequences(self) -> Optional[List[str]]: + return self.hyperparams.generation_params.stop_sequences + def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts: batch: Prompts = super(HFPyTorchGenerativeLMMixin, self)._task_preprocess( batch, @@ -539,12 +592,20 @@ def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts: def forward(self, input: Dict, **kwargs) -> Dict: ## Feed the input_ids and masks to the model: input.pop('token_type_ids', None) + input_ids: Tensor = input['input_ids'] with disable_hf_logging(): gen_kwargs: Dict = { **input, **self.hyperparams.generation_params.hf_dict(), **dict(return_dict_in_generate=True), ## Always return a *DecoderOnlyOutput } + if self.stop_sequences is not None: + gen_kwargs['stopping_criteria'] = HFSubstringMatchStoppingCriteria( + stop_sequences=self.stop_sequences, + tokenizer=self.tokenizer, + tokenizer_decode_dict=self.tokenizer_decode_dict, + prompt_input_ids=input_ids, + ) out: Union[GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput] = self.model.generate(**gen_kwargs) return dict(out) @@ -566,8 +627,22 @@ def prepare_predictions(self, output: Dict, input: Dict, **kwargs) -> Any: num_generated_tokens: int = generated_sequences.shape[1] generated_texts: List[str] = self.tokenizer.batch_decode( generated_sequences, - **self.hyperparams.tokenizer_decode.dict(), + **self.tokenizer_decode_dict, ) + ## Post process stop-sequences: + if self.stop_sequences is not None: + for gen_text_i, generated_text in enumerate(generated_texts): + earliest_stop_idx: Optional[int] = None + for stop_seq in self.stop_sequences: + stop_idx: int = generated_text.find(stop_seq) + if stop_idx != -1: + if earliest_stop_idx is None: + earliest_stop_idx: int = stop_idx + else: + earliest_stop_idx: int = min(earliest_stop_idx, stop_idx) + if earliest_stop_idx is not None: + generated_texts[gen_text_i]: str = generated_text[:earliest_stop_idx] + predictions: Dict = { GENERATED_TEXTS_COL: generated_texts } diff --git a/src/synthesizrr/base/framework/evaluator/LocalEvaluator.py b/src/synthesizrr/base/framework/evaluator/LocalEvaluator.py index f34308a..89b777b 100644 --- a/src/synthesizrr/base/framework/evaluator/LocalEvaluator.py +++ b/src/synthesizrr/base/framework/evaluator/LocalEvaluator.py @@ -17,7 +17,7 @@ class LocalEvaluator(Evaluator): aliases = ['local', 'SimpleEvaluator', 'simple'] ## Cache model locally for 15 mins: - cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=60 * 15) + cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=3 * 60 * 60) def _load_model( self, diff --git a/src/synthesizrr/base/framework/task/text_generation.py b/src/synthesizrr/base/framework/task/text_generation.py index bb5d9b6..325868f 100644 --- a/src/synthesizrr/base/framework/task/text_generation.py +++ b/src/synthesizrr/base/framework/task/text_generation.py @@ -553,6 +553,9 @@ def set_gen_params(cls, params: Dict) -> Dict: params['output_scores_tolerance']: Optional[float] = None ## Do not filter out any tokens. else: raise NotImplementedError(f'Unsupported `output_scores_format`: "{params["output_scores_format"]}"') + + if params.get('stop_sequences') is not None: + params['stop_sequences']: List[str] = as_list(params['stop_sequences']) return params def hf_dict(self) -> Dict: