diff --git a/guacamole_user_sync/ldap/ldap_client.py b/guacamole_user_sync/ldap/ldap_client.py index c233c14..c5400aa 100644 --- a/guacamole_user_sync/ldap/ldap_client.py +++ b/guacamole_user_sync/ldap/ldap_client.py @@ -5,7 +5,7 @@ from ldap.ldapobject import LDAPObject from guacamole_user_sync.models import ( - LDAPException, + LDAPError, LDAPGroup, LDAPQuery, LDAPSearchResult, @@ -16,6 +16,8 @@ class LDAPClient: + """Client for connecting to an LDAP server.""" + def __init__( self, hostname: str, @@ -30,14 +32,14 @@ def __init__( def connect(self) -> LDAPObject: if not self.cnxn: - logger.info(f"Initialising connection to LDAP host at {self.hostname}") + logger.info("Initialising connection to LDAP host at %s", self.hostname) self.cnxn = ldap.initialize(f"ldap://{self.hostname}") if self.bind_dn: try: self.cnxn.simple_bind_s(self.bind_dn, self.bind_password) except ldap.INVALID_CREDENTIALS as exc: logger.warning("Connection credentials were incorrect.") - raise LDAPException from exc + raise LDAPError from exc return self.cnxn def search_groups(self, query: LDAPQuery) -> list[LDAPGroup]: @@ -53,9 +55,9 @@ def search_groups(self, query: LDAPQuery) -> list[LDAPGroup]: group.decode("utf-8") for group in attr_dict["memberUid"] ], name=attr_dict[query.id_attr][0].decode("utf-8"), - ) + ), ) - logger.debug(f"Loaded {len(output)} LDAP groups") + logger.debug("Loaded %s LDAP groups", len(output)) return output def search_users(self, query: LDAPQuery) -> list[LDAPUser]: @@ -70,16 +72,16 @@ def search_users(self, query: LDAPQuery) -> list[LDAPUser]: ], name=attr_dict[query.id_attr][0].decode("utf-8"), uid=attr_dict["uid"][0].decode("utf-8"), - ) + ), ) - logger.debug(f"Loaded {len(output)} LDAP users") + logger.debug("Loaded %s LDAP users", len(output)) return output def search(self, query: LDAPQuery) -> LDAPSearchResult: results: LDAPSearchResult = [] logger.info("Querying LDAP host with:") - logger.info(f"... base DN: {query.base_dn}") - logger.info(f"... filter: {query.filter}") + logger.info("... base DN: %s", query.base_dn) + logger.info("... filter: %s", query.filter) searcher = AsyncSearchList(self.connect()) try: searcher.startSearch( @@ -89,15 +91,16 @@ def search(self, query: LDAPQuery) -> LDAPSearchResult: ) if searcher.processResults() != 0: logger.warning("Only partial results received.") - results = searcher.allResults - logger.debug(f"Server returned {len(results)} results.") - return results except ldap.NO_SUCH_OBJECT as exc: logger.warning("Server returned no results.") - raise LDAPException from exc + raise LDAPError from exc except ldap.SERVER_DOWN as exc: logger.warning("Server could not be reached.") - raise LDAPException from exc + raise LDAPError from exc except ldap.SIZELIMIT_EXCEEDED as exc: logger.warning("Server-side size limit exceeded.") - raise LDAPException from exc + raise LDAPError from exc + else: + results = searcher.allResults + logger.debug("Server returned %s results.", len(results)) + return results diff --git a/guacamole_user_sync/models/__init__.py b/guacamole_user_sync/models/__init__.py index e0ca516..9695664 100644 --- a/guacamole_user_sync/models/__init__.py +++ b/guacamole_user_sync/models/__init__.py @@ -1,14 +1,16 @@ -from .exceptions import LDAPException, PostgreSQLException +from .exceptions import LDAPError, PostgreSQLError +from .guacamole import GuacamoleUserDetails from .ldap_objects import LDAPGroup, LDAPUser from .ldap_query import LDAPQuery LDAPSearchResult = list[tuple[int, tuple[str, dict[str, list[bytes]]]]] __all__ = [ - "LDAPException", + "GuacamoleUserDetails", + "LDAPError", "LDAPGroup", "LDAPQuery", "LDAPSearchResult", "LDAPUser", - "PostgreSQLException", + "PostgreSQLError", ] diff --git a/guacamole_user_sync/models/exceptions.py b/guacamole_user_sync/models/exceptions.py index ac1da8f..dc589c7 100644 --- a/guacamole_user_sync/models/exceptions.py +++ b/guacamole_user_sync/models/exceptions.py @@ -1,6 +1,6 @@ -class LDAPException(Exception): - pass +class LDAPError(Exception): + """LDAP error.""" -class PostgreSQLException(Exception): - pass +class PostgreSQLError(Exception): + """PostgreSQL error.""" diff --git a/guacamole_user_sync/models/guacamole.py b/guacamole_user_sync/models/guacamole.py new file mode 100644 index 0000000..a1771ba --- /dev/null +++ b/guacamole_user_sync/models/guacamole.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class GuacamoleUserDetails: + """A Guacamole user with required attributes only.""" + + entity_id: int + full_name: str + name: str diff --git a/guacamole_user_sync/postgresql/__init__.py b/guacamole_user_sync/postgresql/__init__.py index 6ff288d..c5c4081 100644 --- a/guacamole_user_sync/postgresql/__init__.py +++ b/guacamole_user_sync/postgresql/__init__.py @@ -1,9 +1,10 @@ -from .postgresql_backend import PostgreSQLBackend +from .postgresql_backend import PostgreSQLBackend, PostgreSQLConnectionDetails from .postgresql_client import PostgreSQLClient from .sql import SchemaVersion __all__ = [ "PostgreSQLBackend", + "PostgreSQLConnectionDetails", "PostgreSQLClient", "SchemaVersion", ] diff --git a/guacamole_user_sync/postgresql/orm.py b/guacamole_user_sync/postgresql/orm.py index dfa2c58..3fece71 100644 --- a/guacamole_user_sync/postgresql/orm.py +++ b/guacamole_user_sync/postgresql/orm.py @@ -1,44 +1,47 @@ import enum from datetime import datetime -from sqlalchemy import DateTime, Enum, Integer, String -from sqlalchemy.dialects.postgresql import BYTEA -from sqlalchemy.orm import ( # type:ignore - DeclarativeBase, - Mapped, - mapped_column, -) +from sqlalchemy import DateTime, Enum, Integer, LargeBinary, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -class guacamole_entity_type(enum.Enum): +class GuacamoleEntityType(enum.Enum): + """Guacamole entity enum.""" + USER = "USER" USER_GROUP = "USER_GROUP" -class GuacamoleBase(DeclarativeBase): # type:ignore - pass +class GuacamoleBase(DeclarativeBase): # type: ignore[misc] + """Guacamole database base table.""" class GuacamoleEntity(GuacamoleBase): + """Guacamole database GuacamoleEntity table.""" + __tablename__ = "guacamole_entity" entity_id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String(128)) - type: Mapped[guacamole_entity_type] = mapped_column(Enum(guacamole_entity_type)) + type: Mapped[GuacamoleEntityType] = mapped_column(Enum(GuacamoleEntityType)) class GuacamoleUser(GuacamoleBase): + """Guacamole database GuacamoleUser table.""" + __tablename__ = "guacamole_user" user_id: Mapped[int] = mapped_column(Integer, primary_key=True) entity_id: Mapped[int] = mapped_column(Integer) full_name: Mapped[str] = mapped_column(String(256)) - password_hash: Mapped[bytes] = mapped_column(BYTEA) - password_salt: Mapped[bytes] = mapped_column(BYTEA) + password_hash: Mapped[bytes] = mapped_column(LargeBinary) + password_salt: Mapped[bytes] = mapped_column(LargeBinary) password_date: Mapped[datetime] = mapped_column(DateTime(timezone=True)) class GuacamoleUserGroup(GuacamoleBase): + """Guacamole database GuacamoleUserGroup table.""" + __tablename__ = "guacamole_user_group" user_group_id: Mapped[int] = mapped_column(Integer, primary_key=True) @@ -46,6 +49,8 @@ class GuacamoleUserGroup(GuacamoleBase): class GuacamoleUserGroupMember(GuacamoleBase): + """Guacamole database GuacamoleUserGroupMember table.""" + __tablename__ = "guacamole_user_group_member" user_group_id: Mapped[int] = mapped_column(Integer, primary_key=True) diff --git a/guacamole_user_sync/postgresql/postgresql_backend.py b/guacamole_user_sync/postgresql/postgresql_backend.py index 5639643..f939166 100644 --- a/guacamole_user_sync/postgresql/postgresql_backend.py +++ b/guacamole_user_sync/postgresql/postgresql_backend.py @@ -1,32 +1,38 @@ import logging -from typing import Any, Type, TypeVar +from dataclasses import dataclass +from typing import Any, TypeVar -from sqlalchemy import create_engine -from sqlalchemy.engine import URL, Engine # type:ignore -from sqlalchemy.orm import Session -from sqlalchemy.sql.expression import TextClause +from sqlalchemy import URL, Engine, TextClause, create_engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import DeclarativeBase, Session logger = logging.getLogger("guacamole_user_sync") -T = TypeVar("T") + +@dataclass +class PostgreSQLConnectionDetails: + """Dataclass for holding PostgreSQL connection details.""" + + database_name: str + host_name: str + port: int + user_name: str + user_password: str + + +T = TypeVar("T", bound=DeclarativeBase) class PostgreSQLBackend: + """Backend for connecting to a PostgreSQL database.""" + def __init__( self, *, - database_name: str, - host_name: str, - port: int, - user_name: str, - user_password: str, + connection_details: PostgreSQLConnectionDetails, session: Session | None = None, - ): - self.database_name = database_name - self.host_name = host_name - self.port = port - self.user_name = user_name - self.user_password = user_password + ) -> None: + self.connection_details = connection_details self._engine: Engine | None = None self._session = session @@ -35,11 +41,11 @@ def engine(self) -> Engine: if not self._engine: url_object = URL.create( "postgresql+psycopg", - username=self.user_name, - password=self.user_password, - host=self.host_name, - port=self.port, - database=self.database_name, + username=self.connection_details.user_name, + password=self.connection_details.user_password, + host=self.connection_details.host_name, + port=self.connection_details.port, + database=self.connection_details.database_name, ) self._engine = create_engine(url_object, echo=False) return self._engine @@ -50,29 +56,37 @@ def session(self) -> Session: return Session(self.engine) def add_all(self, items: list[T]) -> None: - with self.session() as session: # type:ignore - with session.begin(): - session.add_all(items) - - def delete(self, table: Type[T], *filter_args: Any) -> None: - with self.session() as session: # type:ignore - with session.begin(): - if filter_args: - session.query(table).filter(*filter_args).delete() - else: - session.query(table).delete() + with self.session() as session, session.begin(): + session.add_all(items) + + def delete( + self, + table: type[T], + *filter_args: Any, # noqa: ANN401 + ) -> None: + with self.session() as session, session.begin(): + if filter_args: + session.query(table).filter(*filter_args).delete() + else: + session.query(table).delete() def execute_commands(self, commands: list[TextClause]) -> None: - with self.session() as session: # type:ignore - with session.begin(): + try: + with self.session() as session, session.begin(): for command in commands: session.execute(command) + except SQLAlchemyError: + logger.warning("Unable to execute PostgreSQL commands.") + raise - def query(self, table: Type[T], **filter_kwargs: Any) -> list[T]: - with self.session() as session: # type:ignore - with session.begin(): - if filter_kwargs: - result = session.query(table).filter_by(**filter_kwargs) - else: - result = session.query(table) - return [item for item in result] + def query( + self, + table: type[T], + **filter_kwargs: Any, # noqa: ANN401 + ) -> list[T]: + with self.session() as session, session.begin(): + if filter_kwargs: + result = session.query(table).filter_by(**filter_kwargs) + else: + result = session.query(table) + return list(result) diff --git a/guacamole_user_sync/postgresql/postgresql_client.py b/guacamole_user_sync/postgresql/postgresql_client.py index 7b17803..3086e7e 100644 --- a/guacamole_user_sync/postgresql/postgresql_client.py +++ b/guacamole_user_sync/postgresql/postgresql_client.py @@ -1,29 +1,32 @@ -import datetime import logging import secrets +from datetime import UTC, datetime -from sqlalchemy.exc import OperationalError +from sqlalchemy.exc import SQLAlchemyError from guacamole_user_sync.models import ( + GuacamoleUserDetails, LDAPGroup, LDAPUser, - PostgreSQLException, + PostgreSQLError, ) from .orm import ( GuacamoleEntity, + GuacamoleEntityType, GuacamoleUser, GuacamoleUserGroup, GuacamoleUserGroupMember, - guacamole_entity_type, ) -from .postgresql_backend import PostgreSQLBackend +from .postgresql_backend import PostgreSQLBackend, PostgreSQLConnectionDetails from .sql import GuacamoleSchema, SchemaVersion logger = logging.getLogger("guacamole_user_sync") class PostgreSQLClient: + """Client for connecting to a PostgreSQL database.""" + def __init__( self, *, @@ -32,48 +35,56 @@ def __init__( port: int, user_name: str, user_password: str, - ): + ) -> None: self.backend = PostgreSQLBackend( - database_name=database_name, - host_name=host_name, - port=port, - user_name=user_name, - user_password=user_password, + connection_details=PostgreSQLConnectionDetails( + database_name=database_name, + host_name=host_name, + port=port, + user_name=user_name, + user_password=user_password, + ), ) def assign_users_to_groups( - self, groups: list[LDAPGroup], users: list[LDAPUser] + self, + groups: list[LDAPGroup], + users: list[LDAPUser], ) -> None: logger.info( - f"Ensuring that {len(users)} user(s)" - f" are correctly assigned among {len(groups)} group(s)" + "Ensuring that %s user(s) are correctly assigned among %s group(s)", + len(users), + len(groups), ) user_group_members = [] for group in groups: - logger.debug(f"Working on group '{group.name}'") + logger.debug("Working on group '%s'", group.name) # Get the user_group_id for each group (via looking up the entity_id) try: - group_entity_id = [ + group_entity_id = next( item.entity_id for item in self.backend.query( GuacamoleEntity, name=group.name, - type=guacamole_entity_type.USER_GROUP, + type=GuacamoleEntityType.USER_GROUP, ) - ][0] - user_group_id = [ + ) + user_group_id = next( item.user_group_id for item in self.backend.query( GuacamoleUserGroup, entity_id=group_entity_id, ) - ][0] + ) logger.debug( - f"-> entity_id: {group_entity_id}; user_group_id: {user_group_id}" + "-> entity_id: %s; user_group_id: %s", + group_entity_id, + user_group_id, ) - except IndexError: + except StopIteration: logger.debug( - f"Could not determine user_group_id for group '{group.name}'." + "Could not determine user_group_id for group '%s'.", + group.name, ) continue # Get the user_entity_id for each user belonging to this group @@ -81,44 +92,52 @@ def assign_users_to_groups( try: user = next(filter(lambda u: u.uid == user_uid, users)) except StopIteration: - logger.debug(f"Could not find LDAP user with UID {user_uid}") + logger.debug("Could not find LDAP user with UID %s", user_uid) continue try: - user_entity_id = [ + user_entity_id = next( item.entity_id for item in self.backend.query( GuacamoleEntity, name=user.name, - type=guacamole_entity_type.USER, + type=GuacamoleEntityType.USER, ) - ][0] + ) logger.debug( - f"... group member '{user}' has entity_id '{user_entity_id}'" + "... group member '%s' has entity_id '%s'", + user, + user_entity_id, + ) + except StopIteration: + logger.debug( + "Could not find entity ID for LDAP user '%s'", + user_uid, ) - except IndexError: - logger.debug(f"Could not find entity ID for LDAP user {user_uid}") continue # Create an entry in the user group member table user_group_members.append( GuacamoleUserGroupMember( user_group_id=user_group_id, member_entity_id=user_entity_id, - ) + ), ) # Clear existing assignments then reassign - logger.debug(f"... creating {len(user_group_members)} user/group assignments.") + logger.debug( + "... creating %s user/group assignments.", + len(user_group_members), + ) self.backend.delete(GuacamoleUserGroupMember) self.backend.add_all(user_group_members) def ensure_schema(self, schema_version: SchemaVersion) -> None: try: self.backend.execute_commands(GuacamoleSchema.commands(schema_version)) - except OperationalError as exc: - logger.warning("Unable to connect to the PostgreSQL server.") - raise PostgreSQLException("Unable to ensure PostgreSQL schema.") from exc + except SQLAlchemyError as exc: + msg = "Unable to ensure PostgreSQL schema." + raise PostgreSQLError(msg) from exc def update(self, *, groups: list[LDAPGroup], users: list[LDAPUser]) -> None: - """Update the relevant tables to match lists of LDAP users and groups""" + """Update the relevant tables to match lists of LDAP users and groups.""" self.update_groups(groups) self.update_users(users) self.update_group_entities() @@ -126,31 +145,33 @@ def update(self, *, groups: list[LDAPGroup], users: list[LDAPUser]) -> None: self.assign_users_to_groups(groups, users) def update_groups(self, groups: list[LDAPGroup]) -> None: - """Update the entities table with desired groups""" + """Update the entities table with desired groups.""" # Set groups to desired list - logger.info(f"Ensuring that {len(groups)} group(s) are registered") + logger.info("Ensuring that %s group(s) are registered", len(groups)) desired_group_names = [group.name for group in groups] current_group_names = [ item.name for item in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER_GROUP + GuacamoleEntity, + type=GuacamoleEntityType.USER_GROUP, ) ] # Add groups logger.debug( - f"There are {len(current_group_names)} group(s) currently registered" + "There are %s group(s) currently registered", + len(current_group_names), ) group_names_to_add = [ group_name for group_name in desired_group_names if group_name not in current_group_names ] - logger.debug(f"... {len(group_names_to_add)} group(s) will be added") + logger.debug("... %s group(s) will be added", len(group_names_to_add)) self.backend.add_all( [ - GuacamoleEntity(name=group_name, type=guacamole_entity_type.USER_GROUP) + GuacamoleEntity(name=group_name, type=GuacamoleEntityType.USER_GROUP) for group_name in group_names_to_add - ] + ], ) # Remove groups group_names_to_remove = [ @@ -158,32 +179,34 @@ def update_groups(self, groups: list[LDAPGroup]) -> None: for group_name in current_group_names if group_name not in desired_group_names ] - logger.debug(f"... {len(group_names_to_remove)} group(s) will be removed") + logger.debug("... %s group(s) will be removed", len(group_names_to_remove)) for group_name in group_names_to_remove: self.backend.delete( GuacamoleEntity, GuacamoleEntity.name == group_name, - GuacamoleEntity.type == guacamole_entity_type.USER_GROUP, + GuacamoleEntity.type == GuacamoleEntityType.USER_GROUP, ) def update_group_entities(self) -> None: - """Add group entities to the groups table""" + """Add group entities to the groups table.""" current_user_group_entity_ids = [ group.entity_id for group in self.backend.query(GuacamoleUserGroup) ] logger.debug( - f"There are {len(current_user_group_entity_ids)}" - " user group entit(y|ies) currently registered" + "There are %s user group entit(y|ies) currently registered", + len(current_user_group_entity_ids), ) new_group_entity_ids = [ group.entity_id for group in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER_GROUP + GuacamoleEntity, + type=GuacamoleEntityType.USER_GROUP, ) if group.entity_id not in current_user_group_entity_ids ] logger.debug( - f"... {len(new_group_entity_ids)} user group entit(y|ies) will be added" + "... %s user group entit(y|ies) will be added", + len(new_group_entity_ids), ) self.backend.add_all( [ @@ -191,44 +214,53 @@ def update_group_entities(self) -> None: entity_id=group_entity_id, ) for group_entity_id in new_group_entity_ids - ] + ], ) # Clean up any unused entries valid_entity_ids = [ group.entity_id for group in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER_GROUP + GuacamoleEntity, + type=GuacamoleEntityType.USER_GROUP, ) ] - logger.debug(f"There are {len(valid_entity_ids)} valid user group entit(y|ies)") + logger.debug( + "There are %s valid user group entit(y|ies)", + len(valid_entity_ids), + ) self.backend.delete( - GuacamoleUserGroup, GuacamoleUserGroup.entity_id.not_in(valid_entity_ids) + GuacamoleUserGroup, + GuacamoleUserGroup.entity_id.not_in(valid_entity_ids), ) def update_users(self, users: list[LDAPUser]) -> None: - """Update the entities table with desired users""" + """Update the entities table with desired users.""" # Set users to desired list - logger.info(f"Ensuring that {len(users)} user(s) are registered") + logger.info("Ensuring that %s user(s) are registered", len(users)) desired_usernames = [user.name for user in users] current_usernames = [ user.name for user in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER + GuacamoleEntity, + type=GuacamoleEntityType.USER, ) ] # Add users - logger.debug(f"There are {len(current_usernames)} user(s) currently registered") + logger.debug( + "There are %s user(s) currently registered", + len(current_usernames), + ) usernames_to_add = [ username for username in desired_usernames if username not in current_usernames ] - logger.debug(f"... {len(usernames_to_add)} user(s) will be added") + logger.debug("... %s user(s) will be added", len(usernames_to_add)) self.backend.add_all( [ - GuacamoleEntity(name=username, type=guacamole_entity_type.USER) + GuacamoleEntity(name=username, type=GuacamoleEntityType.USER) for username in usernames_to_add - ] + ], ) # Remove users usernames_to_remove = [ @@ -236,53 +268,62 @@ def update_users(self, users: list[LDAPUser]) -> None: for username in current_usernames if username not in desired_usernames ] - logger.debug(f"... {len(usernames_to_remove)} user(s) will be removed") + logger.debug("... %s user(s) will be removed", len(usernames_to_remove)) for username in usernames_to_remove: self.backend.delete( GuacamoleEntity, GuacamoleEntity.name == username, - GuacamoleEntity.type == guacamole_entity_type.USER, + GuacamoleEntity.type == GuacamoleEntityType.USER, ) def update_user_entities(self, users: list[LDAPUser]) -> None: - """Add user entities to the users table""" + """Add user entities to the users table.""" current_user_entity_ids = [ user.entity_id for user in self.backend.query(GuacamoleUser) ] logger.debug( - f"There are {len(current_user_entity_ids)} " - "user entit(y|ies) currently registered" + "There are %s user entit(y|ies) currently registered", + len(current_user_entity_ids), ) - new_user_tuples: list[tuple[int, LDAPUser]] = [ - (user.entity_id, [u for u in users if u.name == user.name][0]) - for user in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER + user_entities = self.backend.query( + GuacamoleEntity, + type=GuacamoleEntityType.USER, + ) + new_users = [ + GuacamoleUserDetails( + entity_id=entity.entity_id, + full_name=user.display_name, + name=user.name, ) - if user.entity_id not in current_user_entity_ids + for user in users + for entity in user_entities + if entity.name == user.name + and entity.entity_id not in current_user_entity_ids ] - logger.debug( - f"... {len(current_user_entity_ids)} user entit(y|ies) will be added" - ) + logger.debug("... %s user entit(y|ies) will be added", len(new_users)) + self.backend.add_all( [ GuacamoleUser( - entity_id=user_tuple[0], - full_name=user_tuple[1].display_name, - password_date=datetime.datetime.now(), + entity_id=new_user.entity_id, + full_name=new_user.full_name, + password_date=datetime.now(tz=UTC), password_hash=secrets.token_bytes(32), password_salt=secrets.token_bytes(32), ) - for user_tuple in new_user_tuples - ] + for new_user in new_users + ], ) # Clean up any unused entries valid_entity_ids = [ user.entity_id for user in self.backend.query( - GuacamoleEntity, type=guacamole_entity_type.USER + GuacamoleEntity, + type=GuacamoleEntityType.USER, ) ] - logger.debug(f"There are {len(valid_entity_ids)} valid user entit(y|ies)") + logger.debug("There are %s valid user entit(y|ies)", len(valid_entity_ids)) self.backend.delete( - GuacamoleUser, GuacamoleUser.entity_id.not_in(valid_entity_ids) + GuacamoleUser, + GuacamoleUser.entity_id.not_in(valid_entity_ids), ) diff --git a/guacamole_user_sync/postgresql/sql.py b/guacamole_user_sync/postgresql/sql.py index 6523440..c4df5bc 100644 --- a/guacamole_user_sync/postgresql/sql.py +++ b/guacamole_user_sync/postgresql/sql.py @@ -3,26 +3,28 @@ from pathlib import Path import sqlparse -from sqlalchemy import text -from sqlalchemy.sql.expression import TextClause +from sqlalchemy import TextClause, text logger = logging.getLogger("guacamole_user_sync") class SchemaVersion(StrEnum): + """Version for Guacamole database schema.""" + v1_5_5 = "1.5.5" class GuacamoleSchema: + """Schema for Guacamole database.""" @classmethod def commands(cls, schema_version: SchemaVersion) -> list[TextClause]: - logger.info(f"Ensuring correct schema for Guacamole {schema_version.value}") + logger.info("Ensuring correct schema for Guacamole %s", schema_version.value) commands = [] sql_file_path = Path(__file__).with_name( - f"guacamole_schema.{schema_version.value}.sql" + f"guacamole_schema.{schema_version.value}.sql", ) - with open(sql_file_path, "r") as f_sql: + with Path.open(sql_file_path) as f_sql: statements = sqlparse.split(f_sql.read()) for statement in statements: # Extract the first comment if there is one @@ -34,9 +36,9 @@ def commands(cls, schema_version: SchemaVersion) -> list[TextClause]: if isinstance(token, sqlparse.sql.Comment) ] if first_comment := next(filter(lambda item: item, comment_lines), None): - logger.debug(f"... {first_comment}") + logger.debug("... %s", first_comment) # Extract the command commands.append( - text(sqlparse.format(statement, strip_comments=True, compact=True)) + text(sqlparse.format(statement, strip_comments=True, compact=True)), ) return commands diff --git a/pyproject.toml b/pyproject.toml index f7bf290..cdfeecb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ lint = [ "black==24.8.0", "mypy==1.11.2", "ruff==0.6.2", - "types-SQLAlchemy==1.4.53.38", ] test = [ "coverage[toml]==7.6.1", @@ -95,15 +94,70 @@ strict = true # enable all optional error checking flags module = [ "ldap.*", "pytest.*", + "sqlalchemy.*", "sqlparse.*", ] ignore_missing_imports = true [tool.ruff.lint] select = [ - # See https://beta.ruff.rs/docs/rules/ - "E", # pycodestyle errors - "F", # pyflakes - "I", # isort - "W", # pycodestyle warnings + "A", # flake8-builtins + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "C90", # McCabe complexity + "COM", # flake8-commas + "D", # pydocstyle + "DTZ", # flake8-datetimez + "E", # pycodestyle errors + "EM", # flake8-errmsg + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FLY", # flynt + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "PERF", # perflint + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PL", # pylint + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-rse + "RUF", # Ruff-specific rules + "S", # flake8-bandit + "SIM", # flake8-simplify + "SLF", # flake8-self + "T20", # flake8-print + "TCH", # flake8-type-checking + "TID", # flake8-tidy-imports + "TRY", # tryceratops + "UP", # pyupgrade + "W", # pycodestyle warnings + "YTT", # flake8-2020 ] +ignore = [ + "ANN101", # missing-type-self [deprecated] + "ANN102", # missing-type-cls [deprecated] + "D100", # undocumented-public-module + "D102", # undocumented-public-method + "D103", # undocumented-public-function + "D104", # undocumented-public-package + "D105", # undocumented-magic-method + "D107", # undocumented-public-init + "D203", # one-blank-line-before-class [conflicts with D211] + "D213", # multi-line-summary-second-line [conflicts with D212] + "D400", # ends-in-period [conflicts with D415] + "S101", # assert [conflicts with pytest] +] \ No newline at end of file diff --git a/synchronise.py b/synchronise.py index 626e4a8..eb89731 100644 --- a/synchronise.py +++ b/synchronise.py @@ -4,11 +4,11 @@ import time from guacamole_user_sync.ldap import LDAPClient -from guacamole_user_sync.models import LDAPException, LDAPQuery, PostgreSQLException +from guacamole_user_sync.models import LDAPError, LDAPQuery, PostgreSQLError from guacamole_user_sync.postgresql import PostgreSQLClient, SchemaVersion -def main( +def main( # noqa: PLR0913 ldap_bind_dn: str | None, ldap_bind_password: str | None, ldap_group_base_dn: str, @@ -61,7 +61,7 @@ def main( ) # Wait before repeating - logger.info(f"Waiting {repeat_interval} seconds.") + logger.info("Waiting %s seconds.", repeat_interval) time.sleep(repeat_interval) @@ -76,37 +76,45 @@ def synchronise( try: ldap_groups = ldap_client.search_groups(ldap_group_query) ldap_users = ldap_client.search_users(ldap_user_query) - except LDAPException: + except LDAPError: logger.warning("LDAP server query failed") return try: postgresql_client.ensure_schema(SchemaVersion.v1_5_5) postgresql_client.update(groups=ldap_groups, users=ldap_users) - except PostgreSQLException: + except PostgreSQLError: logger.warning("PostgreSQL update failed") return if __name__ == "__main__": if not (ldap_host := os.getenv("LDAP_HOST", None)): - raise ValueError("LDAP_HOST is not defined") + msg = "LDAP_HOST is not defined" + raise ValueError(msg) if not (ldap_group_base_dn := os.getenv("LDAP_GROUP_BASE_DN", None)): - raise ValueError("LDAP_GROUP_BASE_DN is not defined") + msg = "LDAP_GROUP_BASE_DN is not defined" + raise ValueError(msg) if not (ldap_group_filter := os.getenv("LDAP_GROUP_FILTER", None)): - raise ValueError("LDAP_GROUP_FILTER is not defined") + msg = "LDAP_GROUP_FILTER is not defined" + raise ValueError(msg) if not (ldap_user_base_dn := os.getenv("LDAP_USER_BASE_DN", None)): - raise ValueError("LDAP_USER_BASE_DN is not defined") + msg = "LDAP_USER_BASE_DN is not defined" + raise ValueError(msg) if not (ldap_user_filter := os.getenv("LDAP_USER_FILTER", None)): - raise ValueError("LDAP_USER_FILTER is not defined") + msg = "LDAP_USER_FILTER is not defined" + raise ValueError(msg) if not (postgresql_host_name := os.getenv("POSTGRESQL_HOST", None)): - raise ValueError("POSTGRESQL_HOST is not defined") + msg = "POSTGRESQL_HOST is not defined" + raise ValueError(msg) if not (postgresql_password := os.getenv("POSTGRESQL_PASSWORD", None)): - raise ValueError("POSTGRESQL_PASSWORD is not defined") + msg = "POSTGRESQL_PASSWORD is not defined" + raise ValueError(msg) if not (postgresql_user_name := os.getenv("POSTGRESQL_USERNAME", None)): - raise ValueError("POSTGRESQL_USERNAME is not defined") + msg = "POSTGRESQL_USERNAME is not defined" + raise ValueError(msg) logging.basicConfig( level=( @@ -126,14 +134,14 @@ def synchronise( ldap_group_filter=ldap_group_filter, ldap_group_name_attr=os.getenv("LDAP_GROUP_NAME_ATTR", "cn"), ldap_host=ldap_host, - ldap_port=int(os.getenv("LDAP_PORT", 389)), + ldap_port=int(os.getenv("LDAP_PORT", "389")), ldap_user_base_dn=ldap_user_base_dn, ldap_user_filter=ldap_user_filter, ldap_user_name_attr=os.getenv("LDAP_USER_NAME_ATTR", "userPrincipalName"), postgresql_database_name=os.getenv("POSTGRESQL_DB_NAME", "guacamole"), postgresql_host_name=postgresql_host_name, postgresql_password=postgresql_password, - postgresql_port=int(os.getenv("POSTGRESQL_PORT", 5432)), + postgresql_port=int(os.getenv("POSTGRESQL_PORT", "5432")), postgresql_user_name=postgresql_user_name, - repeat_interval=int(os.getenv("REPEAT_INTERVAL", 300)), + repeat_interval=int(os.getenv("REPEAT_INTERVAL", "300")), ) diff --git a/tests/conftest.py b/tests/conftest.py index 8a339f1..12afadf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,13 @@ -from datetime import datetime +from datetime import UTC, datetime import pytest from guacamole_user_sync.models import LDAPGroup, LDAPQuery, LDAPSearchResult, LDAPUser from guacamole_user_sync.postgresql.orm import ( GuacamoleEntity, + GuacamoleEntityType, GuacamoleUser, GuacamoleUserGroup, - guacamole_entity_type, ) @@ -138,42 +138,50 @@ def ldap_response_users_fixture() -> LDAPSearchResult: @pytest.fixture -def postgresql_model_guacamoleentity_USER_GROUP_fixture() -> list[GuacamoleEntity]: +def postgresql_model_guacamoleentity_user_group_fixture() -> list[GuacamoleEntity]: return [ GuacamoleEntity( - entity_id=1, name="defendants", type=guacamole_entity_type.USER_GROUP + entity_id=1, + name="defendants", + type=GuacamoleEntityType.USER_GROUP, ), GuacamoleEntity( - entity_id=2, name="everyone", type=guacamole_entity_type.USER_GROUP + entity_id=2, + name="everyone", + type=GuacamoleEntityType.USER_GROUP, ), GuacamoleEntity( - entity_id=3, name="plaintiffs", type=guacamole_entity_type.USER_GROUP + entity_id=3, + name="plaintiffs", + type=GuacamoleEntityType.USER_GROUP, ), ] @pytest.fixture -def postgresql_model_guacamoleentity_USER_fixture() -> list[GuacamoleEntity]: +def postgresql_model_guacamoleentity_user_fixture() -> list[GuacamoleEntity]: return [ GuacamoleEntity( - entity_id=4, name="aulus.agerius@rome.la", type=guacamole_entity_type.USER + entity_id=4, + name="aulus.agerius@rome.la", + type=GuacamoleEntityType.USER, ), GuacamoleEntity( entity_id=5, name="numerius.negidius@rome.la", - type=guacamole_entity_type.USER, + type=GuacamoleEntityType.USER, ), ] @pytest.fixture def postgresql_model_guacamoleentity_fixture( - postgresql_model_guacamoleentity_USER_GROUP_fixture: list[GuacamoleEntity], - postgresql_model_guacamoleentity_USER_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_group_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_fixture: list[GuacamoleEntity], ) -> list[GuacamoleEntity]: return ( - postgresql_model_guacamoleentity_USER_GROUP_fixture - + postgresql_model_guacamoleentity_USER_fixture + postgresql_model_guacamoleentity_user_group_fixture + + postgresql_model_guacamoleentity_user_fixture ) @@ -186,7 +194,7 @@ def postgresql_model_guacamoleuser_fixture() -> list[GuacamoleUser]: full_name="Aulus Agerius", password_hash=b"PASSWORD_HASH", password_salt=b"PASSWORD_SALT", - password_date=datetime(1, 1, 1), + password_date=datetime(1, 1, 1, tzinfo=UTC), ), GuacamoleUser( user_id=2, @@ -194,7 +202,7 @@ def postgresql_model_guacamoleuser_fixture() -> list[GuacamoleUser]: full_name="Numerius Negidius", password_hash=b"PASSWORD_HASH", password_salt=b"PASSWORD_SALT", - password_date=datetime(1, 1, 1), + password_date=datetime(1, 1, 1, tzinfo=UTC), ), ] diff --git a/tests/mocks.py b/tests/mocks.py index 7721ce7..02a7b24 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -1,85 +1,108 @@ -from typing import Any, Generic, Type, TypeVar +from typing import Any import ldap -from sqlalchemy.sql.expression import TextClause +from sqlalchemy import TextClause from guacamole_user_sync.models import LDAPSearchResult +from guacamole_user_sync.postgresql.orm import GuacamoleBase class MockLDAPObject: + """Mock LDAPObject.""" + def __init__(self, uri: str) -> None: self.uri = uri self.bind_dn = "" self.bind_password = "" def simple_bind_s(self, bind_dn: str, bind_password: str) -> None: - if bind_password == "incorrect-password": + if bind_password == "incorrect-password": # noqa: S105 raise ldap.INVALID_CREDENTIALS self.bind_dn = bind_dn self.bind_password = bind_password class MockAsyncSearchList: + """Mock AsyncSearchList.""" + def __init__( - self, partial: bool, results: LDAPSearchResult, *args: Any, **kwargs: Any + self, + partial: bool, # noqa: FBT001 + results: LDAPSearchResult, + *args: Any, # noqa: ANN401, ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 ) -> None: self.allResults = results self.partial = partial - def startSearch(self, *args: Any, **kwargs: Any) -> None: + def startSearch( # noqa: N802 + self, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: pass - def processResults(self, *args: Any, **kwargs: Any) -> bool: + def processResults( # noqa: N802 + self, + *args: Any, # noqa: ANN401, ARG002 + **kwargs: Any, # noqa: ANN401, ARG002 + ) -> bool: return self.partial class MockAsyncSearchListFullResults(MockAsyncSearchList): + """Mock AsyncSearchList with full results.""" + def __init__(self, results: LDAPSearchResult) -> None: super().__init__(results=results, partial=False) class MockAsyncSearchListPartialResults(MockAsyncSearchList): + """Mock AsyncSearchList with partial results.""" + def __init__(self, results: LDAPSearchResult) -> None: super().__init__(results=results, partial=True) -T = TypeVar("T") - - -class MockPostgreSQLBackend(Generic[T]): +class MockPostgreSQLBackend: + """Mock PostgreSQLBackend.""" - def __init__(self, *data_lists: Any, **kwargs: Any) -> None: - self.contents: dict[Type[T], list[T]] = {} + def __init__(self, *data_lists: Any, **kwargs: Any) -> None: # noqa: ANN401, ARG002 + self.contents: dict[type[GuacamoleBase], list[GuacamoleBase]] = {} for data_list in data_lists: self.add_all(data_list) - def add_all(self, items: list[T]) -> None: + def add_all(self, items: list[GuacamoleBase]) -> None: cls = type(items[0]) if cls not in self.contents: self.contents[cls] = [] self.contents[cls] += items - def delete(self, *args: Any, **kwargs: Any) -> None: + def delete(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 pass def execute_commands(self, commands: list[TextClause]) -> None: for command in commands: - print(f"Executing {command}") + print(f"Executing {command}") # noqa: T201 - def query(self, table: Type[T], **filter_kwargs: Any) -> Any: + def query( + self, + table: type[GuacamoleBase], + **filter_kwargs: Any, # noqa: ANN401 + ) -> list[GuacamoleBase]: if table not in self.contents: self.contents[table] = [] - results = [item for item in self.contents[table]] + results = list(self.contents[table]) if "entity_id" in filter_kwargs: results = [ - item for item in results if item.entity_id == filter_kwargs["entity_id"] # type: ignore + item for item in results if item.entity_id == filter_kwargs["entity_id"] ] if "name" in filter_kwargs: - results = [item for item in results if item.name == filter_kwargs["name"]] # type: ignore + results = [item for item in results if item.name == filter_kwargs["name"]] if "type" in filter_kwargs: - results = [item for item in results if item.type == filter_kwargs["type"]] # type: ignore + results = [item for item in results if item.type == filter_kwargs["type"]] return results diff --git a/tests/test_about.py b/tests/test_about.py index d3ffaa1..e87cf36 100644 --- a/tests/test_about.py +++ b/tests/test_about.py @@ -1,5 +1,8 @@ from guacamole_user_sync import version -def test_about() -> None: - assert version == "0.6.0" +class TestAbout: + """Test about.py.""" + + def test_version(self) -> None: + assert version == "0.6.0" diff --git a/tests/test_ldap.py b/tests/test_ldap.py index 61371d5..af61854 100644 --- a/tests/test_ldap.py +++ b/tests/test_ldap.py @@ -7,7 +7,7 @@ from guacamole_user_sync.ldap import LDAPClient from guacamole_user_sync.models import ( - LDAPException, + LDAPError, LDAPGroup, LDAPQuery, LDAPSearchResult, @@ -22,11 +22,13 @@ class TestLDAPClient: + """Test LDAPClient.""" + def test_constructor(self) -> None: client = LDAPClient(hostname="test-host") assert client.hostname == "test-host" - def test_connect(self, monkeypatch: Any) -> None: + def test_connect(self, monkeypatch: pytest.MonkeyPatch) -> None: def mock_initialize(uri: str) -> MockLDAPObject: return MockLDAPObject(uri) @@ -37,70 +39,88 @@ def mock_initialize(uri: str) -> MockLDAPObject: assert isinstance(cnxn, MockLDAPObject) assert cnxn.uri == "ldap://test-host" - def test_connect_with_bind(self, monkeypatch: Any) -> None: + def test_connect_with_bind(self, monkeypatch: pytest.MonkeyPatch) -> None: def mock_initialize(uri: str) -> MockLDAPObject: return MockLDAPObject(uri) monkeypatch.setattr(ldap, "initialize", mock_initialize) client = LDAPClient( - hostname="test-host", bind_dn="bind-dn", bind_password="bind_password" + hostname="test-host", + bind_dn="bind-dn", + bind_password="bind_password", # noqa: S106 ) cnxn = client.connect() assert isinstance(cnxn, MockLDAPObject) assert cnxn.bind_dn == "bind-dn" - assert cnxn.bind_password == "bind_password" + assert cnxn.bind_password == "bind_password" # noqa: S105 - def test_connect_with_failed_bind(self, monkeypatch: Any, caplog: Any) -> None: + def test_connect_with_failed_bind( + self, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, + ) -> None: def mock_initialize(uri: str) -> MockLDAPObject: return MockLDAPObject(uri) monkeypatch.setattr(ldap, "initialize", mock_initialize) client = LDAPClient( - hostname="test-host", bind_dn="bind-dn", bind_password="incorrect-password" + hostname="test-host", + bind_dn="bind-dn", + bind_password="incorrect-password", # noqa: S106 ) - with pytest.raises(LDAPException): + with pytest.raises(LDAPError): client.connect() assert "Connection credentials were incorrect." in caplog.text - def test_search_exception_server_down(self, monkeypatch: Any, caplog: Any) -> None: - def mock_raise_server_down(*args: Any) -> None: + def test_search_exception_server_down( + self, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, + ) -> None: + def mock_raise_server_down(*args: Any) -> None: # noqa: ANN401, ARG001 raise ldap.SERVER_DOWN monkeypatch.setattr( - ldap.asyncsearch.List, "startSearch", mock_raise_server_down + ldap.asyncsearch.List, + "startSearch", + mock_raise_server_down, ) client = LDAPClient(hostname="test-host") - with pytest.raises(LDAPException): + with pytest.raises(LDAPError): client.search(query=LDAPQuery(base_dn="", filter="", id_attr="")) assert "Server could not be reached." in caplog.text def test_search_exception_sizelimit_exceeded( - self, monkeypatch: Any, caplog: Any + self, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, ) -> None: - def mock_raise_sizelimit_exceeded(*args: Any) -> None: + def mock_raise_sizelimit_exceeded(*args: Any) -> None: # noqa: ANN401, ARG001 raise ldap.SIZELIMIT_EXCEEDED monkeypatch.setattr( - ldap.asyncsearch.List, "startSearch", mock_raise_sizelimit_exceeded + ldap.asyncsearch.List, + "startSearch", + mock_raise_sizelimit_exceeded, ) client = LDAPClient(hostname="test-host") - with pytest.raises(LDAPException): + with pytest.raises(LDAPError): client.search(query=LDAPQuery(base_dn="", filter="", id_attr="")) assert "Server-side size limit exceeded." in caplog.text def test_search_failure_partial( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_response_groups_fixture: LDAPSearchResult, ) -> None: caplog.set_level(logging.DEBUG) with mock.patch( - "guacamole_user_sync.ldap.ldap_client.AsyncSearchList" + "guacamole_user_sync.ldap.ldap_client.AsyncSearchList", ) as mock_async_search_list: mock_async_search_list.return_value = MockAsyncSearchListPartialResults( - results=ldap_response_groups_fixture[0:1] + results=ldap_response_groups_fixture[0:1], ) client = LDAPClient(hostname="test-host") client.search(query=LDAPQuery(base_dn="", filter="", id_attr="")) @@ -109,31 +129,31 @@ def test_search_failure_partial( def test_search_no_results( self, - monkeypatch: Any, - caplog: Any, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, ) -> None: - def mock_raise_no_results(*args: Any) -> None: + def mock_raise_no_results(*args: Any) -> None: # noqa: ANN401, ARG001 raise ldap.NO_SUCH_OBJECT monkeypatch.setattr(ldap.asyncsearch.List, "startSearch", mock_raise_no_results) client = LDAPClient(hostname="test-host") - with pytest.raises(LDAPException): + with pytest.raises(LDAPError): client.search(query=LDAPQuery(base_dn="", filter="", id_attr="")) assert "Server returned no results." in caplog.text def test_search_groups( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_query_groups_fixture: LDAPQuery, ldap_response_groups_fixture: LDAPSearchResult, ldap_model_groups_fixture: list[LDAPGroup], ) -> None: caplog.set_level(logging.DEBUG) with mock.patch( - "guacamole_user_sync.ldap.ldap_client.AsyncSearchList" + "guacamole_user_sync.ldap.ldap_client.AsyncSearchList", ) as mock_async_search_list: mock_async_search_list.return_value = MockAsyncSearchListFullResults( - results=ldap_response_groups_fixture + results=ldap_response_groups_fixture, ) client = LDAPClient(hostname="test-host") users = client.search_groups(query=ldap_query_groups_fixture) @@ -144,17 +164,17 @@ def test_search_groups( def test_search_users( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_query_users_fixture: LDAPQuery, ldap_response_users_fixture: LDAPSearchResult, ldap_model_users_fixture: list[LDAPUser], ) -> None: caplog.set_level(logging.DEBUG) with mock.patch( - "guacamole_user_sync.ldap.ldap_client.AsyncSearchList" + "guacamole_user_sync.ldap.ldap_client.AsyncSearchList", ) as mock_async_search_list: mock_async_search_list.return_value = MockAsyncSearchListFullResults( - results=ldap_response_users_fixture + results=ldap_response_users_fixture, ) client = LDAPClient(hostname="test-host") users = client.search_users(query=ldap_query_users_fixture) diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index bb6eb9f..3c8f6e4 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -3,21 +3,24 @@ from unittest import mock import pytest -from sqlalchemy import text +from sqlalchemy import URL, Engine, text from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg -from sqlalchemy.engine import URL, Engine # type: ignore from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session -from sqlalchemy.pool.impl import QueuePool -from sqlalchemy.sql.elements import BinaryExpression - -from guacamole_user_sync.models import LDAPGroup, LDAPUser, PostgreSQLException -from guacamole_user_sync.postgresql import PostgreSQLBackend, PostgreSQLClient +from sqlalchemy.pool import QueuePool +from sqlalchemy.sql.elements import BinaryExpression, TextClause + +from guacamole_user_sync.models import LDAPGroup, LDAPUser, PostgreSQLError +from guacamole_user_sync.postgresql import ( + PostgreSQLBackend, + PostgreSQLClient, + PostgreSQLConnectionDetails, +) from guacamole_user_sync.postgresql.orm import ( GuacamoleEntity, + GuacamoleEntityType, GuacamoleUser, GuacamoleUserGroup, - guacamole_entity_type, ) from guacamole_user_sync.postgresql.sql import SchemaVersion @@ -25,53 +28,57 @@ class TestPostgreSQLBackend: - def mock_backend( - self, test_session: bool = False - ) -> tuple[PostgreSQLBackend, mock.MagicMock]: + """Test PostgreSQLBackend.""" + + def mock_backend(self, session: Session | None = None) -> PostgreSQLBackend: + return PostgreSQLBackend( + connection_details=PostgreSQLConnectionDetails( + database_name="database_name", + host_name="host_name", + port=1234, + user_name="user_name", + user_password="user_password", # noqa: S106 + ), + session=session, + ) + + def mock_session(self) -> mock.MagicMock: mock_session = mock.MagicMock() mock_session.__enter__.return_value = mock_session mock_session.filter.return_value = mock_session mock_session.query.return_value = mock_session - backend = PostgreSQLBackend( - database_name="database_name", - host_name="host_name", - port=1234, - user_name="user_name", - user_password="user_password", - session=mock_session, - ) - if test_session: - backend._session = None - return (backend, mock_session) + return mock_session def test_constructor(self) -> None: - backend, _ = self.mock_backend() + backend = self.mock_backend() assert isinstance(backend, PostgreSQLBackend) - assert backend.database_name == "database_name" - assert backend.host_name == "host_name" - assert backend.port == 1234 - assert backend.user_name == "user_name" - assert backend.user_password == "user_password" + assert isinstance(backend.connection_details, PostgreSQLConnectionDetails) + assert backend.connection_details.database_name == "database_name" + assert backend.connection_details.host_name == "host_name" + assert backend.connection_details.port == 1234 # noqa: PLR2004 + assert backend.connection_details.user_name == "user_name" + assert backend.connection_details.user_password == "user_password" # noqa: S105 def test_engine(self) -> None: - backend, _ = self.mock_backend() + backend = self.mock_backend() assert isinstance(backend.engine, Engine) assert isinstance(backend.engine.pool, QueuePool) assert isinstance(backend.engine.dialect, PGDialect_psycopg) assert isinstance(backend.engine.url, URL) assert backend.engine.logging_name is None assert not backend.engine.echo - assert not backend.engine.hide_parameters # type: ignore + assert not backend.engine.hide_parameters def test_session(self) -> None: - backend, _ = self.mock_backend(test_session=True) + backend = self.mock_backend() assert isinstance(backend.session(), Session) def test_add_all( self, postgresql_model_guacamoleentity_fixture: list[GuacamoleEntity], ) -> None: - backend, session = self.mock_backend() + session = self.mock_session() + backend = self.mock_backend(session=session) backend.add_all(postgresql_model_guacamoleentity_fixture) # Check method calls @@ -83,14 +90,12 @@ def test_add_all( assert len(execute_args) == 1 assert len(execute_args[0]) == len(postgresql_model_guacamoleentity_fixture) - def test_delete( - self, - ) -> None: - backend, session = self.mock_backend() + def test_delete(self) -> None: + session = self.mock_session() + backend = self.mock_backend(session=session) backend.delete(GuacamoleEntity) # Check method calls - print(session.mock_calls) session.query.assert_called_once() session.filter.assert_not_called() session.delete.assert_called_once() @@ -101,16 +106,15 @@ def test_delete( assert len(query_args) == 1 assert isinstance(query_args[0], type(GuacamoleEntity)) - def test_delete_with_filter( - self, - ) -> None: - backend, session = self.mock_backend() + def test_delete_with_filter(self) -> None: + session = self.mock_session() + backend = self.mock_backend(session=session) backend.delete( - GuacamoleEntity, GuacamoleEntity.type == guacamole_entity_type.USER + GuacamoleEntity, + GuacamoleEntity.type == GuacamoleEntityType.USER, ) # Check method calls - print(session.mock_calls) session.query.assert_called_once() session.filter.assert_called_once() session.delete.assert_called_once() @@ -124,22 +128,35 @@ def test_delete_with_filter( assert len(filter_args) == 1 assert isinstance(filter_args[0], BinaryExpression) - def test_execute_commands( - self, - ) -> None: + def test_execute_commands(self) -> None: command = text("SELECT * FROM guacamole_entity;") - backend, session = self.mock_backend() + session = self.mock_session() + backend = self.mock_backend(session=session) backend.execute_commands([command]) session.execute.assert_called_once_with(command) - def test_query( + def test_execute_commands_exception( self, + caplog: pytest.LogCaptureFixture, ) -> None: - backend, session = self.mock_backend() + command = text("SELECT * FROM guacamole_entity;") + session = self.mock_session() + session.execute.side_effect = OperationalError( + statement="exception reason", + params=None, + orig=None, + ) + backend = self.mock_backend(session=session) + with pytest.raises(OperationalError, match="SQL: exception reason"): + backend.execute_commands([command]) + assert "Unable to execute PostgreSQL commands." in caplog.text + + def test_query(self) -> None: + session = self.mock_session() + backend = self.mock_backend(session=session) backend.query(GuacamoleEntity) # Check method calls - print(session.mock_calls) session.query.assert_called_once() session.filter.assert_not_called() session.delete.assert_not_called() @@ -150,14 +167,12 @@ def test_query( assert len(query_args) == 1 assert isinstance(query_args[0], type(GuacamoleEntity)) - def test_query_with_filter( - self, - ) -> None: - backend, session = self.mock_backend() - backend.query(GuacamoleEntity, type=guacamole_entity_type.USER) + def test_query_with_filter(self) -> None: + session = self.mock_session() + backend = self.mock_backend(session=session) + backend.query(GuacamoleEntity, type=GuacamoleEntityType.USER) # Check method calls - print(session.mock_calls) session.query.assert_called_once() session.filter_by.assert_called_once() session.__exit__.assert_called_once() @@ -169,10 +184,12 @@ def test_query_with_filter( filter_by_kwargs = session.filter_by.call_args.kwargs assert len(filter_by_kwargs) == 1 assert "type" in filter_by_kwargs - assert filter_by_kwargs["type"] == guacamole_entity_type.USER + assert filter_by_kwargs["type"] == GuacamoleEntityType.USER class TestPostgreSQLClient: + """Test PostgreSQLClient.""" + client_kwargs: ClassVar[dict[str, Any]] = { "database_name": "database_name", "host_name": "host_name", @@ -188,14 +205,14 @@ def test_constructor(self) -> None: def test_assign_users_to_groups( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], ldap_model_users_fixture: list[LDAPUser], postgresql_model_guacamoleentity_fixture: list[GuacamoleEntity], postgresql_model_guacamoleusergroup_fixture: list[GuacamoleUserGroup], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore + mock_backend = MockPostgreSQLBackend( postgresql_model_guacamoleentity_fixture, postgresql_model_guacamoleusergroup_fixture, ) @@ -205,13 +222,14 @@ def test_assign_users_to_groups( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) client.assign_users_to_groups( - ldap_model_groups_fixture, ldap_model_users_fixture + ldap_model_groups_fixture, + ldap_model_users_fixture, ) for output_line in ( "Ensuring that 2 user(s) are correctly assigned among 3 group(s)", @@ -223,14 +241,14 @@ def test_assign_users_to_groups( def test_assign_users_to_groups_missing_ldap_user( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], ldap_model_users_fixture: list[LDAPUser], postgresql_model_guacamoleentity_fixture: list[GuacamoleEntity], postgresql_model_guacamoleusergroup_fixture: list[GuacamoleUserGroup], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore + mock_backend = MockPostgreSQLBackend( postgresql_model_guacamoleentity_fixture, postgresql_model_guacamoleusergroup_fixture, ) @@ -240,13 +258,14 @@ def test_assign_users_to_groups_missing_ldap_user( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) client.assign_users_to_groups( - ldap_model_groups_fixture, ldap_model_users_fixture[0:1] + ldap_model_groups_fixture, + ldap_model_users_fixture[0:1], ) for output_line in ( "Ensuring that 1 user(s) are correctly assigned among 3 group(s)", @@ -255,19 +274,19 @@ def test_assign_users_to_groups_missing_ldap_user( ): assert output_line in caplog.text - def test_assign_users_to_groups_missing_user_entity( + def test_assign_users_to_groups_missing_user_entity( # noqa: PLR0913 self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], ldap_model_users_fixture: list[LDAPUser], - postgresql_model_guacamoleentity_USER_fixture: list[GuacamoleEntity], - postgresql_model_guacamoleentity_USER_GROUP_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_group_fixture: list[GuacamoleEntity], postgresql_model_guacamoleusergroup_fixture: list[GuacamoleUserGroup], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore - postgresql_model_guacamoleentity_USER_fixture[1:], - postgresql_model_guacamoleentity_USER_GROUP_fixture, + mock_backend = MockPostgreSQLBackend( + postgresql_model_guacamoleentity_user_fixture[1:], + postgresql_model_guacamoleentity_user_group_fixture, postgresql_model_guacamoleusergroup_fixture, ) @@ -276,30 +295,31 @@ def test_assign_users_to_groups_missing_user_entity( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) client.assign_users_to_groups( - ldap_model_groups_fixture, ldap_model_users_fixture + ldap_model_groups_fixture, + ldap_model_users_fixture, ) for output_line in ( - "Could not find entity ID for LDAP user aulus.agerius", + "Could not find entity ID for LDAP user 'aulus.agerius'", ): assert output_line in caplog.text def test_assign_users_to_groups_missing_usergroup( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], ldap_model_users_fixture: list[LDAPUser], postgresql_model_guacamoleentity_fixture: list[GuacamoleEntity], postgresql_model_guacamoleusergroup_fixture: list[GuacamoleUserGroup], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore + mock_backend = MockPostgreSQLBackend( postgresql_model_guacamoleentity_fixture, postgresql_model_guacamoleusergroup_fixture[1:], ) @@ -309,13 +329,14 @@ def test_assign_users_to_groups_missing_usergroup( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) client.assign_users_to_groups( - ldap_model_groups_fixture, ldap_model_users_fixture + ldap_model_groups_fixture, + ldap_model_users_fixture, ) for output_line in ( @@ -325,13 +346,13 @@ def test_assign_users_to_groups_missing_usergroup( ): assert output_line in caplog.text - def test_ensure_schema(self, capsys: Any) -> None: + def test_ensure_schema(self, capsys: pytest.CaptureFixture) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend() # type: ignore + mock_backend = MockPostgreSQLBackend() # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend @@ -414,47 +435,51 @@ def test_ensure_schema(self, capsys: Any) -> None: f"Executing CREATE INDEX IF NOT EXISTS {index_name}" in captured.out ) - def test_ensure_schema_exception(self, capsys: Any) -> None: + def test_ensure_schema_exception(self) -> None: # Create a mock backend - def execute_commands_exception(commands: Any) -> None: + def execute_commands_exception( + commands: list[TextClause], # noqa: ARG001 + ) -> None: raise OperationalError(statement="statement", params=None, orig=None) - mock_backend = MockPostgreSQLBackend() # type: ignore - mock_backend.execute_commands = execute_commands_exception # type: ignore + mock_backend = MockPostgreSQLBackend() + mock_backend.execute_commands = execute_commands_exception # type: ignore[method-assign] # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) with pytest.raises( - PostgreSQLException, match="Unable to ensure PostgreSQL schema." + PostgreSQLError, + match="Unable to ensure PostgreSQL schema.", ): client.ensure_schema(SchemaVersion.v1_5_5) def test_update( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], ldap_model_users_fixture: list[LDAPUser], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend() # type: ignore + mock_backend = MockPostgreSQLBackend() # Capture logs at debug level and above caplog.set_level(logging.DEBUG) # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend client = PostgreSQLClient(**self.client_kwargs) client.update( - groups=ldap_model_groups_fixture, users=ldap_model_users_fixture + groups=ldap_model_groups_fixture, + users=ldap_model_users_fixture, ) for output_line in ( "Ensuring that 3 group(s) are registered", @@ -469,7 +494,7 @@ def test_update( "... 3 user group entit(y|ies) will be added", "There are 3 valid user group entit(y|ies)", "There are 0 user entit(y|ies) currently registered", - "... 0 user entit(y|ies) will be added", + "... 2 user entit(y|ies) will be added", "There are 2 valid user entit(y|ies)", "Ensuring that 2 user(s) are correctly assigned among 3 group(s)", "Working on group 'defendants'", @@ -488,13 +513,13 @@ def test_update( def test_update_group_entities( self, - caplog: Any, - postgresql_model_guacamoleentity_USER_GROUP_fixture: list[GuacamoleEntity], + caplog: pytest.LogCaptureFixture, + postgresql_model_guacamoleentity_user_group_fixture: list[GuacamoleEntity], postgresql_model_guacamoleusergroup_fixture: list[GuacamoleUserGroup], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore - postgresql_model_guacamoleentity_USER_GROUP_fixture, + mock_backend = MockPostgreSQLBackend( + postgresql_model_guacamoleentity_user_group_fixture, postgresql_model_guacamoleusergroup_fixture[0:1], ) @@ -503,7 +528,7 @@ def test_update_group_entities( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend @@ -518,19 +543,19 @@ def test_update_group_entities( def test_update_groups( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_groups_fixture: list[LDAPGroup], - postgresql_model_guacamoleentity_USER_GROUP_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_group_fixture: list[GuacamoleEntity], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore - postgresql_model_guacamoleentity_USER_GROUP_fixture[0:1], + mock_backend = MockPostgreSQLBackend( + postgresql_model_guacamoleentity_user_group_fixture[0:1], [ GuacamoleEntity( entity_id=99, name="to-be-deleted", - type=guacamole_entity_type.USER_GROUP, - ) + type=GuacamoleEntityType.USER_GROUP, + ), ], ) @@ -539,7 +564,7 @@ def test_update_groups( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend @@ -552,14 +577,14 @@ def test_update_groups( def test_update_user_entities( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_users_fixture: list[LDAPUser], postgresql_model_guacamoleuser_fixture: list[GuacamoleUser], - postgresql_model_guacamoleentity_USER_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_fixture: list[GuacamoleEntity], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore - postgresql_model_guacamoleentity_USER_fixture, + mock_backend = MockPostgreSQLBackend( + postgresql_model_guacamoleentity_user_fixture, postgresql_model_guacamoleuser_fixture[0:1], ) @@ -568,7 +593,7 @@ def test_update_user_entities( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend @@ -583,17 +608,19 @@ def test_update_user_entities( def test_update_users( self, - caplog: Any, + caplog: pytest.LogCaptureFixture, ldap_model_users_fixture: list[LDAPUser], - postgresql_model_guacamoleentity_USER_fixture: list[GuacamoleEntity], + postgresql_model_guacamoleentity_user_fixture: list[GuacamoleEntity], ) -> None: # Create a mock backend - mock_backend = MockPostgreSQLBackend( # type: ignore - postgresql_model_guacamoleentity_USER_fixture[0:1], + mock_backend = MockPostgreSQLBackend( + postgresql_model_guacamoleentity_user_fixture[0:1], [ GuacamoleEntity( - entity_id=99, name="to-be-deleted", type=guacamole_entity_type.USER - ) + entity_id=99, + name="to-be-deleted", + type=GuacamoleEntityType.USER, + ), ], ) @@ -602,7 +629,7 @@ def test_update_users( # Patch PostgreSQLBackend with mock.patch( - "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend" + "guacamole_user_sync.postgresql.postgresql_client.PostgreSQLBackend", ) as mock_postgresql_backend: mock_postgresql_backend.return_value = mock_backend