Skip to content

Commit

Permalink
fix: minor bugs that prevented connecting and migrating to dbs and auth
Browse files Browse the repository at this point in the history
  • Loading branch information
ntindle committed Oct 17, 2024
1 parent f04f328 commit 8c0dd41
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis"):
self.prisma: Prisma = db.prisma
self.prisma: Prisma = Prisma()
self.locks = RedisKeyedMutex(redis)

async def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):
if self.get_creds_by_id(user_id, credentials.id):
if await self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
Expand All @@ -43,7 +43,9 @@ async def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata.model_dump()
).integration_credentials

async def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
async def get_creds_by_id(
self, user_id: str, credentials_id: str
) -> Credentials | None:
all_credentials = await self.get_all_creds(user_id)
return next((c for c in all_credentials if c.id == credentials_id), None)

Expand Down Expand Up @@ -117,8 +119,11 @@ async def store_state_token(
oauth_states.append(state.model_dump())
user_metadata.integration_oauth_states = oauth_states

if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id}, data={"metadata": Json(user_metadata.model_dump())}
where={"id": user_id},
data={"metadata": Json(user_metadata.model_dump())},
)

return token
Expand Down Expand Up @@ -174,6 +179,8 @@ async def verify_state_token(self, user_id: str, token: str, provider: str) -> b
# Remove the used state
oauth_states.remove(valid_state)
user_metadata.integration_oauth_states = oauth_states
if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id},
data={"metadata": Json(user_metadata.model_dump())},
Expand All @@ -187,11 +194,15 @@ async def _set_user_integration_creds(
) -> None:
raw_metadata = await self._get_user_metadata(user_id)
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id}, data={"metadata": Json(raw_metadata.model_dump())}
)

async def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
if not self.prisma.is_connected():
await self.prisma.connect()
user = await self.prisma.user.find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User with ID {user_id} not found")
Expand All @@ -202,5 +213,5 @@ async def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
)

def locked_user_metadata(self, user_id: str):
key = ("usermetadatalock", f"user:{user_id}", "metadata")
key = (self.prisma, f"user:{user_id}", "metadata")
return self.locks.locked(key)
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def delete(self, user_id: str, credentials_id: str) -> None:

def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
"usermetadatalock",
self.store.prisma,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
Expand Down
5 changes: 5 additions & 0 deletions autogpt_platform/backend/backend/util/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="What environment to behave as: local or cloud",
)

direct_database_url: str = Field(
default="",
description="The URL for the direct database. This is used to run migrations.",
)

backend_cors_allow_origins: List[str] = Field(default_factory=list)

@field_validator("backend_cors_allow_origins")
Expand Down
7 changes: 6 additions & 1 deletion autogpt_platform/backend/migrate/run_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import subprocess

from backend.util.settings import Settings
import psycopg2


Expand All @@ -17,7 +18,11 @@ def run_prisma_migrations():


def run_custom_migrations():
db_url = os.environ.get("DIRECT_DATABASE_URL")
settings = Settings()
db_url = settings.config.direct_database_url
if not db_url:
logging.error("No Direct DB URL Provided")
return
conn = psycopg2.connect(db_url)

try:
Expand Down
1 change: 1 addition & 0 deletions autogpt_platform/backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ cli = "backend.cli:main"
format = "linter:format"
lint = "linter:lint"
test = "run_tests:test"
migrate = "migrate.run_migrations:main"
migrate_db = "migrate.run_migrations:main"

[tool.pytest-watcher]
Expand Down

0 comments on commit 8c0dd41

Please sign in to comment.