From a6483f0e6c6ef5ba9c194a6cc77f7e961ecc6945 Mon Sep 17 00:00:00 2001 From: Sina Date: Fri, 1 Nov 2024 08:12:02 +0000 Subject: [PATCH] Add option to output top logprobs --- chainlite/llm_generate.py | 40 +++++++++++++++++++++++++++------------ tasks/main.py | 1 + tests/test_logprobs.py | 30 +++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 tests/test_logprobs.py diff --git a/chainlite/llm_generate.py b/chainlite/llm_generate.py index 0541576..73da93d 100644 --- a/chainlite/llm_generate.py +++ b/chainlite/llm_generate.py @@ -1,15 +1,9 @@ -""" -Functionality to work with .prompt files -""" - import json -import logging import os import random import re from datetime import datetime import litellm -from rich import print as pprint from typing import Any, Callable, Dict, List, Optional from uuid import UUID @@ -287,6 +281,10 @@ async def return_response_and_tool( return tool_outputs return response, tool_outputs +@chain +async def return_response_and_logprobs(llm_output): + response = await StrOutputParser().ainvoke(input=llm_output) + return response, llm_output.response_metadata.get("logprobs") def llm_generation_chain( template_file: str, @@ -303,6 +301,7 @@ def llm_generation_chain( additional_postprocessing_runnable: Runnable = None, tools: Optional[list[Callable]] = None, force_tool_calling: bool = False, + return_top_logprobs: int = 0, bind_prompt_values: Dict = {}, force_skip_cache: bool = False, ) -> Runnable: @@ -323,10 +322,12 @@ def llm_generation_chain( template_blocks: If provided, will use this instead of `template_file`. The format is [(role, string)] where role is one of "instruction", "input", "output" keep_indentation (bool, optional): If True, will keep indentations at the beginning of each line in the template_file. Defaults to False. progress_bar_name (str, optional): If provided, will display a `tqdm` progress bar using this name + additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged tools (List[Callable], optional): If provided, will be made available to the underlying LLM, to optionally output it for function calling. Defaults to None. force_tool_calling (bool, optional): If True, will force the LLM to output the tools for function calling. Defaults to False. - additional_postprocessing_runnable (Runnable, optional): If provided, will be applied to the output of LLM generation, and the final output will be logged + return_top_logprobs (int, optional): If > 0, will return the top logprobs for each token, so the output will be Tuple[str, dict]. Defaults to 0. bind_prompt_values (Dict, optional): A dictionary containing {Variable: str : Value}. Binds values to the prompt. Additional variables can be provided when the chain is called. Defaults to {}. + force_skip_cache (bool, optional): If True, will force the LLM to skip the cache, and the new value won't be saved in cache either. Defaults to False. Returns: Runnable: The language model generation chain @@ -351,13 +352,20 @@ def llm_generation_chain( raise IndexError( f"Could not find any matching engines for {engine}. Please check that llm_config.yaml is configured correctly and that the API key is set in the terminal before running this script." ) + if ( - (pydantic_class and tools) - or (pydantic_class and output_json) - or (pydantic_class and output_json) + sum( + [ + bool(pydantic_class), + bool(output_json), + bool(tools), + return_top_logprobs > 0, + ] + ) + > 1 ): raise ValueError( - "At most one of `pydantic_class`, `output_json` and `tools` can be used." + "At most one of `pydantic_class`, `output_json`, `return_top_logprobs` and `tools` can be used." ) llm_resource = random.choice(potential_llm_resources) @@ -404,6 +412,10 @@ def llm_generation_chain( }, } + if return_top_logprobs > 0: + model_kwargs["logprobs"] = True + model_kwargs["top_logprobs"] = return_top_logprobs + if tools: function_json = [ {"type": "function", "function": litellm.utils.function_to_dict(t)} @@ -443,12 +455,16 @@ def llm_generation_chain( tools=tools, force_tool_calling=force_tool_calling ) else: - llm_generation_chain = llm_generation_chain | StrOutputParser() + if return_top_logprobs > 0: + llm_generation_chain = llm_generation_chain | return_response_and_logprobs + else: + llm_generation_chain = llm_generation_chain | StrOutputParser() if pydantic_class: llm_generation_chain = llm_generation_chain | string_to_pydantic_object.bind( pydantic_class=pydantic_class ) + if additional_postprocessing_runnable: llm_generation_chain = llm_generation_chain | additional_postprocessing_runnable diff --git a/tasks/main.py b/tasks/main.py index 50a78da..483ea7b 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -62,6 +62,7 @@ def tests(c, log_level="info", parallel=False): test_files = [ "./tests/test_llm_generate.py", "./tests/test_function_calling.py", + "./tests/test_logprobs.py", ] pytest_command = ( diff --git a/tests/test_logprobs.py b/tests/test_logprobs.py new file mode 100644 index 0000000..098a04d --- /dev/null +++ b/tests/test_logprobs.py @@ -0,0 +1,30 @@ +import pytest +from chainlite import ( + get_logger, + llm_generation_chain, +) + +logger = get_logger(__name__) + + +test_engine = "gpt-4o-openai" + + +@pytest.mark.asyncio(scope="session") +async def test_llm_generate(): + response, logprobs = await llm_generation_chain( + template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs` + engine=test_engine, + max_tokens=5, + force_skip_cache=True, + return_top_logprobs=10, + ).ainvoke({}) + + assert response is not None, "The response should not be None" + assert isinstance(response, str), "The response should be a string" + assert len(response) > 0, "The response should not be empty" + + assert len(logprobs) == 5 + for i in range(len(logprobs)): + assert "top_logprobs" in logprobs[i] + assert len(logprobs[i]["top_logprobs"]) == 10