diff --git a/litellm/litellm_core_utils/logging_callback_manager.py b/litellm/litellm_core_utils/logging_callback_manager.py index 860a57c5f628..e55df4447468 100644 --- a/litellm/litellm_core_utils/logging_callback_manager.py +++ b/litellm/litellm_core_utils/logging_callback_manager.py @@ -85,6 +85,21 @@ def add_litellm_async_failure_callback( callback=callback, parent_list=litellm._async_failure_callback ) + def remove_callback_from_list_by_object( + self, callback_list, obj + ): + """ + Remove callbacks that are methods of a particular object (e.g., router cleanup) + """ + if not isinstance(callback_list, list): # Not list -> do nothing + return + + remove_list=[c for c in callback_list if hasattr(c, '__self__') and c.__self__ == obj] + + for c in remove_list: + callback_list.remove(c) + + def _add_string_callback_to_list( self, callback: str, parent_list: List[Union[CustomLogger, Callable, str]] ): diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index 1ff211c20ccc..6102a26b232f 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -58,7 +58,9 @@ async def route_request( elif "user_config" in data: router_config = data.pop("user_config") user_router = litellm.Router(**router_config) - return getattr(user_router, f"{route_type}")(**data) + ret_val = getattr(user_router, f"{route_type}")(**data) + user_router.discard() + return ret_val elif ( route_type == "acompletion" diff --git a/litellm/router.py b/litellm/router.py index 597ba9fd06aa..bdac540f1ab7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -573,6 +573,21 @@ def __init__( # noqa: PLR0915 litellm.amoderation, call_type="moderation" ) + + def discard(self): + """ + Pseudo-destructor to be invoked to clean up global data structures when router is no longer used. + For now, unhook router's callbacks from all lists + """ + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_success_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.success_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm._async_failure_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.failure_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.input_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.service_callback, self) + litellm.logging_callback_manager.remove_callback_from_list_by_object(litellm.callbacks, self) + + def _update_redis_cache(self, cache: RedisCache): """ Update the redis cache for the router, if none set. @@ -587,6 +602,7 @@ def _update_redis_cache(self, cache: RedisCache): if self.cache.redis_cache is None: self.cache.redis_cache = cache + def initialize_assistants_endpoint(self): ## INITIALIZE PASS THROUGH ASSISTANTS ENDPOINT ## self.acreate_assistants = self.factory_function(litellm.acreate_assistants) diff --git a/tests/litellm_utils_tests/test_logging_callback_manager.py b/tests/litellm_utils_tests/test_logging_callback_manager.py index 71ffb1867819..1b70631e4d17 100644 --- a/tests/litellm_utils_tests/test_logging_callback_manager.py +++ b/tests/litellm_utils_tests/test_logging_callback_manager.py @@ -160,6 +160,39 @@ def test_async_callbacks(): assert async_failure in litellm._async_failure_callback +def test_remove_callback_from_list_by_object(): + manager = LoggingCallbackManager() + # Reset all callbacks + manager._reset_all_callbacks() + + def TestObject(): + def __init__(self): + manager.add_litellm_callback(self.callback) + manager.add_litellm_success_callback(self.callback) + manager.add_litellm_failure_callback(self.callback) + manager.add_litellm_async_success_callback(self.callback) + manager.add_litellm_async_failure_callback(self.callback) + + def callback(self): + pass + + obj = TestObject() + + manager.remove_callback_from_list_by_object(litellm.callbacks, obj) + manager.remove_callback_from_list_by_object(litellm.success_callback, obj) + manager.remove_callback_from_list_by_object(litellm.failure_callback, obj) + manager.remove_callback_from_list_by_object(litellm._async_success_callback, obj) + manager.remove_callback_from_list_by_object(litellm._async_failure_callback, obj) + + # Verify all callback lists are empty + assert len(litellm.callbacks) == 0 + assert len(litellm.success_callback) == 0 + assert len(litellm.failure_callback) == 0 + assert len(litellm._async_success_callback) == 0 + assert len(litellm._async_failure_callback) == 0 + + + def test_reset_callbacks(callback_manager): # Add various callbacks callback_manager.add_litellm_callback("test") diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index e02b47ec365e..f12371baebc2 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -918,6 +918,31 @@ def test_flush_cache(model_list): assert router.cache.get_cache("test") is None +def test_discard(model_list): + """ + Test that discard properly removes a Router from the callback lists + """ + litellm.callbacks = [] + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm._async_failure_callback = [] + litellm.input_callback = [] + litellm.service_callback = [] + + router = Router(model_list=model_list) + router.discard() + + # Verify all callback lists are empty + assert len(litellm.callbacks) == 0 + assert len(litellm.success_callback) == 0 + assert len(litellm.failure_callback) == 0 + assert len(litellm._async_success_callback) == 0 + assert len(litellm._async_failure_callback) == 0 + assert len(litellm.input_callback) == 0 + assert len(litellm.service_callback) == 0 + + def test_initialize_assistants_endpoint(model_list): """Test if the 'initialize_assistants_endpoint' function is working correctly""" router = Router(model_list=model_list)