Skip to content

Commit

Permalink
refactor(backend): Move credentials storage to prisma user (#8283)
Browse files Browse the repository at this point in the history
* feat(frontend,backend): testing

* feat: testing

* feat(backend): it works for reading email

* feat(backend): more docs on google

* fix(frontend,backend): formatting

* feat(backend): more logigin (i know this should be debug)

* feat(backend): make real the default scopes

* feat(backend): tests and linting

* fix: code review prep

* feat: sheets block

* feat: liniting

* Update route.ts

* Update autogpt_platform/backend/backend/integrations/oauth/google.py

Co-authored-by: Reinier van der Leer <[email protected]>

* Update autogpt_platform/backend/backend/server/routers/integrations.py

Co-authored-by: Reinier van der Leer <[email protected]>

* fix: revert opener change

* feat(frontend): add back opener

required to work on mac edge

* feat(frontend): drop typing list import from gmail

* fix: code review comments

* feat: code review changes

* feat: code review changes

* fix(backend): move from asserts to checks so they don't get optimized away in the future

* fix(backend): code review changes

* fix(backend): remove google specific check

* fix: add typing

* fix: only enable google blocks when oauth is configured for google

* fix: errors are real and valid outputs always when output

* fix(backend): add provider detail for debuging scope declines

* Update autogpt_platform/frontend/src/components/integrations/credentials-input.tsx

Co-authored-by: Reinier van der Leer <[email protected]>

* fix(frontend): enhance with comment, typeof error isn't known so this is best way to ensure the stringifyication will work

* feat: code review change requests

* fix: linting

* fix: reduce error catching

* fix: doc messages in code

* fix: check the correct scopes object 😄

* fix: remove double (and not needed) try catch

* fix: lint

* fix: scopes

* feat: handle the default scopes better

* feat: better email objectification

* feat: process attachements

turns out an email doesn't need a body

* fix: lint

* Update google.py

* Update autogpt_platform/backend/backend/data/block.py

Co-authored-by: Reinier van der Leer <[email protected]>

* fix: quit trying and except failure

* Update autogpt_platform/backend/backend/server/routers/integrations.py

Co-authored-by: Reinier van der Leer <[email protected]>

* feat: don't allow expired states

* fix: clarify function name and purpose

* feat: code links updates

* feat: additional docs on adding a block

* fix: type hint missing which means the block won't work

* fix: linting

* fix: docs formatting

* Update issues.py

* fix: improve the naming

* fix: formatting

* Update new_blocks.md

* Update new_blocks.md

* feat: better docs on what the args mean

* feat: more details on yield

* Update new_blocks.md

* fix: remove ignore from docs build

* feat: initial migration

* feat: migration tested with supabase-> prisma data location

* add custom migrations and script

* update migration command

* formatting and linting

* updated migration script

* add direct db url

* add find files

* rename

* use binary instead of source

* temp adding supabase

* remove unused functions

* adding missed merge

* fix: commit hash for lock

* ci: fix lint

* fix: minor bugs that prevented connecting and migrating to dbs and auth

* fix: linting

* fix: missed await

* fix(backend): phase one pr updates

* fix: handle error with returning user object from database_manager

* fix: linting

* Address comments

* Make the migration safe

* Update migration doc

* Move misplaced model functions

* Grammar

* Revert lock

* Remove irrelevant changes

* Remove irrelevant changes

* Avoid adding trigger on public schema

---------

Co-authored-by: Reinier van der Leer <[email protected]>
Co-authored-by: Zamil Majdy <[email protected]>
Co-authored-by: Aarushi <[email protected]>
Co-authored-by: Aarushi <[email protected]>
  • Loading branch information
5 people authored Oct 22, 2024
1 parent 5e386fd commit 1622a4a
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from redis import Redis
from supabase import Client
from backend.executor.database import DatabaseManager

from autogpt_libs.utils.synchronize import RedisKeyedMutex

Expand All @@ -18,8 +18,8 @@


class SupabaseIntegrationCredentialsStore:
def __init__(self, supabase: "Client", redis: "Redis"):
self.supabase = supabase
def __init__(self, redis: "Redis", db: "DatabaseManager"):
self.db_manager: DatabaseManager = db
self.locks = RedisKeyedMutex(redis)

def add_creds(self, user_id: str, credentials: Credentials) -> None:
Expand All @@ -35,7 +35,9 @@ def add_creds(self, user_id: str, credentials: Credentials) -> None:

def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(user_metadata).integration_credentials
return UserMetadata.model_validate(
user_metadata.model_dump()
).integration_credentials

def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
all_credentials = self.get_all_creds(user_id)
Expand Down Expand Up @@ -90,9 +92,7 @@ def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
]
self._set_user_integration_creds(user_id, filtered_credentials)

async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
) -> str:
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)

