diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index 73c7b5095..bf739844d 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -123,30 +123,18 @@ def health_check(): return {"status": "Healthy"} class GenerateKwargs(BaseModel): - max_length: int = 200 + max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 - do_sample: bool = True + do_sample: bool = False early_stopping: bool = False num_beams: int = 1 - num_beam_groups: int = 1 - diversity_penalty: float = 0.0 temperature: float = 1.0 - top_k: int = 10 + top_k: int = 50 top_p: float = 1 typical_p: float = 1 repetition_penalty: float = 1 - length_penalty: float = 1 - no_repeat_ngram_size: int = 0 - encoder_no_repeat_ngram_size: int = 0 - bad_words_ids: Optional[List[int]] = None - num_return_sequences: int = 1 - output_scores: bool = False - return_dict_in_generate: bool = False pad_token_id: Optional[int] = tokenizer.pad_token_id eos_token_id: Optional[int] = tokenizer.eos_token_id - forced_bos_token_id: Optional[int] = None - forced_eos_token_id: Optional[int] = None - remove_invalid_values: Optional[bool] = None class Config: extra = Extra.allow # Allows for additional fields not explicitly defined @@ -157,7 +145,7 @@ class UnifiedRequestModel(BaseModel): clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output") prefix: Optional[str] = Field(None, description="Prefix added to prompt") handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation") - generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method") + generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method") # Field for conversational model messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model")