From 5c870c0c5139a4c87e064546ed3fd944da8c5e37 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 10 Jan 2025 17:57:22 -0800 Subject: [PATCH] (performance improvement - litellm sdk + proxy) - ensure litellm does not create unnecessary threads when running async functions (#7680) * fix handle_sync_success_callbacks_for_async_calls * fix handle_sync_success_callbacks_for_async_calls * fix linting / testing errors * use handle_sync_success_callbacks_for_async_calls * add unit testing for logging fixes --- litellm/litellm_core_utils/litellm_logging.py | 128 ++++++++++++++++-- litellm/utils.py | 18 +-- .../test_unit_test_litellm_logging.py | 127 +++++++++++++++++ .../test_unit_tests_init_callbacks.py | 16 ++- 4 files changed, 264 insertions(+), 25 deletions(-) create mode 100644 tests/logging_callback_tests/test_unit_test_litellm_logging.py diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index d2d71bc80452..34a270258de4 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -65,7 +65,7 @@ TranscriptionResponse, Usage, ) -from litellm.utils import _get_base_model_from_metadata, print_verbose +from litellm.utils import _get_base_model_from_metadata, executor, print_verbose from ..integrations.argilla import ArgillaLogger from ..integrations.arize_ai import ArizeLogger @@ -1019,7 +1019,7 @@ def success_handler( # noqa: PLR0915 status="success", ) ) - callbacks = get_combined_callback_list( + callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_success_callbacks, global_callbacks=litellm.success_callback, ) @@ -1555,7 +1555,7 @@ async def async_success_handler( # noqa: PLR0915 status="success", ) ) - callbacks = get_combined_callback_list( + callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_success_callbacks, global_callbacks=litellm._async_success_callback, ) @@ -1825,7 +1825,7 @@ def failure_handler( # noqa: PLR0915 start_time=start_time, end_time=end_time, ) - callbacks = get_combined_callback_list( + callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_failure_callbacks, global_callbacks=litellm.failure_callback, ) @@ -2011,7 +2011,7 @@ async def async_failure_handler( end_time=end_time, ) - callbacks = get_combined_callback_list( + callbacks = self.get_combined_callback_list( dynamic_success_callbacks=self.dynamic_async_failure_callbacks, global_callbacks=litellm._async_failure_callback, ) @@ -2108,6 +2108,116 @@ def _get_callback_object(self, service_name: Literal["langfuse"]) -> Optional[An return None + def handle_sync_success_callbacks_for_async_calls( + self, + result: Any, + start_time: datetime.datetime, + end_time: datetime.datetime, + ) -> None: + """ + Handles calling success callbacks for Async calls. + + Why: Some callbacks - `langfuse`, `s3` are sync callbacks. We need to call them in the executor. + """ + if self._should_run_sync_callbacks_for_async_calls() is False: + return + + executor.submit( + self.success_handler, + result, + start_time, + end_time, + ) + + def _should_run_sync_callbacks_for_async_calls(self) -> bool: + """ + Returns: + - bool: True if sync callbacks should be run for async calls. eg. `langfuse`, `s3` + """ + _combined_sync_callbacks = self.get_combined_callback_list( + dynamic_success_callbacks=self.dynamic_success_callbacks, + global_callbacks=litellm.success_callback, + ) + _filtered_success_callbacks = self._remove_internal_custom_logger_callbacks( + _combined_sync_callbacks + ) + _filtered_success_callbacks = self._remove_internal_litellm_callbacks( + _filtered_success_callbacks + ) + return len(_filtered_success_callbacks) > 0 + + def get_combined_callback_list( + self, dynamic_success_callbacks: Optional[List], global_callbacks: List + ) -> List: + if dynamic_success_callbacks is None: + return global_callbacks + return list(set(dynamic_success_callbacks + global_callbacks)) + + def _remove_internal_litellm_callbacks(self, callbacks: List) -> List: + """ + Creates a filtered list of callbacks, excluding internal LiteLLM callbacks. + + Args: + callbacks: List of callback functions/strings to filter + + Returns: + List of filtered callbacks with internal ones removed + """ + filtered = [ + cb for cb in callbacks if not self._is_internal_litellm_proxy_callback(cb) + ] + + verbose_logger.debug(f"Filtered callbacks: {filtered}") + return filtered + + def _get_callback_name(self, cb) -> str: + """ + Helper to get the name of a callback function + + Args: + cb: The callback function/string to get the name of + + Returns: + The name of the callback + """ + if hasattr(cb, "__name__"): + return cb.__name__ + if hasattr(cb, "__func__"): + return cb.__func__.__name__ + return str(cb) + + def _is_internal_litellm_proxy_callback(self, cb) -> bool: + """Helper to check if a callback is internal""" + INTERNAL_PREFIXES = [ + "_PROXY", + "_service_logger.ServiceLogging", + "sync_deployment_callback_on_success", + ] + if isinstance(cb, str): + return False + + if not callable(cb): + return True + + cb_name = self._get_callback_name(cb) + return any(prefix in cb_name for prefix in INTERNAL_PREFIXES) + + def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List: + """ + Removes internal custom logger callbacks from the list. + """ + _new_callbacks = [] + for _c in callbacks: + if isinstance(_c, CustomLogger): + continue + elif ( + isinstance(_c, str) + and _c in litellm._known_custom_logger_compatible_callbacks + ): + continue + _new_callbacks.append(_c) + return _new_callbacks + def set_callbacks(callback_list, function_id=None): # noqa: PLR0915 """ @@ -3191,11 +3301,3 @@ def modify_integration(integration_name, integration_params): if integration_name == "supabase": if "table_name" in integration_params: Supabase.supabase_table_name = integration_params["table_name"] - - -def get_combined_callback_list( - dynamic_success_callbacks: Optional[List], global_callbacks: List -) -> List: - if dynamic_success_callbacks is None: - return global_callbacks - return list(set(dynamic_success_callbacks + global_callbacks)) diff --git a/litellm/utils.py b/litellm/utils.py index 494194df9b46..6205293e1d0c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -608,10 +608,11 @@ async def _client_async_logging_helper( asyncio.create_task( logging_obj.async_success_handler(result, start_time, end_time) ) - threading.Thread( - target=logging_obj.success_handler, - args=(result, start_time, end_time), - ).start() + logging_obj.handle_sync_success_callbacks_for_async_calls( + result=result, + start_time=start_time, + end_time=end_time, + ) def client(original_function): # noqa: PLR0915 @@ -1153,11 +1154,10 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915 is_completion_with_fallbacks=is_completion_with_fallbacks, ) ) - executor.submit( - logging_obj.success_handler, - result, - start_time, - end_time, + logging_obj.handle_sync_success_callbacks_for_async_calls( + result=result, + start_time=start_time, + end_time=end_time, ) # REBUILD EMBEDDING CACHING if ( diff --git a/tests/logging_callback_tests/test_unit_test_litellm_logging.py b/tests/logging_callback_tests/test_unit_test_litellm_logging.py new file mode 100644 index 000000000000..455d0dacb9f3 --- /dev/null +++ b/tests/logging_callback_tests/test_unit_test_litellm_logging.py @@ -0,0 +1,127 @@ +import json +import os +import sys +from datetime import datetime +from unittest.mock import AsyncMock + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path + +from typing import Literal + +import pytest +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter +from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck +from litellm._service_logger import ServiceLogging +import asyncio + + +from litellm.litellm_core_utils.litellm_logging import Logging +import litellm + +service_logger = ServiceLogging() + + +def setup_logging(): + return Logging( + model="gpt-4o", + messages=[{"role": "user", "content": "Hello, world!"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) + + +def test_get_callback_name(): + """ + Ensure we can get the name of a callback + """ + logging = setup_logging() + + # Test function with __name__ + def test_func(): + pass + + assert logging._get_callback_name(test_func) == "test_func" + + # Test function with __func__ + class TestClass: + def method(self): + pass + + bound_method = TestClass().method + assert logging._get_callback_name(bound_method) == "method" + + # Test string callback + assert logging._get_callback_name("callback_string") == "callback_string" + + +def test_is_internal_litellm_proxy_callback(): + """ + Ensure we can determine if a callback is an internal litellm proxy callback + + eg. `_PROXY_MaxBudgetLimiter`, `_PROXY_CacheControlCheck` + """ + logging = setup_logging() + + assert logging._is_internal_litellm_proxy_callback(_PROXY_MaxBudgetLimiter) == True + + # Test non-internal callbacks + def regular_callback(): + pass + + assert logging._is_internal_litellm_proxy_callback(regular_callback) == False + + # Test string callback + assert logging._is_internal_litellm_proxy_callback("callback_string") == False + + +def test_should_run_sync_callbacks_for_async_calls(): + """ + Ensure we can determine if we should run sync callbacks for async calls + + Note: We don't want to run sync callbacks for async calls because we don't want to block the event loop + """ + logging = setup_logging() + + # Test with no callbacks + logging.dynamic_success_callbacks = None + litellm.success_callback = [] + assert logging._should_run_sync_callbacks_for_async_calls() == False + + # Test with regular callback + def regular_callback(): + pass + + litellm.success_callback = [regular_callback] + assert logging._should_run_sync_callbacks_for_async_calls() == True + + # Test with internal callback only + litellm.success_callback = [_PROXY_MaxBudgetLimiter] + assert logging._should_run_sync_callbacks_for_async_calls() == False + + +def test_remove_internal_litellm_callbacks(): + logging = setup_logging() + + def regular_callback(): + pass + + callbacks = [ + regular_callback, + _PROXY_MaxBudgetLimiter, + _PROXY_CacheControlCheck, + "string_callback", + ] + + filtered = logging._remove_internal_litellm_callbacks(callbacks) + assert len(filtered) == 2 # Should only keep regular_callback and string_callback + assert regular_callback in filtered + assert "string_callback" in filtered + assert _PROXY_MaxBudgetLimiter not in filtered + assert _PROXY_CacheControlCheck not in filtered diff --git a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py index f5728b95b2f0..69bee7d402f3 100644 --- a/tests/logging_callback_tests/test_unit_tests_init_callbacks.py +++ b/tests/logging_callback_tests/test_unit_tests_init_callbacks.py @@ -295,11 +295,21 @@ def test_dynamic_logging_global_callback(): def test_get_combined_callback_list(): - from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + _logging = LiteLLMLoggingObj( + model="claude-3-opus-20240229", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="completion", + start_time=datetime.now(), + litellm_call_id="123", + function_id="456", + ) - assert "langfuse" in get_combined_callback_list( + assert "langfuse" in _logging.get_combined_callback_list( dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] ) - assert "lago" in get_combined_callback_list( + assert "lago" in _logging.get_combined_callback_list( dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"] )