Expand All @@ -105,17 +105,17 @@ async def store_state_token(

with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states
oauth_states.append(state.model_dump())
user_metadata["integration_oauth_states"] = oauth_states
user_metadata.integration_oauth_states = oauth_states

self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
self.db_manager.update_user_metadata(
user_id=user_id, metadata=user_metadata
)

return token

async def get_any_valid_scopes_from_state_token(
def get_any_valid_scopes_from_state_token(
self, user_id: str, token: str, provider: str
) -> list[str]:
"""
Expand All @@ -126,7 +126,7 @@ async def get_any_valid_scopes_from_state_token(
THE CODE FOR TOKENS.
"""
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states

now = datetime.now(timezone.utc)
valid_state = next(
Expand All @@ -145,10 +145,10 @@ async def get_any_valid_scopes_from_state_token(

return []

async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
with self.locked_user_metadata(user_id):
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.get("integration_oauth_states", [])
oauth_states = user_metadata.integration_oauth_states

now = datetime.now(timezone.utc)
valid_state = next(
Expand All @@ -165,10 +165,8 @@ async def verify_state_token(self, user_id: str, token: str, provider: str) -> b
if valid_state:
# Remove the used state
oauth_states.remove(valid_state)
user_metadata["integration_oauth_states"] = oauth_states
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": user_metadata}
)
user_metadata.integration_oauth_states = oauth_states
self.db_manager.update_user_metadata(user_id, user_metadata)
return True

return False
Expand All @@ -177,19 +175,13 @@ def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]
) -> None:
raw_metadata = self._get_user_metadata(user_id)
raw_metadata.update(
{"integration_credentials": [c.model_dump() for c in credentials]}
)
self.supabase.auth.admin.update_user_by_id(
user_id, {"user_metadata": raw_metadata}
)
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
self.db_manager.update_user_metadata(user_id, raw_metadata)

def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
response = self.supabase.auth.admin.get_user_by_id(user_id)
if not response.user:
raise ValueError(f"User with ID {user_id} not found")
return cast(UserMetadataRaw, response.user.user_metadata)
metadata: UserMetadataRaw = self.db_manager.get_user_metadata(user_id=user_id)
return metadata

def locked_user_metadata(self, user_id: str):
key = (self.supabase.supabase_url, f"user:{user_id}", "metadata")
key = (self.db_manager, f"user:{user_id}", "metadata")
return self.locks.locked(key)
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class OAuthState(BaseModel):
token: str
provider: str
expires_at: int
scopes: list[str]
"""Unix timestamp (seconds) indicating when this OAuth state expires"""


Expand All @@ -64,6 +65,6 @@ class UserMetadata(BaseModel):
integration_oauth_states: list[OAuthState] = Field(default_factory=list)


class UserMetadataRaw(TypedDict, total=False):
integration_credentials: list[dict]
integration_oauth_states: list[dict]
class UserMetadataRaw(BaseModel):
integration_credentials: list[dict] = Field(default_factory=list)
integration_oauth_states: list[dict] = Field(default_factory=list)
2 changes: 1 addition & 1 deletion autogpt_platform/backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ WORKDIR /app

# Install build dependencies
RUN apt-get update \
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev gettext libz-dev libssl-dev postgresql-client git \
&& apt-get install -y build-essential curl ffmpeg wget libcurl4-gnutls-dev libexpat1-dev libpq5 gettext libz-dev libssl-dev postgresql-client git \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

Expand Down
4 changes: 2 additions & 2 deletions autogpt_platform/backend/README.advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes
5. Generate the Prisma client

```sh
poetry run prisma generate --schema postgres/schema.prisma
poetry run prisma generate
```


Expand All @@ -61,7 +61,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes

```sh
cd ../backend
prisma migrate dev --schema postgres/schema.prisma
prisma migrate deploy
```

## Running The Server
Expand Down
2 changes: 1 addition & 1 deletion autogpt_platform/backend/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ We use the Poetry to manage the dependencies. To set up the project, follow thes

```sh
docker compose up db redis -d
poetry run prisma migrate dev
poetry run prisma migrate deploy
```

## Running The Server
Expand Down
20 changes: 20 additions & 0 deletions autogpt_platform/backend/backend/data/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional

from autogpt_libs.supabase_integration_credentials_store.types import UserMetadataRaw
from fastapi import HTTPException
from prisma import Json
from prisma.models import User

from backend.data.db import prisma
Expand Down Expand Up @@ -48,3 +50,21 @@ async def create_default_user(enable_auth: str) -> Optional[User]:
)
return User.model_validate(user)
return None


async def get_user_metadata(user_id: str) -> UserMetadataRaw:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return (
UserMetadataRaw.model_validate(user.metadata)
if user.metadata
else UserMetadataRaw()
)


async def update_user_metadata(user_id: str, metadata: UserMetadataRaw):
await User.prisma().update(
where={"id": user_id},
data={"metadata": Json(metadata.model_dump())},
)
5 changes: 5 additions & 0 deletions autogpt_platform/backend/backend/executor/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from backend.data.graph import get_graph, get_node
from backend.data.queue import RedisEventQueue
from backend.data.user import get_user_metadata, update_user_metadata
from backend.util.service import AppService, expose
from backend.util.settings import Config

Expand Down Expand Up @@ -73,3 +74,7 @@ def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
Callable[[Any, str, int, str, dict[str, str], float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)

# User + User Metadata
get_user_metadata = exposed_run_and_wait(get_user_metadata)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
9 changes: 5 additions & 4 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.cache import thread_cached_property
from backend.util.cache import thread_cached
from backend.util.decorator import error_logged, time_measured
from backend.util.logging import configure_logging
from backend.util.process import set_service_name
Expand Down Expand Up @@ -417,7 +417,7 @@ def on_node_executor_start(cls):
redis.connect()
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
cls.creds_manager = IntegrationCredentialsManager(db_manager=cls.db_client)

# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
Expand Down Expand Up @@ -670,7 +670,7 @@ def run_service(self):
)

self.credentials_store = SupabaseIntegrationCredentialsStore(
self.supabase, redis.get_redis()
redis=redis.get_redis(), db=self.db_client
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
Expand Down Expand Up @@ -701,7 +701,7 @@ def cleanup(self):

super().cleanup()

@thread_cached_property
@property
def db_client(self) -> "DatabaseManager":
return get_db_client()

Expand Down Expand Up @@ -857,6 +857,7 @@ def _validate_node_input_credentials(self, graph: Graph, user_id: str):
# ------- UTILITIES ------- #


@thread_cached
def get_db_client() -> "DatabaseManager":
from backend.executor import DatabaseManager

Expand Down
11 changes: 6 additions & 5 deletions autogpt_platform/backend/backend/integrations/creds_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from redis.lock import Lock as RedisLock

from backend.data import redis
from backend.executor.database import DatabaseManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings

from ..server.integrations.utils import get_supabase

logger = logging.getLogger(__name__)
settings = Settings()

Expand Down Expand Up @@ -51,10 +50,12 @@ class IntegrationCredentialsManager:
cause so much latency that it's worth implementing.
"""

def __init__(self):
def __init__(self, db_manager: DatabaseManager):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(get_supabase(), redis_conn)
self.store = SupabaseIntegrationCredentialsStore(
redis=redis_conn, db=db_manager
)

def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)
Expand Down Expand Up @@ -131,7 +132,7 @@ def delete(self, user_id: str, credentials_id: str) -> None:

def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.supabase.supabase_url,
self.store.db_manager,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
Expand Down
10 changes: 6 additions & 4 deletions autogpt_platform/backend/backend/server/integrations/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
from pydantic import BaseModel, Field, SecretStr

from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings
Expand All @@ -19,7 +20,8 @@
logger = logging.getLogger(__name__)
settings = Settings()
router = APIRouter()
creds_manager = IntegrationCredentialsManager()

creds_manager = IntegrationCredentialsManager(db_manager=get_db_client())


class LoginResponse(BaseModel):
Expand All @@ -41,7 +43,7 @@ async def login(
requested_scopes = scopes.split(",") if scopes else []

# Generate and store a secure random state token along with the scopes
state_token = await creds_manager.store.store_state_token(
state_token = creds_manager.store.store_state_token(
user_id, provider, requested_scopes
)

Expand Down Expand Up @@ -70,12 +72,12 @@ async def callback(
handler = _get_provider_oauth_handler(request, provider)

# Verify the state token
if not await creds_manager.store.verify_state_token(user_id, state_token, provider):
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
logger.warning(f"Invalid or expired state token for user {user_id}")
raise HTTPException(status_code=400, detail="Invalid or expired state token")

try:
scopes = await creds_manager.store.get_any_valid_scopes_from_state_token(
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
user_id, state_token, provider
)
logger.debug(f"Retrieved scopes from state token: {scopes}")
Expand Down
3 changes: 2 additions & 1 deletion autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from backend.data.credit import get_block_costs, get_user_credit_model
from backend.data.user import get_or_create_user
from backend.executor import ExecutionManager, ExecutionScheduler
from backend.executor.manager import get_db_client
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.model import CreateGraph, SetGraphActiveVersion
from backend.util.cache import thread_cached_property
Expand Down Expand Up @@ -97,7 +98,7 @@ def run_service(self):
tags=["integrations"],
dependencies=[Depends(auth_middleware)],
)
self.integration_creds_manager = IntegrationCredentialsManager()
self.integration_creds_manager = IntegrationCredentialsManager(get_db_client())

api_router.include_router(
backend.server.routers.analytics.router,
Expand Down
Loading

0 comments on commit 1622a4a

Please sign in to comment.