Skip to content

Commit

Permalink
adding missed merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ntindle committed Oct 17, 2024
1 parent 55a085c commit 62134f4
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions autogpt_platform/backend/backend/integrations/creds_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from contextlib import contextmanager
from datetime import datetime
Expand Down Expand Up @@ -54,18 +55,18 @@ class IntegrationCredentialsManager:
def __init__(self):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
self.store = SupabaseIntegrationCredentialsStore(redis_conn)

def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
async def create(self, user_id: str, credentials: Credentials) -> None:
return await self.store.add_creds(user_id, credentials)

def exists(self, user_id: str, credentials_id: str) -> bool:
async def exists(self, user_id: str, credentials_id: str) -> bool:
return self.store.get_creds_by_id(user_id, credentials_id) is not None

def get(
async def get(
self, user_id: str, credentials_id: str, lock: bool = True
) -> Credentials | None:
credentials = self.store.get_creds_by_id(user_id, credentials_id)
credentials = await self.store.get_creds_by_id(user_id, credentials_id)
if not credentials:
return None

Expand All @@ -90,7 +91,7 @@ def get(
_lock = self._acquire_lock(user_id, credentials_id)

fresh_credentials = oauth_handler.refresh_tokens(credentials)
self.store.update_creds(user_id, fresh_credentials)
await self.store.update_creds(user_id, fresh_credentials)
if _lock:
_lock.release()

Expand All @@ -112,26 +113,26 @@ def acquire(
# to allow priority access for refreshing/updating the tokens.
with self._locked(user_id, credentials_id, "!time_sensitive"):
lock = self._acquire_lock(user_id, credentials_id)
credentials = self.get(user_id, credentials_id, lock=False)
credentials = asyncio.run(self.get(user_id, credentials_id, lock=False))
if not credentials:
raise ValueError(
f"Credentials #{credentials_id} for user #{user_id} not found"
)
return credentials, lock

def update(self, user_id: str, updated: Credentials) -> None:
async def update(self, user_id: str, updated: Credentials) -> None:
with self._locked(user_id, updated.id):
self.store.update_creds(user_id, updated)
await self.store.update_creds(user_id, updated)

def delete(self, user_id: str, credentials_id: str) -> None:
async def delete(self, user_id: str, credentials_id: str) -> None:
with self._locked(user_id, credentials_id):
self.store.delete_creds_by_id(user_id, credentials_id)
await self.store.delete_creds_by_id(user_id, credentials_id)

# -- Locking utilities -- #

def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.supabase.supabase_url,
"usermetadatalock",
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
Expand Down

0 comments on commit 62134f4

Please sign in to comment.