Skip to content

Commit

Permalink
Add option to output top logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed Nov 1, 2024
1 parent 7283e51 commit a6483f0
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 12 deletions.
40 changes: 28 additions & 12 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
"""
Functionality to work with .prompt files
"""

import json
import logging
import os
import random
import re
from datetime import datetime
import litellm
from rich import print as pprint
from typing import Any, Callable, Dict, List, Optional
from uuid import UUID

Expand Down Expand Up @@ -287,6 +281,10 @@ async def return_response_and_tool(
return tool_outputs
return response, tool_outputs

@chain
async def return_response_and_logprobs(llm_output):
response = await StrOutputParser().ainvoke(input=llm_output)
return response, llm_output.response_metadata.get("logprobs")

def llm_generation_chain(
template_file: str,
Expand All @@ -303,6 +301,7 @@ def llm_generation_chain(
additional_postprocessing_runnable: Runnable = None,
tools: Optional[list[Callable]] = None,
force_tool_calling: bool = False,
return_top_logprobs: int = 0,
bind_prompt_values: Dict = {},
force_skip_cache: bool = False,
) -> Runnable:
Expand All @@ -323,10 +322,12 @@ def llm_generation_chain(
template_blocks: If provided, will use this instead of `template_file`. The format is [(role, string)] where role is one of "instruction", "input", "output"
keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False.
progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name
additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged
tools (List[Callable], optional): If provided, will be made available to the underlying LLM, to optionally output it for function calling. Defaults to None.
force_tool_calling (bool, optional): If True, will force the LLM to output the tools for function calling. Defaults to False.
additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged
return_top_logprobs (int, optional): If > 0, will return the top logprobs for each token, so the output will be Tuple[str, dict]. Defaults to 0.
bind_prompt_values (Dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}.
force_skip_cache (bool, optional): If True, will force the LLM to skip the cache, and the new value won't be saved in cache either. Defaults to False.
Returns:
Runnable: The language model generation chain
Expand All @@ -351,13 +352,20 @@ def llm_generation_chain(
raise IndexError(
f"Could not find any matching engines for {engine}. Please check that llm_config.yaml is configured correctly and that the API key is set in the terminal before running this script."
)

if (
(pydantic_class and tools)
or (pydantic_class and output_json)
or (pydantic_class and output_json)
sum(
[
bool(pydantic_class),
bool(output_json),
bool(tools),
return_top_logprobs > 0,
]
)
> 1
):
raise ValueError(
"At most one of `pydantic_class`, `output_json` and `tools` can be used."
"At most one of `pydantic_class`, `output_json`, `return_top_logprobs` and `tools` can be used."
)
llm_resource = random.choice(potential_llm_resources)

Expand Down Expand Up @@ -404,6 +412,10 @@ def llm_generation_chain(
},
}

if return_top_logprobs > 0:
model_kwargs["logprobs"] = True
model_kwargs["top_logprobs"] = return_top_logprobs

if tools:
function_json = [
{"type": "function", "function": litellm.utils.function_to_dict(t)}
Expand Down Expand Up @@ -443,12 +455,16 @@ def llm_generation_chain(
tools=tools, force_tool_calling=force_tool_calling
)
else:
llm_generation_chain = llm_generation_chain | StrOutputParser()
if return_top_logprobs > 0:
llm_generation_chain = llm_generation_chain | return_response_and_logprobs
else:
llm_generation_chain = llm_generation_chain | StrOutputParser()

if pydantic_class:
llm_generation_chain = llm_generation_chain | string_to_pydantic_object.bind(
pydantic_class=pydantic_class
)


if additional_postprocessing_runnable:
llm_generation_chain = llm_generation_chain | additional_postprocessing_runnable
Expand Down
1 change: 1 addition & 0 deletions tasks/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def tests(c, log_level="info", parallel=False):
test_files = [
"./tests/test_llm_generate.py",
"./tests/test_function_calling.py",
"./tests/test_logprobs.py",
]

pytest_command = (
Expand Down
30 changes: 30 additions & 0 deletions tests/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from chainlite import (
get_logger,
llm_generation_chain,
)

logger = get_logger(__name__)


test_engine = "gpt-4o-openai"


@pytest.mark.asyncio(scope="session")
async def test_llm_generate():
response, logprobs = await llm_generation_chain(
template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs`
engine=test_engine,
max_tokens=5,
force_skip_cache=True,
return_top_logprobs=10,
).ainvoke({})

assert response is not None, "The response should not be None"
assert isinstance(response, str), "The response should be a string"
assert len(response) > 0, "The response should not be empty"

assert len(logprobs) == 5
for i in range(len(logprobs)):
assert "top_logprobs" in logprobs[i]
assert len(logprobs[i]["top_logprobs"]) == 10

0 comments on commit a6483f0

Please sign in to comment.