From c58db163256e51a64a11868eff0a4b052a062b51 Mon Sep 17 00:00:00 2001 From: Arash Date: Mon, 16 Dec 2024 11:18:07 +0100 Subject: [PATCH] partial update for credentials api to match the new changes --- lib/galaxy/schema/credentials.py | 141 ++++-- lib/galaxy/webapps/galaxy/api/credentials.py | 118 ++--- .../webapps/galaxy/services/credentials.py | 450 ++++++++++-------- 3 files changed, 420 insertions(+), 289 deletions(-) diff --git a/lib/galaxy/schema/credentials.py b/lib/galaxy/schema/credentials.py index aa6420bc866b..408744022d8f 100644 --- a/lib/galaxy/schema/credentials.py +++ b/lib/galaxy/schema/credentials.py @@ -1,11 +1,15 @@ from enum import Enum -from typing import List +from typing import ( + Dict, + List, + Optional, +) from pydantic import ( - BaseModel, Field, RootModel, ) +from typing_extensions import Literal from galaxy.schema.fields import ( DecodedDatabaseIdField, @@ -13,6 +17,8 @@ ) from galaxy.schema.schema import Model +SOURCE_TYPE = Literal["tool"] + class CredentialType(str, Enum): secret = "secret" @@ -30,33 +36,98 @@ class CredentialResponse(Model): title="Credential Name", description="Name of the credential", ) - type: CredentialType = Field( + + +class VariableResponse(CredentialResponse): + value: Optional[str] = Field( + None, + title="Value", + description="Value of the credential", + ) + + +class SecretResponse(CredentialResponse): + already_set: bool = Field( ..., - title="Type", - description="Type of the credential", + title="Already Set", + description="Whether the secret is already set", ) -class CredentialsListResponse(Model): - service_reference: str = Field( +class CredentialGroupResponse(Model): + id: EncodedDatabaseIdField = Field( ..., - title="Service Reference", - description="Reference to the service", + title="Group ID", + description="ID of the group", + ) + name: str = Field( + ..., + title="Group Name", + description="Name of the group", + ) + variables: List[VariableResponse] = Field( + ..., + title="Variables", + description="List of variables", ) - user_credentials_id: EncodedDatabaseIdField = Field( + secrets: List[SecretResponse] = Field( + ..., + title="Secrets", + description="List of secrets", + ) + + +class UserCredentialBaseResponse(Model): + user_id: EncodedDatabaseIdField = Field( + ..., + title="User ID", + description="ID of the user", + ) + id: EncodedDatabaseIdField = Field( ..., title="User Credentials ID", description="ID of the user credentials", ) - credentials: List[CredentialResponse] = Field( + source_type: SOURCE_TYPE = Field( ..., - title="Credentials", - description="List of credentials", + title="Source Type", + description="Type of the source", + ) + source_id: str = Field( + ..., + title="Source ID", + description="ID of the source", + ) + reference: str = Field( + ..., + title="Service Reference", + description="Reference to the service", + ) + current_group_name: str = Field( + ..., + title="Current Group Name", + description="Name of the current group", + ) + + +class UserCredentialsResponse(UserCredentialBaseResponse): + groups: Dict[str, CredentialGroupResponse] = Field( + ..., + title="Groups", + description="Groups of credentials", + ) + + +class UserCredentialCreateResponse(UserCredentialBaseResponse): + group: CredentialGroupResponse = Field( + ..., + title="Group", + description="Group of credentials", ) class UserCredentialsListResponse(RootModel): - root: List[CredentialsListResponse] = Field( + root: List[UserCredentialsResponse] = Field( ..., title="User Credentials", description="List of user credentials", @@ -82,11 +153,26 @@ class CredentialPayload(Model): class CredentialsPayload(Model): - service_reference: str = Field( + source_type: SOURCE_TYPE = Field( + ..., + title="Source Type", + description="Type of the source", + ) + source_id: str = Field( + ..., + title="Source ID", + description="ID of the source", + ) + reference: str = Field( ..., title="Service Reference", description="Reference to the service", ) + group_name: Optional[str] = Field( + "default", + title="Group Name", + description="Name of the group", + ) credentials: List[CredentialPayload] = Field( ..., title="Credentials", @@ -107,25 +193,14 @@ class UpdateCredentialPayload(Model): ) -class UpdateCredentialsPayload(BaseModel): - root: List[UpdateCredentialPayload] = Field( - ..., - title="Update Credentials", - description="List of credentials to update", - ) - - -class VerifyCredentialsResponse(Model): - exists: bool = Field( +class UpdateCredentialsPayload(Model): + group_id: DecodedDatabaseIdField = Field( ..., - title="Exists", - description="Indicates if the credentials exist", + title="Group ID", + description="ID of the group", ) - - -class DeleteCredentialsResponse(Model): - deleted: bool = Field( + credentials: List[UpdateCredentialPayload] = Field( ..., - title="Deleted", - description="Indicates if the credentials were deleted", + title="Update Credentials", + description="List of credentials to update", ) diff --git a/lib/galaxy/webapps/galaxy/api/credentials.py b/lib/galaxy/webapps/galaxy/api/credentials.py index 647ff5801103..a670e65194bc 100644 --- a/lib/galaxy/webapps/galaxy/api/credentials.py +++ b/lib/galaxy/webapps/galaxy/api/credentials.py @@ -5,16 +5,19 @@ import logging from typing import Optional -from fastapi import Query +from fastapi import ( + Query, + Response, + status, +) from galaxy.managers.context import ProvidesUserContext from galaxy.schema.credentials import ( - CredentialsListResponse, CredentialsPayload, - DeleteCredentialsResponse, + SOURCE_TYPE, UpdateCredentialsPayload, + UserCredentialCreateResponse, UserCredentialsListResponse, - VerifyCredentialsResponse, ) from galaxy.schema.fields import DecodedDatabaseIdField from galaxy.webapps.galaxy.api import ( @@ -42,7 +45,7 @@ def list_user_credentials( self, user_id: UserIdPathParam, trans: ProvidesUserContext = DependsOnTrans, - source_type: Optional[str] = Query( + source_type: Optional[SOURCE_TYPE] = Query( None, description="The type of source to filter by.", ), @@ -50,33 +53,12 @@ def list_user_credentials( None, description="The ID of the source to filter by.", ), + group_name: Optional[str] = Query( + None, + description="The name of the group to filter by.", + ), ) -> UserCredentialsListResponse: - return self.service.list_user_credentials(trans, user_id, source_type, source_id) - - @router.get( - "/api/users/{user_id}/credentials/{user_credentials_id}", - summary="Verifies if credentials have been provided for a specific service", - ) - def verify_service_credentials( - self, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - trans: ProvidesUserContext = DependsOnTrans, - ) -> VerifyCredentialsResponse: - return self.service.verify_service_credentials(trans, user_id, user_credentials_id) - - @router.get( - "/api/users/{user_id}/credentials/{user_credentials_id}/{credentials_id}", - summary="Verifies if a credential have been provided", - ) - def verify_credentials( - self, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - credentials_id: DecodedDatabaseIdField, - trans: ProvidesUserContext = DependsOnTrans, - ) -> VerifyCredentialsResponse: - return self.service.verify_credentials(trans, user_credentials_id, credentials_id) + return self.service.list_user_credentials(trans, user_id, source_type, source_id, group_name) @router.post( "/api/users/{user_id}/credentials", @@ -87,43 +69,45 @@ def provide_credential( user_id: UserIdPathParam, payload: CredentialsPayload, trans: ProvidesUserContext = DependsOnTrans, - ) -> CredentialsListResponse: + ) -> UserCredentialCreateResponse: return self.service.provide_credential(trans, user_id, payload) - @router.put( - "/api/users/{user_id}/credentials/{user_credentials_id}", - summary="Updates credentials for a specific secret/variable", - ) - def update_credential( - self, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - payload: UpdateCredentialsPayload, - trans: ProvidesUserContext = DependsOnTrans, - ) -> CredentialsListResponse: - return self.service.update_credential(trans, user_id, user_credentials_id, payload) + # @router.put( + # "/api/users/{user_id}/credentials/{user_credentials_id}", + # summary="Updates credentials for a specific secret/variable", + # ) + # def update_credential( + # self, + # user_id: UserIdPathParam, + # user_credentials_id: DecodedDatabaseIdField, + # payload: UpdateCredentialsPayload, + # trans: ProvidesUserContext = DependsOnTrans, + # ) -> CredentialsListResponse: + # return self.service.update_credential(trans, user_id, user_credentials_id, payload) - @router.delete( - "/api/users/{user_id}/credentials/{user_credentials_id}", - summary="Deletes all credentials for a specific service", - ) - def delete_service_credentials( - self, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - trans: ProvidesUserContext = DependsOnTrans, - ) -> DeleteCredentialsResponse: - return self.service.delete_service_credentials(trans, user_id, user_credentials_id) + # @router.delete( + # "/api/users/{user_id}/credentials/{user_credentials_id}", + # summary="Deletes all credentials for a specific service", + # ) + # def delete_service_credentials( + # self, + # user_id: UserIdPathParam, + # user_credentials_id: DecodedDatabaseIdField, + # trans: ProvidesUserContext = DependsOnTrans, + # ): + # self.service.delete_service_credentials(trans, user_id, user_credentials_id) + # return Response(status_code=status.HTTP_204_NO_CONTENT) - @router.delete( - "/api/users/{user_id}/credentials/{user_credentials_id}/{credentials_id}", - summary="Deletes a specific credential", - ) - def delete_credentials( - self, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - credentials_id: DecodedDatabaseIdField, - trans: ProvidesUserContext = DependsOnTrans, - ) -> DeleteCredentialsResponse: - return self.service.delete_credentials(trans, user_id, user_credentials_id, credentials_id) + # @router.delete( + # "/api/users/{user_id}/credentials/{user_credentials_id}/{group_id}", + # summary="Deletes a specific credential", + # ) + # def delete_credentials( + # self, + # user_id: UserIdPathParam, + # user_credentials_id: DecodedDatabaseIdField, + # group_id: DecodedDatabaseIdField, + # trans: ProvidesUserContext = DependsOnTrans, + # ): + # self.service.delete_credentials(trans, user_id, user_credentials_id, group_id) + # return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/lib/galaxy/webapps/galaxy/services/credentials.py b/lib/galaxy/webapps/galaxy/services/credentials.py index eb82583e4753..36611554fcac 100644 --- a/lib/galaxy/webapps/galaxy/services/credentials.py +++ b/lib/galaxy/webapps/galaxy/services/credentials.py @@ -1,25 +1,38 @@ from typing import ( + Any, Dict, List, Optional, + Sequence, Tuple, ) +from sqlalchemy import ( + false, + select, + update, +) +from sqlalchemy.orm import aliased + from galaxy import exceptions from galaxy.managers.context import ProvidesUserContext from galaxy.model import ( - Credentials, + Credential, + CredentialsGroup, UserCredentials, ) from galaxy.model.base import transaction from galaxy.schema.credentials import ( + CredentialGroupResponse, CredentialResponse, - CredentialsListResponse, CredentialsPayload, - DeleteCredentialsResponse, + SecretResponse, + SOURCE_TYPE, UpdateCredentialsPayload, + UserCredentialCreateResponse, UserCredentialsListResponse, - VerifyCredentialsResponse, + UserCredentialsResponse, + VariableResponse, ) from galaxy.schema.fields import DecodedDatabaseIdField from galaxy.security.vault import UserVaultWrapper @@ -37,249 +50,308 @@ def list_user_credentials( self, trans: ProvidesUserContext, user_id: UserIdPathParam, - source_type: Optional[str] = None, + source_type: Optional[SOURCE_TYPE] = None, source_id: Optional[str] = None, + group_name: Optional[str] = None, ) -> UserCredentialsListResponse: """Lists all credentials the user has provided (credentials themselves are not included).""" - service_reference = f"{source_type}|{source_id}".strip("|") if source_type else None - user_credentials, credentials_dict = self._user_credentials( - trans, user_id=user_id, service_reference=service_reference + db_user_credentials = self._user_credentials( + trans, user_id=user_id, source_type=source_type, source_id=source_id, group_name=group_name ) - user_credentials_list = [ - CredentialsListResponse( - service_reference=sref, - user_credentials_id=next( - (uc.id for uc in user_credentials if uc.service_reference == sref), - None, - ), - credentials=self._credentials_response(creds), - ) - for sref, creds in credentials_dict.items() - ] - return UserCredentialsListResponse(root=user_credentials_list) - - def verify_service_credentials( - self, - trans: ProvidesUserContext, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - ) -> VerifyCredentialsResponse: - """Verifies if credentials have been provided for a specific service (no credential data returned).""" - _, credentials_dict = self._user_credentials(trans, user_id=user_id, user_credentials_id=user_credentials_id) - return VerifyCredentialsResponse(exists=bool(credentials_dict)) - - def verify_credentials( - self, - trans: ProvidesUserContext, - user_credentials_id: DecodedDatabaseIdField, - credentials_id: DecodedDatabaseIdField, - ) -> VerifyCredentialsResponse: - """Verifies if a credential have been provided (no credential data returned).""" - credentials = self._credentials(trans, user_credentials_id=user_credentials_id, id=credentials_id) - return VerifyCredentialsResponse(exists=bool(credentials)) + credentials_dict = self._user_credentials_to_dict(db_user_credentials) + return UserCredentialsListResponse(root=[UserCredentialsResponse(**cred) for cred in credentials_dict.values()]) def provide_credential( self, trans: ProvidesUserContext, user_id: UserIdPathParam, payload: CredentialsPayload, - ) -> CredentialsListResponse: + ) -> UserCredentialCreateResponse: """Allows users to provide credentials for a secret/variable.""" return self._create_user_credential(trans, user_id, payload) - def update_credential( - self, - trans: ProvidesUserContext, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - payload: UpdateCredentialsPayload, - ) -> CredentialsListResponse: - """Updates credentials for a specific secret/variable.""" - return self._update_user_credential(trans, user_id, user_credentials_id, payload) + # def update_credential( + # self, + # trans: ProvidesUserContext, + # user_id: UserIdPathParam, + # user_credentials_id: DecodedDatabaseIdField, + # payload: UpdateCredentialsPayload, + # ) -> CredentialsListResponse: + # """Updates credentials for a specific secret/variable.""" + # return self._update_user_credential(trans, user_id, user_credentials_id, payload) - def delete_service_credentials( - self, - trans: ProvidesUserContext, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - ) -> DeleteCredentialsResponse: - """Deletes all credentials for a specific service.""" - user_credentials, credentials_dict = self._user_credentials( - trans, user_id=user_id, user_credentials_id=user_credentials_id - ) - session = trans.sa_session - for credentials in credentials_dict.values(): - for credential in credentials: - session.delete(credential) - for user_credential in user_credentials: - session.delete(user_credential) - with transaction(session): - session.commit() - return DeleteCredentialsResponse(deleted=True) + # def delete_service_credentials( + # self, + # trans: ProvidesUserContext, + # user_id: UserIdPathParam, + # user_credentials_id: DecodedDatabaseIdField, + # ): + # """Deletes all credentials for a specific service.""" + # user_credentials = self._user_credentials(trans, user_id=user_id, user_credentials_id=user_credentials_id) + # session = trans.sa_session + # for credentials in credentials_dict.values(): + # for credential in credentials: + # session.delete(credential) + # for user_credential in user_credentials: + # session.delete(user_credential) + # with transaction(session): + # session.commit() - def delete_credentials( - self, - trans: ProvidesUserContext, - user_id: UserIdPathParam, - user_credentials_id: DecodedDatabaseIdField, - credentials_id: DecodedDatabaseIdField, - ) -> DeleteCredentialsResponse: - """Deletes a specific credential.""" - credentials = self._credentials(trans, user_credentials_id=user_credentials_id, id=credentials_id) - session = trans.sa_session - for credential in credentials: - session.delete(credential) - with transaction(session): - session.commit() - return DeleteCredentialsResponse(deleted=True) + # def delete_credentials( + # self, + # trans: ProvidesUserContext, + # user_id: UserIdPathParam, + # group_id: DecodedDatabaseIdField, + # credentials_id: DecodedDatabaseIdField, + # ): + # """Deletes a specific credential group.""" + # credentials = self._credentials(trans, group_id=group_id, id=credentials_id) + # session = trans.sa_session + # for credential in credentials: + # session.delete(credential) + # with transaction(session): + # session.commit() def _user_credentials( self, trans: ProvidesUserContext, user_id: UserIdPathParam, - service_reference: Optional[str] = None, + source_type: Optional[SOURCE_TYPE] = None, + source_id: Optional[str] = None, + reference: Optional[str] = None, + group_name: Optional[str] = None, user_credentials_id: Optional[DecodedDatabaseIdField] = None, - ) -> Tuple[List[UserCredentials], Dict[str, List[Credentials]]]: + ) -> List[Tuple[UserCredentials, CredentialsGroup, Credential]]: if not trans.user_is_admin and (not trans.user or trans.user != user_id): raise exceptions.ItemOwnershipException( "Only admins and the user can manage their own credentials.", type="error" ) - query = trans.sa_session.query(UserCredentials).filter(UserCredentials.user_id == user_id) - if service_reference: - query = query.filter(UserCredentials.service_reference.startswith(service_reference)) - if user_credentials_id: - query = query.filter(UserCredentials.id == user_credentials_id) - user_credentials_list = query.all() - credentials_dict = {} - for user_credential in user_credentials_list: - credentials_list = self._credentials(trans, user_credentials_id=user_credential.id) - credentials_dict[user_credential.service_reference] = credentials_list - return user_credentials_list, credentials_dict + group_alias = aliased(CredentialsGroup) + credential_alias = aliased(Credential) + stmt = ( + select(UserCredentials, group_alias, credential_alias) + .join(group_alias, UserCredentials.groups) + .join(credential_alias, credential_alias.user_credential_group_id == group_alias.id) + .where(UserCredentials.user_id == user_id) + ) + if source_type: + stmt = stmt.where(UserCredentials.source_type == source_type) + if source_id: + if not source_type: + raise exceptions.RequestParameterInvalidException( + "Source type is required when source ID is provided.", type="error" + ) + stmt = stmt.where(UserCredentials.source_id == source_id) + if group_name: + if not source_type or not source_id: + raise exceptions.RequestParameterInvalidException( + "Source type and source ID are required when group name is provided.", type="error" + ) + stmt = stmt.where(group_alias.name == group_name) - def _credentials( - self, - trans: ProvidesUserContext, - user_credentials_id: Optional[DecodedDatabaseIdField] = None, - id: Optional[DecodedDatabaseIdField] = None, - name: Optional[str] = None, - type: Optional[str] = None, - ) -> List[Credentials]: - query = trans.sa_session.query(Credentials) + if reference: + stmt = stmt.where(UserCredentials.reference == reference) if user_credentials_id: - query = query.filter(Credentials.user_credentials_id == user_credentials_id) - if id: - query = query.filter(Credentials.id == id) - if name: - query = query.filter(Credentials.name == name) - if type: - query = query.filter(Credentials.type == type) - return query.all() + stmt = stmt.where(UserCredentials.id == user_credentials_id) - def _credentials_response(self, credentials_list: List[Credentials]) -> List[CredentialResponse]: - return [ - CredentialResponse( - id=credential.id, - name=credential.name, - type=credential.type, + result = trans.sa_session.execute(stmt).all() + return [(uc, cg, c) for uc, cg, c in result] + + def _user_credentials_to_dict( + self, db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup, Credential]] + ) -> Dict[int, Dict[str, Any]]: + grouped_data: Dict[int, Dict[str, Any]] = {} + for user_credentials, credentials_group, credential in db_user_credentials: + grouped_data.setdefault( + user_credentials.id, + dict( + user_id=user_credentials.user_id, + id=user_credentials.id, + reference=user_credentials.reference, + source_type=user_credentials.source_type, + source_id=user_credentials.source_id, + current_group_id=user_credentials.current_group_id, + current_group_name=credentials_group.name, + groups=dict(), + ), ) - for credential in credentials_list - ] - def _update_user_credential( - self, - trans: ProvidesUserContext, - user_id: UserIdPathParam, - user_credential_id: DecodedDatabaseIdField, - payload: UpdateCredentialsPayload, - ) -> CredentialsListResponse: - user_credentials, credentials_dict = self._user_credentials( - trans, user_id, user_credentials_id=user_credential_id - ) - user_credential = user_credentials[0] if user_credentials else None - if not user_credential: - raise exceptions.ObjectNotFound(f"User credential {user_credential_id} not found.", type="error") - db_credentials = sum(credentials_dict.values(), []) - session = trans.sa_session - for credential in payload.root: - existing_credential = next( - (cred for cred in db_credentials if cred.id == credential.id), - None, + grouped_data[user_credentials.id]["groups"].setdefault( + credentials_group.name, + dict( + id=credentials_group.id, + name=credentials_group.name, + variables=[], + secrets=[], + ), ) - if not existing_credential: - raise exceptions.ObjectNotFound(f"Credential {credential.id} not found.", type="error") - if existing_credential.type == "secret": - user_vault = UserVaultWrapper(self._app.vault, trans.user) - user_vault.write_secret( - f"{user_credential.service_reference}|{existing_credential.name}", credential.value + if credential.type == "secret": + grouped_data[user_credentials.id]["groups"][credentials_group.name]["secrets"].append( + dict( + id=credential.id, + name=credential.name, + already_set=True, + ) + ) + elif credential.type == "variable": + grouped_data[user_credentials.id]["groups"][credentials_group.name]["variables"].append( + dict( + id=credential.id, + name=credential.name, + value=credential.value, + ) ) - elif existing_credential.type == "variable": - existing_credential.value = credential.value - session.add(existing_credential) - with transaction(session): - session.commit() - return CredentialsListResponse( - service_reference=user_credential.service_reference, - user_credentials_id=user_credential_id, - credentials=self._credentials_response(db_credentials), - ) + + return grouped_data def _create_user_credential( self, trans: ProvidesUserContext, user_id: UserIdPathParam, payload: CredentialsPayload, - ) -> CredentialsListResponse: - service_reference = payload.service_reference - user_credentials_list, credentials_dict = self._user_credentials( - trans, user_id, service_reference=service_reference + ) -> UserCredentialCreateResponse: + session = trans.sa_session + + source_type, source_id, reference, group_name = ( + payload.source_type, + payload.source_id, + payload.reference, + payload.group_name, ) - user_credential = user_credentials_list[0] if user_credentials_list else None - session = trans.sa_session + db_user_credentials = self._user_credentials( + trans, + user_id=user_id, + source_type=source_type, + source_id=source_id, + reference=reference, + ) + user_credential_dict = self._user_credentials_to_dict(db_user_credentials) + if user_credential_dict: + for user_credential_data in user_credential_dict.values(): + if group_name in user_credential_data["groups"]: + raise exceptions.RequestParameterInvalidException( + f"Group name '{group_name}' already exists for the given user credentials.", type="error" + ) - if not user_credential: - user_credential = UserCredentials( + credentials_group = CredentialsGroup(name=group_name) + existing_user_credentials = next(iter(db_user_credentials), None) + if existing_user_credentials: + user_credentials = existing_user_credentials[0] + user_credentials.current_group = credentials_group + else: + user_credentials = UserCredentials( user_id=user_id, - service_reference=service_reference, + reference=reference, + source_type=source_type, + source_id=source_id, + current_group=credentials_group, ) - session.add(user_credential) - session.flush() + credentials_group.user_credentials_rel = user_credentials + session.add(user_credentials) + session.flush() + user_credentials_id = user_credentials.id + user_credential_group_id = credentials_group.id - user_credential_id = user_credential.id - db_credentials = credentials_dict.get(service_reference, []) - provided_credentials_list: List[Credentials] = [] + provided_credentials_list: List[Credential] = [] for credential_payload in payload.credentials: - credential_name = credential_payload.name - credential_type = credential_payload.type - credential_value = credential_payload.value - - existing_credential = next( - (cred for cred in db_credentials if cred.name == credential_name and cred.type == credential_type), - None, + credential_name, credential_type, credential_value = ( + credential_payload.name, + credential_payload.type, + credential_payload.value, ) - if existing_credential: - raise exceptions.RequestParameterInvalidException( - f"Credential {service_reference}|{credential_name} already exists.", type="error" - ) - credential = Credentials( - user_credentials_id=user_credential_id, + credential = Credential( + user_credential_group_id=user_credential_group_id, name=credential_name, type=credential_type, + value="", ) if credential_type == "secret": user_vault = UserVaultWrapper(self._app.vault, trans.user) - user_vault.write_secret(f"{service_reference}|{credential_name}", credential_value) + user_vault.write_secret( + f"{source_type}|{source_id}|{reference}|{group_name}|{credential_name}", credential_value + ) elif credential_type == "variable": credential.value = credential_value provided_credentials_list.append(credential) session.add(credential) with transaction(session): session.commit() - return CredentialsListResponse( - service_reference=service_reference, - user_credentials_id=user_credential_id, - credentials=self._credentials_response(provided_credentials_list), + + variables = [ + VariableResponse( + id=credential.id, + name=credential.name, + value=credential.value, + ) + for credential in provided_credentials_list + if credential.type == "variable" + ] + + secrets = [ + SecretResponse( + id=credential.id, + name=credential.name, + already_set=True, + ) + for credential in provided_credentials_list + if credential.type == "secret" + ] + + credentials_group_response = CredentialGroupResponse( + id=user_credential_group_id, + name=group_name, + variables=variables, + secrets=secrets, + ) + + return UserCredentialCreateResponse( + user_id=user_id, + id=user_credentials_id, + source_type=source_type, + source_id=source_id, + reference=reference, + current_group_name=group_name, + group=credentials_group_response, ) + + # def _update_user_credential( + # self, + # trans: ProvidesUserContext, + # user_id: UserIdPathParam, + # user_credential_id: DecodedDatabaseIdField, + # payload: UpdateCredentialsPayload, + # ) -> CredentialsListResponse: + # user_credential = next(self._user_credentials(trans, user_id, user_credentials_id=user_credential_id), None) + # if not user_credential: + # raise exceptions.ObjectNotFound(f"User credential {user_credential_id} not found.", type="error") + # db_credentials = user_credential.credentials + # session = trans.sa_session + # group_id = payload.group_id + # for credential in payload.credentials: + # existing_credential = next( + # (cred for cred in db_credentials if cred.id == credential.id), + # None, + # ) + # if not existing_credential: + # raise exceptions.ObjectNotFound(f"Credential {credential.id} not found.", type="error") + + # if existing_credential.type == "secret": + # user_vault = UserVaultWrapper(self._app.vault, trans.user) + # user_vault.write_secret(f"{user_credential.reference}|{existing_credential.name}", credential.value) + # elif existing_credential.type == "variable": + # existing_credential.value = credential.value + # session.add(existing_credential) + # with transaction(session): + # session.commit() + # return CredentialsListResponse( + # user_credentials_id=user_credential_id, + # source_type=user_credential.source_type, + # source_id=user_credential.source_id, + # reference=user_credential.reference, + # group_name=user_credential.group_name, + # credentials=self._credentials_response(db_credentials), + # )