From 07c1789fd8f76f1da0fad8b59bd070fa2695afcd Mon Sep 17 00:00:00 2001 From: Claudio Spiess Date: Tue, 14 Jan 2025 16:25:55 -0800 Subject: [PATCH] Set default params for watsonx chat and text gen --- pyproject.toml | 2 +- src/pdl/pdl_ast.py | 81 +++++++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 09943777..dfdb1fcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "jinja2~=3.0", "PyYAML~=6.0", "jsonschema~=4.0", - "litellm>=1.49", + "litellm>=1.52.1", "termcolor~=2.0", "ipython~=8.0", ] diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index 6ef43454..1a9ea5b0 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -655,41 +655,47 @@ def set_default_granite_model_parameters( if parameters is None: parameters = {} - if "watsonx" in model_id: - if "decoding_method" not in parameters: - parameters["decoding_method"] = ( - DECODING_METHOD # pylint: disable=attribute-defined-outside-init - ) - if "max_tokens" in parameters and parameters["max_tokens"] is None: - parameters["max_tokens"] = ( - MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + # see https://cloud.ibm.com/apidocs/watsonx-ai#text-chat-request + if model_id.startswith("watsonx/"): + parameters.setdefault("temperature", 0) # setting to decoding greedy + + # see https://cloud.ibm.com/apidocs/watsonx-ai#text-generation-request + if model_id.startswith("watsonx_text/"): + parameters.setdefault( + "decoding_method", + DECODING_METHOD, + ) + parameters.setdefault( + "max_tokens", + MAX_NEW_TOKENS, + ) + parameters.setdefault( + "min_new_tokens", + MIN_NEW_TOKENS, + ) + parameters.setdefault( + "repetition_penalty", + REPETITION_PENATLY, + ) + if parameters["decoding_method"] == "sample": + parameters.setdefault( + "temperature", + TEMPERATURE_SAMPLING, ) - if "min_new_tokens" not in parameters: - parameters["min_new_tokens"] = ( - MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + parameters.setdefault( + "top_k", + TOP_K_SAMPLING, ) - if "repetition_penalty" not in parameters: - parameters["repetition_penalty"] = ( - REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init + parameters.setdefault( + "top_p", + TOP_P_SAMPLING, ) - if parameters["decoding_method"] == "sample": - if "temperature" not in parameters: - parameters["temperature"] = ( - TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init - ) - if "top_k" not in parameters: - parameters["top_k"] = ( - TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init - ) - if "top_p" not in parameters: - parameters["top_p"] = ( - TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init - ) - if "replicate" in model_id and "granite-3.0" in model_id: - if "temperature" not in parameters or parameters["temperature"] is None: - parameters["temperature"] = 0 # setting to decoding greedy - if "roles" not in parameters: - parameters["roles"] = { + + if model_id.startswith("replicate/") and "granite-3.0" in model_id: + parameters.setdefault("temperature", 0) # setting to decoding greedy + parameters.setdefault( + "roles", + { "system": { "pre_message": "<|start_of_role|>system<|end_of_role|>", "post_message": "<|end_of_text|>", @@ -710,10 +716,11 @@ def set_default_granite_model_parameters( "pre_message": "<|start_of_role|>tool_response<|end_of_role|>", "post_message": "<|end_of_text|>", }, - } - if "final_prompt_value" not in parameters: - parameters["final_prompt_value"] = ( - "<|start_of_role|>assistant<|end_of_role|>" - ) + }, + ) + parameters.setdefault( + "final_prompt_value", + "<|start_of_role|>assistant<|end_of_role|>", + ) return parameters