Skip to content

Commit

Permalink
Fixed logging bug, added chat template (#125)
Browse files Browse the repository at this point in the history
* fixed logging model input, added template

Signed-off-by: Mandana Vaziri <[email protected]>

* cleanup

Signed-off-by: Mandana Vaziri <[email protected]>

---------

Signed-off-by: Mandana Vaziri <[email protected]>
  • Loading branch information
vazirim authored Oct 4, 2024
1 parent 842994d commit 14cee1e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
10 changes: 6 additions & 4 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,9 @@ def get_transformed_inputs(kwargs):
if "input" in litellm_params:
append_log(state, "Model Input", litellm_params["input"])
else:
append_log(state, "Model Input", model_input)
append_log(
state, "Model Input", messages_to_str(concrete_block.model, model_input)
)
background: Messages = [msg]
result = msg["content"]
append_log(state, "Model Output", result)
Expand Down Expand Up @@ -1093,7 +1095,7 @@ def generate_client_response_streaming(
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
msg_stream: Generator[Message, Any, None]
model_input_str = messages_to_str(model_input)
model_input_str = messages_to_str(block.model, model_input)
match block:
case BamModelBlock():
msg_stream = BamModel.generate_text_stream(
Expand Down Expand Up @@ -1148,7 +1150,7 @@ def generate_client_response_single(
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
msg: Message
model_input_str = messages_to_str(model_input)
model_input_str = messages_to_str(block.model, model_input)
match block:
case BamModelBlock():
msg = BamModel.generate_text(
Expand Down Expand Up @@ -1178,7 +1180,7 @@ def generate_client_response_batching( # pylint: disable=too-many-arguments
# model: str,
model_input: Messages,
) -> Generator[YieldMessage, Any, Message]:
model_input_str = messages_to_str(model_input)
model_input_str = messages_to_str(block.model, model_input)
match block:
case BamModelBlock():
msg = yield ModelCallMessage(
Expand Down
27 changes: 16 additions & 11 deletions src/pdl/pdl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,20 @@ def messages_concat(messages1: Messages, messages2: Messages) -> Messages:
return messages1 + messages2


def messages_to_str(messages: Messages) -> str:
# TODO
return "".join(
[
(
msg["content"]
if msg["role"] is None
else f"<|{msg['role']}|>{msg['content']}"
)
for msg in messages
]
def messages_to_str(model_id: str, messages: Messages) -> str:
if "granite-3b" not in model_id and "granite-8b" not in model_id:
return "".join([(msg["content"]) for msg in messages])
return (
"".join(
[
(
msg["content"]
if msg["role"] is None
# else f"<|{msg['role']}|>{msg['content']}"
else f"<|start_of_role|>{msg['role']}<|end_of_role|>{msg['content']}<|end_of_text|>\n"
)
for msg in messages
]
)
+ "<|start_of_role|>assistant<|end_of_role|>"
)

0 comments on commit 14cee1e

Please sign in to comment.