Skip to content

Commit

Permalink
fix(image): fix validation error, improve image saving function
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed Feb 16, 2024
1 parent 2c035e0 commit 5b78c88
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ async def generate_content(
url=image[0][3][3],
title=f"[Generated Image {image[3][6]}]",
alt=image[3][5][i],
cookies=self.cookies,
)
for i, image in enumerate(candidate[12][7][0])
]
Expand Down
47 changes: 33 additions & 14 deletions src/gemini/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
from datetime import datetime

from pydantic import BaseModel, validator
from loguru import logger
from httpx import AsyncClient, HTTPError
from pydantic import BaseModel, field_validator


class Image(BaseModel):
Expand Down Expand Up @@ -30,21 +32,33 @@ def __repr__(self):

async def save(
self, path: str = "temp/", filename: str = None, cookies: dict = None
):
) -> None:
"""
Save the image to disk.
Parameters
----------
path: `str`
path: `str`, optional
Path to save the image
filename: `str`, optional
Filename to save the image, by default will use the original filename from the URL
cookies: `dict`, optional
Cookies used for requesting the content of the image
"""
try:
filename = filename or re.search(r"^(.*\.\w+)", self.url.split("/")[-1]).group()
except AttributeError:
filename = self.url.split("/")[-1]

async with AsyncClient(follow_redirects=True, cookies=cookies) as client:
response = await client.get(self.url)
if response.status_code == 200:
filename = filename or self.url.split("/")[-1]
content_type = response.headers.get("content-type")
if content_type and "image" not in content_type:
logger.warning(
f"Content type of {filename} is not image, but {content_type}."
)

with open(f"{path}{filename}", "wb") as file:
file.write(response.content)
else:
Expand All @@ -71,15 +85,20 @@ class GeneratedImage(Image):
Cookies used for requesting the content of the generated image, inherit from GeminiClient object or manually set.
Must contain valid "__Secure-1PSID" and "__Secure-1PSIDTS" values
"""
cookies: dict

@validator("cookies")
def validate_cookies(cls, value):
if "__Secure-1PSID" not in value or "__Secure-1PSIDTS" not in value:
raise ValueError("Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'")
return value
cookies: dict[str, str]

async def save(self, path: str = "temp/", filename: str = None):
@field_validator("cookies")
@classmethod
def validate_cookies(cls, v: dict) -> dict:
if "__Secure-1PSID" not in v or "__Secure-1PSIDTS" not in v:
raise ValueError(
"Cookies must contain '__Secure-1PSID' and '__Secure-1PSIDTS'"
)
return v

# @override
async def save(self, path: str = "temp/", filename: str = None) -> None:
"""
Save the image to disk.
Expand Down Expand Up @@ -156,15 +175,15 @@ def __repr__(self):
return f"ModelOutput(metadata={self.metadata}, chosen={self.chosen}, candidates={self.candidates})"

@property
def text(self):
def text(self) -> str:
return self.candidates[self.chosen].text

@property
def images(self):
def images(self) -> list[Image]:
return self.candidates[self.chosen].images

@property
def rcid(self):
def rcid(self) -> str:
return self.candidates[self.chosen].rcid


Expand Down
4 changes: 2 additions & 2 deletions tests/test_save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def asyncSetUp(self):

async def test_save_web_image(self):
response = await self.geminiclient.generate_content(
"Send me some pictures of cats"
"Send me 10 pictures of random subjects"
)
self.assertTrue(response.images)
for i, image in enumerate(response.images):
Expand All @@ -26,7 +26,7 @@ async def test_save_web_image(self):

async def test_save_generated_image(self):
response = await self.geminiclient.generate_content(
"Generate some pictures of cats"
"Generate some pictures of random subjects"
)
self.assertTrue(response.images)
for i, image in enumerate(response.images):
Expand Down

0 comments on commit 5b78c88

Please sign in to comment.