From 3f0863b55c891804ea9f5ddbc63289d4551d2902 Mon Sep 17 00:00:00 2001 From: bahtman Date: Tue, 26 Nov 2024 19:23:20 +0100 Subject: [PATCH 1/3] Update lm.py --- dspy/clients/lm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 4dffc17e1..788e10c69 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -84,7 +84,8 @@ def __call__(self, prompt=None, messages=None, **kwargs): cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} - + callable_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, Callable)} + kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, Callable)} # Make the request and handle LRU & disk caching. if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion @@ -94,6 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): response = completion( request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), num_retries=self.num_retries, + **callable_kwargs, ) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] From cf84b9eb382dc6e6b66c4928d8f6aff75a8aa40e Mon Sep 17 00:00:00 2001 From: Anton Baht Date: Tue, 26 Nov 2024 21:10:56 +0100 Subject: [PATCH 2/3] Finished the kwarg juggling --- dspy/clients/lm.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 788e10c69..b25491693 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -4,7 +4,7 @@ import threading import uuid from datetime import datetime -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Callable import litellm import ujson @@ -84,8 +84,11 @@ def __call__(self, prompt=None, messages=None, **kwargs): cache = kwargs.pop("cache", self.cache) messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} - callable_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, Callable)} - kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, Callable)} + callable_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, Callable): + callable_kwargs[k] = kwargs.pop(k) + # Make the request and handle LRU & disk caching. if self.model_type == "chat": completion = cached_litellm_completion if cache else litellm_completion @@ -215,16 +218,18 @@ def copy(self, **kwargs): @functools.lru_cache(maxsize=None) -def cached_litellm_completion(request, num_retries: int): +def cached_litellm_completion(request, num_retries: int, **kwargs): return litellm_completion( request, cache={"no-cache": False, "no-store": False}, num_retries=num_retries, + **kwargs, ) -def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) +def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}, **kwargs): + req_kwargs = ujson.loads(request) + kwargs = {**req_kwargs, **kwargs} return litellm.completion( num_retries=num_retries, cache=cache, @@ -233,17 +238,19 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s @functools.lru_cache(maxsize=None) -def cached_litellm_text_completion(request, num_retries: int): +def cached_litellm_text_completion(request, num_retries: int,**kwargs): return litellm_text_completion( request, num_retries=num_retries, cache={"no-cache": False, "no-store": False}, + **kwargs, ) -def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): - kwargs = ujson.loads(request) - +def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True},**kwargs): + req_kwargs = ujson.loads(request) + kwargs = {**req_kwargs, **kwargs} + # Extract the provider and model from the model string. # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1) From 37c719da15ed7331f58df4e8c695cb4062fc418b Mon Sep 17 00:00:00 2001 From: bahtman Date: Sun, 1 Dec 2024 21:25:34 +0100 Subject: [PATCH 3/3] Iterate over copy of dict instead of dict --- dspy/clients/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index b25491693..e4d3b7b4d 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -85,7 +85,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): messages = messages or [{"role": "user", "content": prompt}] kwargs = {**self.kwargs, **kwargs} callable_kwargs = {} - for k, v in kwargs.items(): + for k, v in list(kwargs.items()): if isinstance(v, Callable): callable_kwargs[k] = kwargs.pop(k)