Skip to content

Commit

Permalink
(performance improvement - litellm sdk + proxy) - ensure litellm does…
Browse files Browse the repository at this point in the history
… not create unnecessary threads when running async functions (#7680)

* fix handle_sync_success_callbacks_for_async_calls

* fix handle_sync_success_callbacks_for_async_calls

* fix linting / testing errors

* use handle_sync_success_callbacks_for_async_calls

* add unit testing for logging fixes
  • Loading branch information
ishaan-jaff authored Jan 11, 2025
1 parent a3e65c9 commit 5c870c0
Show file tree
Hide file tree
Showing 4 changed files with 264 additions and 25 deletions.
128 changes: 115 additions & 13 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
TranscriptionResponse,
Usage,
)
from litellm.utils import _get_base_model_from_metadata, print_verbose
from litellm.utils import _get_base_model_from_metadata, executor, print_verbose

from ..integrations.argilla import ArgillaLogger
from ..integrations.arize_ai import ArizeLogger
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def success_handler( # noqa: PLR0915
status="success",
)
)
callbacks = get_combined_callback_list(
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
global_callbacks=litellm.success_callback,
)
Expand Down Expand Up @@ -1555,7 +1555,7 @@ async def async_success_handler( # noqa: PLR0915
status="success",
)
)
callbacks = get_combined_callback_list(
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_success_callbacks,
global_callbacks=litellm._async_success_callback,
)
Expand Down Expand Up @@ -1825,7 +1825,7 @@ def failure_handler( # noqa: PLR0915
start_time=start_time,
end_time=end_time,
)
callbacks = get_combined_callback_list(
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_failure_callbacks,
global_callbacks=litellm.failure_callback,
)
Expand Down Expand Up @@ -2011,7 +2011,7 @@ async def async_failure_handler(
end_time=end_time,
)

callbacks = get_combined_callback_list(
callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_async_failure_callbacks,
global_callbacks=litellm._async_failure_callback,
)
Expand Down Expand Up @@ -2108,6 +2108,116 @@ def _get_callback_object(self, service_name: Literal["langfuse"]) -> Optional[An

return None

def handle_sync_success_callbacks_for_async_calls(
self,
result: Any,
start_time: datetime.datetime,
end_time: datetime.datetime,
) -> None:
"""
Handles calling success callbacks for Async calls.
Why: Some callbacks - `langfuse`, `s3` are sync callbacks. We need to call them in the executor.
"""
if self._should_run_sync_callbacks_for_async_calls() is False:
return

executor.submit(
self.success_handler,
result,
start_time,
end_time,
)

def _should_run_sync_callbacks_for_async_calls(self) -> bool:
"""
Returns:
- bool: True if sync callbacks should be run for async calls. eg. `langfuse`, `s3`
"""
_combined_sync_callbacks = self.get_combined_callback_list(
dynamic_success_callbacks=self.dynamic_success_callbacks,
global_callbacks=litellm.success_callback,
)
_filtered_success_callbacks = self._remove_internal_custom_logger_callbacks(
_combined_sync_callbacks
)
_filtered_success_callbacks = self._remove_internal_litellm_callbacks(
_filtered_success_callbacks
)
return len(_filtered_success_callbacks) > 0

def get_combined_callback_list(
self, dynamic_success_callbacks: Optional[List], global_callbacks: List
) -> List:
if dynamic_success_callbacks is None:
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))

def _remove_internal_litellm_callbacks(self, callbacks: List) -> List:
"""
Creates a filtered list of callbacks, excluding internal LiteLLM callbacks.
Args:
callbacks: List of callback functions/strings to filter
Returns:
List of filtered callbacks with internal ones removed
"""
filtered = [
cb for cb in callbacks if not self._is_internal_litellm_proxy_callback(cb)
]

verbose_logger.debug(f"Filtered callbacks: {filtered}")
return filtered

def _get_callback_name(self, cb) -> str:
"""
Helper to get the name of a callback function
Args:
cb: The callback function/string to get the name of
Returns:
The name of the callback
"""
if hasattr(cb, "__name__"):
return cb.__name__
if hasattr(cb, "__func__"):
return cb.__func__.__name__
return str(cb)

def _is_internal_litellm_proxy_callback(self, cb) -> bool:
"""Helper to check if a callback is internal"""
INTERNAL_PREFIXES = [
"_PROXY",
"_service_logger.ServiceLogging",
"sync_deployment_callback_on_success",
]
if isinstance(cb, str):
return False

if not callable(cb):
return True

cb_name = self._get_callback_name(cb)
return any(prefix in cb_name for prefix in INTERNAL_PREFIXES)

def _remove_internal_custom_logger_callbacks(self, callbacks: List) -> List:
"""
Removes internal custom logger callbacks from the list.
"""
_new_callbacks = []
for _c in callbacks:
if isinstance(_c, CustomLogger):
continue
elif (
isinstance(_c, str)
and _c in litellm._known_custom_logger_compatible_callbacks
):
continue
_new_callbacks.append(_c)
return _new_callbacks


