Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: default params #293

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 4 additions & 16 deletions presets/inference/text-generation/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,18 @@ def health_check():
return {"status": "Healthy"}

class GenerateKwargs(BaseModel):
max_length: int = 200
max_length: int = 200 # Length of input prompt+max_new_tokens
min_length: int = 0
do_sample: bool = True
do_sample: bool = False
early_stopping: bool = False
num_beams: int = 1
num_beam_groups: int = 1
diversity_penalty: float = 0.0
temperature: float = 1.0
top_k: int = 10
top_k: int = 50
top_p: float = 1
typical_p: float = 1
repetition_penalty: float = 1
length_penalty: float = 1
no_repeat_ngram_size: int = 0
encoder_no_repeat_ngram_size: int = 0
bad_words_ids: Optional[List[int]] = None
num_return_sequences: int = 1
output_scores: bool = False
return_dict_in_generate: bool = False
pad_token_id: Optional[int] = tokenizer.pad_token_id
eos_token_id: Optional[int] = tokenizer.eos_token_id
forced_bos_token_id: Optional[int] = None
forced_eos_token_id: Optional[int] = None
remove_invalid_values: Optional[bool] = None
class Config:
extra = Extra.allow # Allows for additional fields not explicitly defined

Expand All @@ -157,7 +145,7 @@ class UnifiedRequestModel(BaseModel):
clean_up_tokenization_spaces: Optional[bool] = Field(False, description="Clean up extra spaces in text output")
prefix: Optional[str] = Field(None, description="Prefix added to prompt")
handle_long_generation: Optional[str] = Field(None, description="Strategy to handle long generation")
generate_kwargs: Optional[GenerateKwargs] = Field(None, description="Additional kwargs for generate method")
generate_kwargs: Optional[GenerateKwargs] = Field(default_factory=GenerateKwargs, description="Additional kwargs for generate method")

# Field for conversational model
messages: Optional[List[Dict[str, str]]] = Field(None, description="Messages for conversational model")
Expand Down
Loading