Skip to content

Commit

Permalink
Merge pull request #419 from KCaverly/google_import_fix
Browse files Browse the repository at this point in the history
fix: update imports to allow global dspy.Google
  • Loading branch information
okhat authored Feb 23, 2024
2 parents 3e8e4b6 + 96949a7 commit 9d8a40c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ollama import *
from .clarifai import *
from .bedrock import *
from .google import *


from .hf_client import HFClientTGI
Expand Down
20 changes: 10 additions & 10 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import google.generativeai as genai
except ImportError:
google_api_error = Exception
print("Not loading Google because it is not installed.")
# print("Not loading Google because it is not installed.")


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
Expand All @@ -33,10 +34,7 @@ class Google(LM):
"""

def __init__(
self,
model: str = "gemini-pro-1.0",
api_key: Optional[str] = None,
**kwargs
self, model: str = "gemini-pro-1.0", api_key: Optional[str] = None, **kwargs
):
"""
Parameters
Expand All @@ -51,15 +49,17 @@ def __init__(
Additional arguments to pass to the API provider.
"""
super().__init__(model)
self.google = genai.configure(api_key=self.api_key)
self.google = genai.configure(api_key=api_key)
self.provider = "google"
self.kwargs = {
"model_name": model,
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
"temperature": 0.0
if "temperature" not in kwargs
else kwargs["temperature"],
"max_output_tokens": 2048,
"top_p": 1,
"top_k": 1,
**kwargs
**kwargs,
}

self.history: list[dict[str, Any]] = []
Expand All @@ -85,7 +85,7 @@ def basic_request(self, prompt: str, **kwargs):

@backoff.on_exception(
backoff.expo,
(google_api_error),
(Exception),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
Expand All @@ -99,6 +99,6 @@ def __call__(
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs
**kwargs,
):
return self.request(prompt, **kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ColBERTv2 = dsp.ColBERTv2
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Expand Down

0 comments on commit 9d8a40c

Please sign in to comment.