Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update imports to allow global dspy.Google #419

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .ollama import *
from .clarifai import *
from .bedrock import *
from .google import *


from .hf_client import HFClientTGI
Expand Down
20 changes: 10 additions & 10 deletions dsp/modules/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import google.generativeai as genai
except ImportError:
google_api_error = Exception
print("Not loading Google because it is not installed.")
# print("Not loading Google because it is not installed.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why commenting this out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it stays, we print this every time anybody imports dspy, regardless of whether they are importing Google specifically. It's similar to how imports are managed on the Cohere side in: dsp/modules/cohere.py.

I imagine there may be a cleaner way throughout the project to manage optional dependencies.



def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
Expand All @@ -33,10 +34,7 @@ class Google(LM):
"""

def __init__(
self,
model: str = "gemini-pro-1.0",
api_key: Optional[str] = None,
**kwargs
self, model: str = "gemini-pro-1.0", api_key: Optional[str] = None, **kwargs
):
"""
Parameters
Expand All @@ -51,15 +49,17 @@ def __init__(
Additional arguments to pass to the API provider.
"""
super().__init__(model)
self.google = genai.configure(api_key=self.api_key)
self.google = genai.configure(api_key=api_key)
self.provider = "google"
self.kwargs = {
"model_name": model,
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
"temperature": 0.0
if "temperature" not in kwargs
else kwargs["temperature"],
"max_output_tokens": 2048,
"top_p": 1,
"top_k": 1,
**kwargs
**kwargs,
}

self.history: list[dict[str, Any]] = []
Expand All @@ -85,7 +85,7 @@ def basic_request(self, prompt: str, **kwargs):

@backoff.on_exception(
backoff.expo,
(google_api_error),
(Exception),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
Expand All @@ -99,6 +99,6 @@ def __call__(
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs
**kwargs,
):
return self.request(prompt, **kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ColBERTv2 = dsp.ColBERTv2
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Expand Down
Loading