From b35d7906a9e92972ddab51810eebb3961f58a81e Mon Sep 17 00:00:00 2001 From: UZQueen <157540577+HanaokaYuzu@users.noreply.github.com> Date: Sun, 11 Feb 2024 15:40:13 -0600 Subject: [PATCH] chore: fix basic functionalities, add basic unit test, update build info --- .vscode/settings.json | 11 +++ pyproject.toml | 1 + src/gemini/client.py | 138 ++++++++++++++++++++++++--------- src/gemini/utils.py | 13 ---- tests/test_generate_content.py | 27 +++++++ 5 files changed, 139 insertions(+), 51 deletions(-) create mode 100644 .vscode/settings.json delete mode 100644 src/gemini/utils.py create mode 100644 tests/test_generate_content.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..7f00b59 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} diff --git a/pyproject.toml b/pyproject.toml index 15c6884..443bbea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ] requires-python = ">=3.7" dependencies = [ diff --git a/src/gemini/client.py b/src/gemini/client.py index 4042f7d..42981f7 100644 --- a/src/gemini/client.py +++ b/src/gemini/client.py @@ -8,10 +8,29 @@ from loguru import logger from .consts import HEADERS -from .utils import running from .types import Image, Candidate, ModelOutput +def running(func) -> callable: + """ + Decorator to check if client is running before making a request. + """ + + async def wrapper(self: "GeminiClient", *args, **kwargs): + if not self.running: + await self.init(auto_close=self.auto_close, close_delay=self.close_delay) + if self.running: + return await func(self, *args, **kwargs) + + raise Exception( + f"Invalid function call: GeminiClient.{func.__name__}. Client initialization failed." + ) + else: + return await func(self, *args, **kwargs) + + return wrapper + + class GeminiClient: """ Async httpx client interface for gemini.google.com @@ -26,7 +45,16 @@ class GeminiClient: Dict of proxies """ - __slots__ = ["running", "posttoken", "close_task", "client"] + __slots__ = [ + "cookies", + "proxy", + "client", + "access_token", + "running", + "auto_close", + "close_delay", + "close_task", + ] def __init__( self, @@ -34,48 +62,80 @@ def __init__( secure_1psidts: Optional[str] = None, proxy: Optional[dict] = None, ): + self.cookies = { + "__Secure-1PSID": secure_1psid, + "__Secure-1PSIDTS": secure_1psidts, + } + self.proxy = proxy + self.client: AsyncClient | None = None + self.access_token: Optional[str] = None self.running: bool = False - self.posttoken: Optional[str] = None - self.close_task: Optional[Task] = None - self.client: AsyncClient = AsyncClient( - timeout=20, - proxies=proxy, - follow_redirects=True, - headers=HEADERS, - cookies={ - "__Secure-1PSID": secure_1psid, - "__Secure-1PSIDTS": secure_1psidts, - }, - ) + self.auto_close: bool = False + self.close_delay: int = 0 + self.close_task: Task | None = None - async def init(self) -> None: - """ - Get SNlM0e value as posting token. Without this token posting will fail with 400 bad request. + async def init( + self, timeout: float = 30, auto_close: bool = False, close_delay: int = 300 + ) -> None: """ - async with self.client: - response = await self.client.get("https://gemini.google.com/chat") + Get SNlM0e value as access token. Without this token posting will fail with 400 bad request. - if response.status_code != 200: - raise Exception( - f"Failed to initiate client. Request failed with status code {response.status_code}" + Parameters + ---------- + timeout: `int`, optional + Request timeout of the client in seconds. Used to limit the max waiting time when sending a request + auto_close: `bool`, optional + If `True`, the client will close connections and clear resource usage after a certain period + of inactivity. Useful for keep-alive services + close_delay: `int`, optional + Time to wait before auto-closing the client in seconds. Effective only if `auto_close` is `True` + """ + try: + self.client = AsyncClient( + timeout=timeout, + proxies=self.proxy, + follow_redirects=True, + headers=HEADERS, + cookies=self.cookies, ) - else: - match = re.search(r'"SNlM0e":"(.*?)"', response.text) - if match: - self.posttoken = match.group(1) - self.running = True - logger.success("Gemini client initiated successfully.") - else: + + response = await self.client.get("https://gemini.google.com/app") + + if response.status_code != 200: raise Exception( - "Failed to initiate client. SNlM0e not found in response, make sure cookie values are valid." + f"Failed to initiate client. Request failed with status code {response.status_code}" ) - - async def close_client(self, timeout=300) -> None: + else: + match = re.search(r'"SNlM0e":"(.*?)"', response.text) + if match: + self.access_token = match.group(1) + self.running = True + logger.success("Gemini client initiated successfully.") + else: + raise Exception( + "Failed to initiate client. SNlM0e not found in response, make sure cookie values are valid." + ) + + self.auto_close = auto_close + self.close_delay = close_delay + if self.auto_close: + await self.reset_close_task() + except Exception: + await self.close(0) + raise + + async def close(self, wait: int | None = None) -> None: """ - Close the client after a certain period of inactivity. + Close the client after a certain period of inactivity, or call manually to close immediately. + + Parameters + ---------- + wait: `int`, optional + Time to wait before closing the client in seconds """ - await asyncio.sleep(timeout) + await asyncio.sleep(wait is not None and wait or self.close_delay) await self.client.aclose() + self.running = False async def reset_close_task(self) -> None: """ @@ -84,7 +144,7 @@ async def reset_close_task(self) -> None: if self.close_task: self.close_task.cancel() self.close_task = None - self.close_task = asyncio.create_task(self.close_client()) + self.close_task = asyncio.create_task(self.close()) @running async def generate_content( @@ -108,11 +168,13 @@ async def generate_content( """ assert prompt, "Prompt cannot be empty." - await self.reset_close_task() + if self.auto_close: + await self.reset_close_task() + response = await self.client.post( - "https://gemini.google.com/_/GeminiChatUi/data/assistant.lamda.GeminiFrontendService/StreamGenerate", + "https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate", data={ - "at": self.posttoken, + "at": self.access_token, "f.req": json.dumps( [None, json.dumps([[prompt], None, chat and chat.metadata])] ), diff --git a/src/gemini/utils.py b/src/gemini/utils.py deleted file mode 100644 index bfdf9b3..0000000 --- a/src/gemini/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -def running(func): - """ - Decorator to check if client is running before making a request. - """ - - async def wrapper(self, *args, **kwargs): - if not self.running: - raise Exception( - f"Invalid function call: GeminiClient.{func.__name__}. Client is not running. Re-initiate client to try again." - ) - return await func(self, *args, **kwargs) - - return wrapper diff --git a/tests/test_generate_content.py b/tests/test_generate_content.py new file mode 100644 index 0000000..5bc1918 --- /dev/null +++ b/tests/test_generate_content.py @@ -0,0 +1,27 @@ +import os +import unittest + +from gemini import GeminiClient + + +class TestGenerateContent(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.geminiclient = GeminiClient( + os.getenv("SECURE_1PSID") or "test_1psid", + os.getenv("SECURE_1PSIDTS") or "test_ipsidts", + ) + + @unittest.skipIf( + not (os.getenv("SECURE_1PSID") and os.getenv("SECURE_1PSIDTS")), + "Skipping test_success...", + ) + async def test_success(self): + await self.geminiclient.init() + self.assertTrue(self.geminiclient.running) + + response = await self.geminiclient.generate_content("Hello World!") + self.assertTrue(response.text) + + +if __name__ == "__main__": + unittest.main()