diff --git a/src/gemini/client.py b/src/gemini/client.py index a58e3fb..07362e1 100644 --- a/src/gemini/client.py +++ b/src/gemini/client.py @@ -228,6 +228,7 @@ async def generate_content( url=image[0][3][3], title=f"[Generated Image {image[3][6]}]", alt=image[3][5][i], + cookies=self.cookies, ) for i, image in enumerate(candidate[12][7][0]) ] diff --git a/src/gemini/types.py b/src/gemini/types.py index 93597d8..999cae8 100644 --- a/src/gemini/types.py +++ b/src/gemini/types.py @@ -1,7 +1,9 @@ +import re from datetime import datetime -from pydantic import BaseModel, validator +from loguru import logger from httpx import AsyncClient, HTTPError +from pydantic import BaseModel, field_validator class Image(BaseModel): @@ -30,21 +32,33 @@ def __repr__(self): async def save( self, path: str = "temp/", filename: str = None, cookies: dict = None - ): + ) -> None: """ Save the image to disk. Parameters ---------- - path: `str` + path: `str`, optional Path to save the image filename: `str`, optional Filename to save the image, by default will use the original filename from the URL + cookies: `dict`, optional + Cookies used for requesting the content of the image """ + try: + filename = filename or re.search(r"^(.*\.\w+)", self.url.split("/")[-1]).group() + except AttributeError: + filename = self.url.split("/")[-1] + async with AsyncClient(follow_redirects=True, cookies=cookies) as client: response = await client.get(self.url) if response.status_code == 200: - filename = filename or self.url.split("/")[-1] + content_type = response.headers.get("content-type") + if content_type and "image" not in content_type: + logger.warning( + f"Content type of {filename} is not image, but {content_type}." + ) + with open(f"{path}{filename}", "wb") as file: file.write(response.content) else: @@ -71,15 +85,20 @@ class GeneratedImage(Image): Cookies used for requesting the content of the generated image, inherit from GeminiClient object or manually set. Must contain valid "__Secure-1PSID" and "__Secure-1PSIDTS" values """ - cookies: dict - @validator("cookies") - def validate_cookies(cls, value): - if "__Secure-1PSID" not in value or "__Secure-1PSIDTS" not in value: - raise ValueError("Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'") - return value + cookies: dict[str, str] - async def save(self, path: str = "temp/", filename: str = None): + @field_validator("cookies") + @classmethod + def validate_cookies(cls, v: dict) -> dict: + if "__Secure-1PSID" not in v or "__Secure-1PSIDTS" not in v: + raise ValueError( + "Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'" + ) + return v + + # @override + async def save(self, path: str = "temp/", filename: str = None) -> None: """ Save the image to disk. @@ -156,15 +175,15 @@ def __repr__(self): return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})" @property - def text(self): + def text(self) -> str: return self.candidates[self.chosen].text @property - def images(self): + def images(self) -> list[Image]: return self.candidates[self.chosen].images @property - def rcid(self): + def rcid(self) -> str: return self.candidates[self.chosen].rcid diff --git a/tests/test_save_image.py b/tests/test_save_image.py index 4ab266f..50861ca 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -17,7 +17,7 @@ async def asyncSetUp(self): async def test_save_web_image(self): response = await self.geminiclient.generate_content( - "Send me some pictures of cats" + "Send me 10 pictures of random subjects" ) self.assertTrue(response.images) for i, image in enumerate(response.images): @@ -26,7 +26,7 @@ async def test_save_web_image(self): async def test_save_generated_image(self): response = await self.geminiclient.generate_content( - "Generate some pictures of cats" + "Generate some pictures of random subjects" ) self.assertTrue(response.images) for i, image in enumerate(response.images):