diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index 4f634092..d56954f0 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1051,7 +1051,9 @@ def get_transformed_inputs(kwargs): if "input" in litellm_params: append_log(state, "Model Input", litellm_params["input"]) else: - append_log(state, "Model Input", messages_to_str(concrete_block.model, model_input)) + append_log( + state, "Model Input", messages_to_str(concrete_block.model, model_input) + ) background: Messages = [msg] result = msg["content"] append_log(state, "Model Output", result) diff --git a/src/pdl/pdl_utils.py b/src/pdl/pdl_utils.py index 5cdada3c..47b5643e 100644 --- a/src/pdl/pdl_utils.py +++ b/src/pdl/pdl_utils.py @@ -33,25 +33,19 @@ def messages_concat(messages1: Messages, messages2: Messages) -> Messages: def messages_to_str(model_id: str, messages: Messages) -> str: - if "granite-3b" not in model_id and "granite-8b" not in model_id: - return "".join( - [ - ( - msg["content"] - - ) - for msg in messages - ] + if "granite-3b" not in model_id and "granite-8b" not in model_id: + return "".join([(msg["content"]) for msg in messages]) + return ( + "".join( + [ + ( + msg["content"] + if msg["role"] is None + # else f"<|{msg['role']}|>{msg['content']}" + else f"<|start_of_role|>{msg['role']}<|end_of_role|>{msg['content']}<|end_of_text|>\n" + ) + for msg in messages + ] + ) + + "<|start_of_role|>assistant<|end_of_role|>" ) - return "".join( - [ - ( - msg["content"] - if msg["role"] is None - #else f"<|{msg['role']}|>{msg['content']}" - else f"<|start_of_role|>{msg['role']}<|end_of_role|>{msg['content']}<|end_of_text|>\n" - ) - for msg in messages - ] - ) + "<|start_of_role|>assistant<|end_of_role|>" -