Skip to content

Commit

Permalink
feat: set up exception classes for better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed Feb 12, 2024
1 parent b35d790 commit 2069d62
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 21 deletions.
43 changes: 28 additions & 15 deletions src/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
from asyncio import Task
from typing import Any, Optional

from httpx import AsyncClient
from httpx import AsyncClient, ReadTimeout
from loguru import logger

from .consts import HEADERS
from .types import Image, Candidate, ModelOutput
from .types import (
Image,
Candidate,
ModelOutput,
AuthError,
APIError,
GeminiError,
TimeoutError,
)


def running(func) -> callable:
Expand Down Expand Up @@ -102,7 +110,7 @@ async def init(
response = await self.client.get("https://gemini.google.com/app")

if response.status_code != 200:
raise Exception(
raise APIError(
f"Failed to initiate client. Request failed with status code {response.status_code}"
)
else:
Expand All @@ -112,7 +120,7 @@ async def init(
self.running = True
logger.success("Gemini client initiated successfully.")
else:
raise Exception(
raise AuthError(
"Failed to initiate client. SNlM0e not found in response, make sure cookie values are valid."
)

Expand Down Expand Up @@ -171,18 +179,23 @@ async def generate_content(
if self.auto_close:
await self.reset_close_task()

response = await self.client.post(
"https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate",
data={
"at": self.access_token,
"f.req": json.dumps(
[None, json.dumps([[prompt], None, chat and chat.metadata])]
),
},
)
try:
response = await self.client.post(
"https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate",
data={
"at": self.access_token,
"f.req": json.dumps(
[None, json.dumps([[prompt], None, chat and chat.metadata])]
),
},
)
except ReadTimeout:
raise TimeoutError(
"Request timed out, please try again. If the problem persists, consider setting a higher `timeout` value when initiating GeminiClient."
)

if response.status_code != 200:
raise Exception(
raise APIError(
f"Failed to generate contents. Request failed with status code {response.status_code}"
)
else:
Expand All @@ -202,7 +215,7 @@ async def generate_content(
Candidate(rcid=candidate[0], text=candidate[1][0], images=images)
)
if not candidates:
raise Exception(
raise GeminiError(
"Failed to generate contents. No output data found in response."
)

Expand Down
55 changes: 49 additions & 6 deletions src/gemini/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from typing import Optional

from pydantic import BaseModel


class Image(BaseModel):
"""
A single image object returned from Gemini.
Parameters
----------
url: `str`
URL of the image
title: `str`, optional
Title of the image, by default is "[Image]"
alt: `str`, optional
Optional description of the image
"""

url: str
title: Optional[str] = "[Image]"
alt: Optional[str] = ""
title: str = "[Image]"
alt: str = ""

def __str__(self):
return f"{self.title}({self.url}) - {self.alt}"

def __repr__(self):
return f"Image(title='{self.title}', url='{len(self.url)<=20 and self.url or self.url[:8] + '...' + self.url[-12:]}', alt='{self.alt}')"
return f"""Image(title='{self.title}', url='{len(self.url) <= 20 and self.url or self.url[:8] + '...' + self.url[-12:]}', alt='{self.alt}')"""


class Candidate(BaseModel):
Expand All @@ -24,7 +35,7 @@ def __str__(self):
return self.text

def __repr__(self):
return f"Candidate(rcid='{self.rcid}', text='{len(self.text)<=20 and self.text or self.text[:20] + '...'}', images={self.images})"
return f"Candidate(rcid='{self.rcid}', text='{len(self.text) <= 20 and self.text or self.text[:20] + '...'}', images={self.images})"


class ModelOutput(BaseModel):
Expand Down Expand Up @@ -62,3 +73,35 @@ def images(self):
@property
def rcid(self):
return self.candidates[self.chosen].rcid


class AuthError(Exception):
"""
Exception for authentication errors caused by invalid credentials/cookies.
"""

pass


class APIError(Exception):
"""
Exception for package-level errors which need to be fixed in the future development (e.g. validation errors).
"""

pass


class GeminiError(Exception):
"""
Exception for errors returned from Gemini server which are not handled by the package.
"""

pass


class TimeoutError(GeminiError):
"""
Exception for request timeouts.
"""

pass

0 comments on commit 2069d62

Please sign in to comment.