def set_callbacks(callback_list, function_id=None): # noqa: PLR0915
"""
Expand Down Expand Up @@ -3191,11 +3301,3 @@ def modify_integration(integration_name, integration_params):
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]


def get_combined_callback_list(
dynamic_success_callbacks: Optional[List], global_callbacks: List
) -> List:
if dynamic_success_callbacks is None:
return global_callbacks
return list(set(dynamic_success_callbacks + global_callbacks))
18 changes: 9 additions & 9 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,10 +608,11 @@ async def _client_async_logging_helper(
asyncio.create_task(
logging_obj.async_success_handler(result, start_time, end_time)
)
threading.Thread(
target=logging_obj.success_handler,
args=(result, start_time, end_time),
).start()
logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result,
start_time=start_time,
end_time=end_time,
)


def client(original_function): # noqa: PLR0915
Expand Down Expand Up @@ -1153,11 +1154,10 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
is_completion_with_fallbacks=is_completion_with_fallbacks,
)
)
executor.submit(
logging_obj.success_handler,
result,
start_time,
end_time,
logging_obj.handle_sync_success_callbacks_for_async_calls(
result=result,
start_time=start_time,
end_time=end_time,
)
# REBUILD EMBEDDING CACHING
if (
Expand Down
127 changes: 127 additions & 0 deletions tests/logging_callback_tests/test_unit_test_litellm_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import json
import os
import sys
from datetime import datetime
from unittest.mock import AsyncMock

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path

from typing import Literal

import pytest
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm._service_logger import ServiceLogging
import asyncio


from litellm.litellm_core_utils.litellm_logging import Logging
import litellm

service_logger = ServiceLogging()


def setup_logging():
return Logging(
model="gpt-4o",
messages=[{"role": "user", "content": "Hello, world!"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
)


def test_get_callback_name():
"""
Ensure we can get the name of a callback
"""
logging = setup_logging()

# Test function with __name__
def test_func():
pass

assert logging._get_callback_name(test_func) == "test_func"

# Test function with __func__
class TestClass:
def method(self):
pass

bound_method = TestClass().method
assert logging._get_callback_name(bound_method) == "method"

# Test string callback
assert logging._get_callback_name("callback_string") == "callback_string"


def test_is_internal_litellm_proxy_callback():
"""
Ensure we can determine if a callback is an internal litellm proxy callback
eg. `_PROXY_MaxBudgetLimiter`, `_PROXY_CacheControlCheck`
"""
logging = setup_logging()

assert logging._is_internal_litellm_proxy_callback(_PROXY_MaxBudgetLimiter) == True

# Test non-internal callbacks
def regular_callback():
pass

assert logging._is_internal_litellm_proxy_callback(regular_callback) == False

# Test string callback
assert logging._is_internal_litellm_proxy_callback("callback_string") == False


def test_should_run_sync_callbacks_for_async_calls():
"""
Ensure we can determine if we should run sync callbacks for async calls
Note: We don't want to run sync callbacks for async calls because we don't want to block the event loop
"""
logging = setup_logging()

# Test with no callbacks
logging.dynamic_success_callbacks = None
litellm.success_callback = []
assert logging._should_run_sync_callbacks_for_async_calls() == False

# Test with regular callback
def regular_callback():
pass

litellm.success_callback = [regular_callback]
assert logging._should_run_sync_callbacks_for_async_calls() == True

# Test with internal callback only
litellm.success_callback = [_PROXY_MaxBudgetLimiter]
assert logging._should_run_sync_callbacks_for_async_calls() == False


def test_remove_internal_litellm_callbacks():
logging = setup_logging()

def regular_callback():
pass

callbacks = [
regular_callback,
_PROXY_MaxBudgetLimiter,
_PROXY_CacheControlCheck,
"string_callback",
]

filtered = logging._remove_internal_litellm_callbacks(callbacks)
assert len(filtered) == 2 # Should only keep regular_callback and string_callback
assert regular_callback in filtered
assert "string_callback" in filtered
assert _PROXY_MaxBudgetLimiter not in filtered
assert _PROXY_CacheControlCheck not in filtered
16 changes: 13 additions & 3 deletions tests/logging_callback_tests/test_unit_tests_init_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,21 @@ def test_dynamic_logging_global_callback():


def test_get_combined_callback_list():
from litellm.litellm_core_utils.litellm_logging import get_combined_callback_list
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

_logging = LiteLLMLoggingObj(
model="claude-3-opus-20240229",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="completion",
start_time=datetime.now(),
litellm_call_id="123",
function_id="456",
)

assert "langfuse" in get_combined_callback_list(
assert "langfuse" in _logging.get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)
assert "lago" in get_combined_callback_list(
assert "lago" in _logging.get_combined_callback_list(
dynamic_success_callbacks=["langfuse"], global_callbacks=["lago"]
)

0 comments on commit 5c870c0

Please sign in to comment.