diff --git a/README.md b/README.md index ebae817..0527dc5 100644 --- a/README.md +++ b/README.md @@ -117,13 +117,13 @@ Note: by default, when asked to send images (like the previous example), Gemini ### Save images to local files -You can save images returned from Gemini to local files under `/temp` by calling `Image.save()`. Optionally, you can specify the file path and file name by passing `path` and `filename` arguments to the function. Works for both `WebImage` and `GeneratedImage`. +You can save images returned from Gemini to local files under `/temp` by calling `Image.save()`. Optionally, you can specify the file path and file name by passing `path` and `filename` arguments to the function and skip images with invalid file names by passing `skip_invalid_filename=True`. Works for both `WebImage` and `GeneratedImage`. ```python async def main(): response = await client.generate_content("Generate some pictures of cats") for i, image in enumerate(response.images): - await image.save(path="temp/", filename=f"cat_{i}.png") + await image.save(path="temp/", filename=f"cat_{i}.png", verbose=True) asyncio.run(main()) ``` diff --git a/src/gemini/__init__.py b/src/gemini/__init__.py index 077cffb..c1ac8cc 100644 --- a/src/gemini/__init__.py +++ b/src/gemini/__init__.py @@ -1,2 +1,3 @@ from .client import GeminiClient, ChatSession # noqa: F401 +from .exceptions import * # noqa: F401, F403 from .types import * # noqa: F401, F403 diff --git a/src/gemini/client.py b/src/gemini/client.py index c5b11e0..86e2c8b 100644 --- a/src/gemini/client.py +++ b/src/gemini/client.py @@ -7,17 +7,9 @@ from httpx import AsyncClient, ReadTimeout from loguru import logger -from .consts import HEADERS -from .types import ( - WebImage, - GeneratedImage, - Candidate, - ModelOutput, - AuthError, - APIError, - GeminiError, - TimeoutError, -) +from .types import WebImage, GeneratedImage, Candidate, ModelOutput +from .exceptions import APIError, AuthError, TimeoutError, GeminiError +from .constant import HEADERS def running(func) -> callable: @@ -71,10 +63,7 @@ def __init__( secure_1psidts: Optional[str] = None, proxy: Optional[dict] = None, ): - self.cookies = { - "__Secure-1PSID": secure_1psid, - "__Secure-1PSIDTS": secure_1psidts, - } + self.cookies = {"__Secure-1PSID": secure_1psid} self.proxy = proxy self.client: AsyncClient | None = None self.access_token: Optional[str] = None @@ -83,6 +72,9 @@ def __init__( self.close_delay: float = 300 self.close_task: Task | None = None + if secure_1psidts: + self.cookies["__Secure-1PSIDTS"] = secure_1psidts + async def init( self, timeout: float = 30, auto_close: bool = False, close_delay: float = 300 ) -> None: @@ -248,7 +240,9 @@ async def generate_content( GeneratedImage( url=image[0][3][3], title=f"[Generated Image {image[3][6]}]", - alt=image[3][5][i], + alt=len(image[3][5]) > i + and image[3][5][i] + or image[3][5][0], cookies=self.cookies, ) for i, image in enumerate(candidate[12][7][0]) diff --git a/src/gemini/consts.py b/src/gemini/constant.py similarity index 100% rename from src/gemini/consts.py rename to src/gemini/constant.py diff --git a/src/gemini/exceptions.py b/src/gemini/exceptions.py new file mode 100644 index 0000000..de736e2 --- /dev/null +++ b/src/gemini/exceptions.py @@ -0,0 +1,30 @@ +class AuthError(Exception): + """ + Exception for authentication errors caused by invalid credentials/cookies. + """ + + pass + + +class APIError(Exception): + """ + Exception for package-level errors which need to be fixed in the future development (e.g. validation errors). + """ + + pass + + +class GeminiError(Exception): + """ + Exception for errors returned from Gemini server which are not handled by the package. + """ + + pass + + +class TimeoutError(GeminiError): + """ + Exception for request timeouts. + """ + + pass diff --git a/src/gemini/types/__init__.py b/src/gemini/types/__init__.py new file mode 100644 index 0000000..7e43010 --- /dev/null +++ b/src/gemini/types/__init__.py @@ -0,0 +1,3 @@ +from .image import Image, WebImage, GeneratedImage # noqa: F401 +from .candidate import Candidate # noqa: F401 +from .modeloutput import ModelOutput # noqa: F401 diff --git a/src/gemini/types/candidate.py b/src/gemini/types/candidate.py new file mode 100644 index 0000000..e044388 --- /dev/null +++ b/src/gemini/types/candidate.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel + +from .image import Image, WebImage, GeneratedImage + + +class Candidate(BaseModel): + """ + A single reply candidate object in the model output. A full response from Gemini usually contains multiple reply candidates. + + Parameters + ---------- + rcid: `str` + Reply candidate ID to build the metadata + text: `str` + Text output + web_images: `list[WebImage]`, optional + List of web images in reply, can be empty. + generated_images: `list[GeneratedImage]`, optional + List of generated images in reply, can be empty + """ + + rcid: str + text: str + web_images: list[WebImage] = [] + generated_images: list[GeneratedImage] = [] + + def __str__(self): + return self.text + + def __repr__(self): + return f"Candidate(rcid='{self.rcid}', text='{len(self.text) <= 20 and self.text or self.text[:20] + '...'}', images={self.images})" + + @property + def images(self) -> list[Image]: + return self.web_images + self.generated_images diff --git a/src/gemini/types.py b/src/gemini/types/image.py similarity index 51% rename from src/gemini/types.py rename to src/gemini/types/image.py index 7da502f..fde4e7f 100644 --- a/src/gemini/types.py +++ b/src/gemini/types/image.py @@ -36,8 +36,9 @@ async def save( path: str = "temp", filename: str | None = None, cookies: dict | None = None, - verbose: bool = False - ) -> None: + verbose: bool = False, + skip_invalid_filename: bool = False, + ) -> str | None: """ Save the image to disk. @@ -46,22 +47,27 @@ async def save( path: `str`, optional Path to save the image, by default will save to ./temp filename: `str`, optional - Filename to save the image, by default will use the original filename from the URL + File name to save the image, by default will use the original file name from the URL cookies: `dict`, optional Cookies used for requesting the content of the image verbose : `bool`, optional - If True, print the path of the saved file, by default False + If True, print the path of the saved file or warning for invalid file name, by default False + skip_invalid_filename: `bool`, optional + If True, will only save the image if the file name and extension are valid, by default False + + Returns + ------- + `str | None` + Absolute path of the saved image if successful, None if filename is invalid and `skip_invalid_filename` is True """ + filename = filename or self.url.split("/")[-1].split("?")[0] try: - filename = ( - filename - or ( - re.search(r"^(.*\.\w+)", self.url.split("/")[-1]) - or re.search(r"^(.*)\?", self.url.split("/")[-1]) - ).group() - ) + filename = re.search(r"^(.*\.\w+)", filename).group() except AttributeError: - filename = self.url.split("/")[-1] + if verbose: + logger.warning(f"Invalid filename: {filename}") + if skip_invalid_filename: + return None async with AsyncClient(follow_redirects=True, cookies=cookies) as client: response = await client.get(self.url) @@ -80,6 +86,8 @@ async def save( if verbose: logger.info(f"Image saved as {dest.resolve()}") + + return dest.resolve() else: raise HTTPError( f"Error downloading image: {response.status_code} {response.reason_phrase}" @@ -110,129 +118,28 @@ class GeneratedImage(Image): @field_validator("cookies") @classmethod def validate_cookies(cls, v: dict) -> dict: - if "__Secure-1PSID" not in v or "__Secure-1PSIDTS" not in v: + if len(v) == 0: raise ValueError( - "Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'" + "GeneratedImage is designed to be initiated with same cookies as GeminiClient." ) return v # @override - async def save(self, path: str = "temp/", filename: str = None) -> None: + async def save(self, **kwargs) -> None: """ Save the image to disk. Parameters ---------- - path: `str` - Path to save the image filename: `str`, optional Filename to save the image, generated images are always in .png format, but file extension will not be included in the URL. And since the URL ends with a long hash, by default will use timestamp + end of the hash as the filename + **kwargs: `dict`, optional + Other arguments to pass to `Image.save` """ await super().save( - path, - filename + filename=kwargs.pop("filename", None) or f"{datetime.now().strftime('%Y%m%d%H%M%S')}_{self.url[-10:]}.png", - self.cookies, + cookies=self.cookies, + **kwargs, ) - - -class Candidate(BaseModel): - """ - A single reply candidate object in the model output. A full response from Gemini usually contains multiple reply candidates. - - Parameters - ---------- - rcid: `str` - Reply candidate ID to build the metadata - text: `str` - Text output - web_images: `list[WebImage]`, optional - List of web images in reply, can be empty. - generated_images: `list[GeneratedImage]`, optional - List of generated images in reply, can be empty - """ - - rcid: str - text: str - web_images: list[WebImage] = [] - generated_images: list[GeneratedImage] = [] - - def __str__(self): - return self.text - - def __repr__(self): - return f"Candidate(rcid='{self.rcid}', text='{len(self.text) <= 20 and self.text or self.text[:20] + '...'}', images={self.images})" - - @property - def images(self) -> list[Image]: - return self.web_images + self.generated_images - - -class ModelOutput(BaseModel): - """ - Classified output from gemini.google.com - - Parameters - ---------- - metadata: `list[str]` - List of chat metadata `[cid, rid, rcid]`, can be shorter than 3 elements, like `[cid, rid]` or `[cid]` only - candidates: `list[Candidate]` - List of all candidates returned from gemini - chosen: `int`, optional - Index of the chosen candidate, by default will choose the first one - """ - - metadata: list[str] - candidates: list[Candidate] - chosen: int = 0 - - def __str__(self): - return self.text - - def __repr__(self): - return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})" - - @property - def text(self) -> str: - return self.candidates[self.chosen].text - - @property - def images(self) -> list[Image]: - return self.candidates[self.chosen].images - - @property - def rcid(self) -> str: - return self.candidates[self.chosen].rcid - - -class AuthError(Exception): - """ - Exception for authentication errors caused by invalid credentials/cookies. - """ - - pass - - -class APIError(Exception): - """ - Exception for package-level errors which need to be fixed in the future development (e.g. validation errors). - """ - - pass - - -class GeminiError(Exception): - """ - Exception for errors returned from Gemini server which are not handled by the package. - """ - - pass - - -class TimeoutError(GeminiError): - """ - Exception for request timeouts. - """ - - pass diff --git a/src/gemini/types/modeloutput.py b/src/gemini/types/modeloutput.py new file mode 100644 index 0000000..3ebb979 --- /dev/null +++ b/src/gemini/types/modeloutput.py @@ -0,0 +1,41 @@ +from pydantic import BaseModel + +from .image import Image +from .candidate import Candidate + + +class ModelOutput(BaseModel): + """ + Classified output from gemini.google.com + + Parameters + ---------- + metadata: `list[str]` + List of chat metadata `[cid, rid, rcid]`, can be shorter than 3 elements, like `[cid, rid]` or `[cid]` only + candidates: `list[Candidate]` + List of all candidates returned from gemini + chosen: `int`, optional + Index of the chosen candidate, by default will choose the first one + """ + + metadata: list[str] + candidates: list[Candidate] + chosen: int = 0 + + def __str__(self): + return self.text + + def __repr__(self): + return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})" + + @property + def text(self) -> str: + return self.candidates[self.chosen].text + + @property + def images(self) -> list[Image]: + return self.candidates[self.chosen].images + + @property + def rcid(self) -> str: + return self.candidates[self.chosen].rcid diff --git a/tests/test_client_features.py b/tests/test_client_features.py index 158a856..d4b53cb 100644 --- a/tests/test_client_features.py +++ b/tests/test_client_features.py @@ -67,7 +67,11 @@ async def test_reply_candidates(self): response = await chat.send_message( "What's the best Japanese dish? Recommend one only." ) - self.assertTrue(len(response.candidates) > 1) + + if len(response.candidates) == 1: + logger.debug(response.candidates[0]) + self.skipTest("Only one candidate was returned. Test skipped") + for candidate in response.candidates: logger.debug(candidate) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index 50861ca..44d8a58 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -22,7 +22,7 @@ async def test_save_web_image(self): self.assertTrue(response.images) for i, image in enumerate(response.images): self.assertTrue(image.url) - await image.save() + await image.save(verbose=True, skip_invalid_filename=True) async def test_save_generated_image(self): response = await self.geminiclient.generate_content( @@ -31,7 +31,7 @@ async def test_save_generated_image(self): self.assertTrue(response.images) for i, image in enumerate(response.images): self.assertTrue(image.url) - await image.save() + await image.save(verbose=True) if __name__ == "__main__":