From 361448257d451fb9203147defaad440ec8d3ec60 Mon Sep 17 00:00:00 2001 From: multiflexi Date: Wed, 31 Jul 2024 20:21:29 +0200 Subject: [PATCH 1/2] Fix error when user has no organization --- src/core/model/asset.py | 397 ++++++++++++++++++++++-- src/core/model/notification_template.py | 194 +++++++++++- src/core/model/user.py | 362 +++++++++++++++++++-- 3 files changed, 896 insertions(+), 57 deletions(-) diff --git a/src/core/model/asset.py b/src/core/model/asset.py index bde4a2247..18d313888 100644 --- a/src/core/model/asset.py +++ b/src/core/model/asset.py @@ -1,3 +1,5 @@ +"""Module for Asset model.""" + import uuid from marshmallow import fields, post_load from sqlalchemy import orm, func, or_, text @@ -12,43 +14,135 @@ class NewAssetCpeSchema(AssetCpeSchema): + """A schema class for creating a new AssetCpe object. + + This schema inherits from the AssetCpeSchema class and provides a method + for creating an AssetCpe object from the given data. + + Attributes: + data (dict): The data used to create the AssetCpe object. + + Returns: + AssetCpe: The created AssetCpe object. + """ @post_load def make(self, data, **kwargs): + """Use decorator to create an instance of the AssetCpe class from the provided data. + + Parameters: + data: A dictionary containing the data to initialize the AssetCpe instance. + + Returns: + An instance of the AssetCpe class initialized with the provided data. + """ return AssetCpe(**data) class AssetCpe(db.Model): + """Represents an AssetCpe object. + + Attributes: + id (int): The unique identifier of the AssetCpe. + value (str): The value of the AssetCpe. + asset_id (int): The foreign key referencing the associated Asset object. + + Methods: + __init__(value): Initializes a new instance of the AssetCpe class. + """ + id = db.Column(db.Integer, primary_key=True) value = db.Column(db.String()) - asset_id = db.Column(db.Integer, db.ForeignKey('asset.id')) + asset_id = db.Column(db.Integer, db.ForeignKey("asset.id")) def __init__(self, value): + """Initialize a new instance of the Asset class. + + Parameters: + value: The value of the asset. + """ self.id = None self.value = value class NewAssetSchema(AssetSchema): + """Schema for creating a new asset. + + Attributes: + asset_cpes (List[NewAssetCpeSchema]): A list of nested schemas for asset CPES. + + Methods: + make(data, **kwargs): Post-load method that creates an Asset instance from the given data. + """ + asset_cpes = fields.Nested(NewAssetCpeSchema, many=True) @post_load def make(self, data, **kwargs): + """Post-load method that creates an Asset instance from the given data. + + Parameters: + data (dict): The data to create the Asset instance from. + **kwargs: Additional keyword arguments. + + Returns: + Asset: The created Asset instance. + """ return Asset(**data) class Asset(db.Model): + """Represents an asset in the system. + + Attributes: + id (int): The unique identifier of the asset. + name (str): The name of the asset. + serial (str): The serial number of the asset. + description (str): The description of the asset. + asset_group_id (str): The ID of the asset group that the asset belongs to. + asset_group (AssetGroup): The asset group that the asset belongs to. + asset_cpes (list[AssetCpe]): The list of asset CPEs associated with the asset. + vulnerabilities (list[AssetVulnerability]): The list of vulnerabilities associated with the asset. + vulnerabilities_count (int): The count of vulnerabilities associated with the asset. + + Methods: + __init__(self, id, name, serial, description, asset_group_id, asset_cpes): Initializes a new instance of the Asset class. + reconstruct(self): Reconstructs the asset object after it is loaded from the database. + get_by_cpe(cls, cpes): Retrieves a list of assets based on the given CPEs. + remove_vulnerability(cls, report_item_id): Removes a vulnerability associated with the asset. + add_vulnerability(self, report_item): Adds a vulnerability to the asset. + update_vulnerabilities(self): Updates the vulnerabilities associated with the asset. + solve_vulnerability(cls, user, group_id, asset_id, report_item_id, solved): Solves or unsolves a vulnerability associated with the + asset. + get(cls, group_id, search, sort, vulnerable): Retrieves a list of assets based on the given criteria. + get_all_json(cls, user, group_id, search, sort, vulnerable): Retrieves a list of assets in JSON format based on the given criteria. + add(cls, user, group_id, data): Adds a new asset to the system. + update(cls, user, group_id, asset_id, data): Updates an existing asset in the system. + delete(cls, user, group_id, id): Deletes an asset from the system. + """ + id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(), nullable=False) serial = db.Column(db.String()) description = db.Column(db.String()) - asset_group_id = db.Column(db.String, db.ForeignKey('asset_group.id')) + asset_group_id = db.Column(db.String, db.ForeignKey("asset_group.id")) asset_group = db.relationship("AssetGroup") asset_cpes = db.relationship("AssetCpe", cascade="all, delete-orphan") vulnerabilities = db.relationship("AssetVulnerability", cascade="all, delete-orphan") vulnerabilities_count = db.Column(db.Integer, default=0) def __init__(self, id, name, serial, description, asset_group_id, asset_cpes): + """Initialize a new instance of the Asset class. + + Parameters: + id (int): The ID of the asset. + name (str): The name of the asset. + serial (str): The serial number of the asset. + description (str): The description of the asset. + asset_group_id (int): The ID of the asset group the asset belongs to. + asset_cpes (list): A list of Common Platform Enumeration (CPE) values associated with the asset. + """ self.id = None self.name = name self.serial = serial @@ -61,16 +155,31 @@ def __init__(self, id, name, serial, description, asset_group_id, asset_cpes): @orm.reconstructor def reconstruct(self): + """Use decorator to create an instance of the AssetCpe class from the provided data. + + Parameters: + data: A dictionary containing the data to initialize the AssetCpe instance. + + Returns: + An instance of the AssetCpe class initialized with the provided data. + """ self.title = self.name self.subtitle = self.description self.tag = "mdi-laptop" @classmethod def get_by_cpe(cls, cpes): + """Get assets by Common Platform Enumeration (CPE). + Parameters: + cpes: A list of CPE values. + + Returns: + A list of assets matching the given CPE values. + """ if len(cpes) > 0: query_string = "SELECT DISTINCT asset_id FROM asset_cpe WHERE value LIKE ANY(:cpes) OR {}" - params = {'cpes': cpes} + params = {"cpes": cpes} inner_query = "" for i in range(len(cpes)): @@ -88,12 +197,22 @@ def get_by_cpe(cls, cpes): @classmethod def remove_vulnerability(cls, report_item_id): + """Remove a vulnerability from the asset. + + Parameters: + report_item_id: The ID of the report item associated with the vulnerability. + """ vulnerabilities = AssetVulnerability.get_by_report(report_item_id) for vulnerability in vulnerabilities: vulnerability.asset.vulnerabilities_count -= 1 db.session.delete(vulnerability) def add_vulnerability(self, report_item): + """Add a vulnerability to the asset. + + Parameters: + report_item: The report item representing the vulnerability. + """ for vulnerability in self.vulnerabilities: if vulnerability.report_item.id == report_item.id: return @@ -103,6 +222,7 @@ def add_vulnerability(self, report_item): self.vulnerabilities_count += 1 def update_vulnerabilities(self): + """Update the vulnerabilities associated with the asset.""" cpes = [] for cpe in self.asset_cpes: cpes.append(cpe.value) @@ -126,6 +246,15 @@ def update_vulnerabilities(self): @classmethod def solve_vulnerability(cls, user, group_id, asset_id, report_item_id, solved): + """Solves a vulnerability for a specific asset. + + Parameters: + user (User): The user performing the action. + group_id (int): The ID of the asset group. + asset_id (int): The ID of the asset. + report_item_id (int): The ID of the report item. + solved (bool): Indicates whether the vulnerability is solved or not. + """ asset = cls.query.get(asset_id) if AssetGroup.access_allowed(user, asset.asset_group_id): for vulnerability in asset.vulnerabilities: @@ -141,6 +270,18 @@ def solve_vulnerability(cls, user, group_id, asset_id, report_item_id, solved): @classmethod def get(cls, group_id, search, sort, vulnerable): + """Retrieve assets based on the provided parameters. + + Parameters: + group_id (int): The ID of the asset group. + search (str): The search string to filter assets by name, description, serial, or CPE value. + sort (str): The sorting option for the assets. Can be "ALPHABETICAL" or "VULNERABILITIES_COUNT". + vulnerable (str): Flag to filter assets by vulnerability count. Can be "true" or None. + + Returns: + assets (list): A list of assets that match the provided parameters. + count (int): The total count of assets that match the provided parameters. + """ query = cls.query.filter(Asset.asset_group_id == group_id) if vulnerable is not None: @@ -148,12 +289,15 @@ def get(cls, group_id, search, sort, vulnerable): query = query.filter(Asset.vulnerabilities_count > 0) if search is not None: - search_string = '%' + search.lower() + '%' - query = query.join(AssetCpe, Asset.id == AssetCpe.asset_id).filter(or_( - func.lower(Asset.name).like(search_string), - func.lower(Asset.description).like(search_string), - func.lower(Asset.serial).like(search_string), - func.lower(AssetCpe.value).like(search_string))) + search_string = "%" + search.lower() + "%" + query = query.join(AssetCpe, Asset.id == AssetCpe.asset_id).filter( + or_( + func.lower(Asset.name).like(search_string), + func.lower(Asset.description).like(search_string), + func.lower(Asset.serial).like(search_string), + func.lower(AssetCpe.value).like(search_string), + ) + ) if sort is not None: if sort == "ALPHABETICAL": @@ -165,14 +309,33 @@ def get(cls, group_id, search, sort, vulnerable): @classmethod def get_all_json(cls, user, group_id, search, sort, vulnerable): + """Get all assets in JSON format. + + Parameters: + user (User): The user object. + group_id (int): The ID of the asset group. + search (str): The search query for filtering assets. + sort (str): The sorting criteria for assets. + vulnerable (bool): Flag indicating whether to include vulnerable assets. + + Returns: + dict: A dictionary containing the total count of assets and a list of asset items in JSON format. + """ if AssetGroup.access_allowed(user, group_id): assets, count = cls.get(group_id, search, sort, vulnerable) asset_schema = AssetPresentationSchema(many=True) items = asset_schema.dump(assets) - return {'total_count': count, 'items': items} + return {"total_count": count, "items": items} @classmethod def add(cls, user, group_id, data): + """Add a new asset to the database. + + Parameters: + user (User): The user adding the asset. + group_id (int): The ID of the asset group to which the asset belongs. + data (dict): The data of the asset to be added. + """ schema = NewAssetSchema() asset = schema.load(data) asset.asset_group_id = group_id @@ -183,6 +346,14 @@ def add(cls, user, group_id, data): @classmethod def update(cls, user, group_id, asset_id, data): + """Update an asset with the provided data. + + Parameters: + user (User): The user performing the update. + group_id (int): The ID of the asset group. + asset_id (int): The ID of the asset to update. + data (dict): The data to update the asset with. + """ asset = cls.query.get(asset_id) if AssetGroup.access_allowed(user, asset.asset_group_id): schema = NewAssetSchema() @@ -196,6 +367,13 @@ def update(cls, user, group_id, asset_id, data): @classmethod def delete(cls, user, group_id, id): + """Delete an asset. + + Parameters: + user (User): The user performing the delete operation. + group_id (int): The ID of the asset group. + id (int): The ID of the asset to be deleted. + """ asset = cls.query.get(id) if AssetGroup.access_allowed(user, asset.asset_group_id): db.session.delete(asset) @@ -203,32 +381,106 @@ def delete(cls, user, group_id, id): class AssetVulnerability(db.Model): + """Represents a vulnerability associated with an asset. + + Attributes: + id (int): The unique identifier of the vulnerability. + solved (bool): Indicates whether the vulnerability has been solved or not. + asset_id (int): The ID of the asset associated with the vulnerability. + report_item_id (int): The ID of the report item associated with the vulnerability. + report_item (ReportItem): The report item associated with the vulnerability. + + Methods: + __init__(asset_id, report_item_id): Initializes a new instance of the AssetVulnerability class. + get_by_report(report_id): Retrieves all vulnerabilities associated with a specific report. + """ + id = db.Column(db.Integer, primary_key=True) solved = db.Column(db.Boolean, default=False) - asset_id = db.Column(db.Integer, db.ForeignKey('asset.id')) - report_item_id = db.Column(db.Integer, db.ForeignKey('report_item.id')) + asset_id = db.Column(db.Integer, db.ForeignKey("asset.id")) + report_item_id = db.Column(db.Integer, db.ForeignKey("report_item.id")) report_item = db.relationship("ReportItem") def __init__(self, asset_id, report_item_id): + """Initialize a new instance of the Asset class. + + Parameters: + asset_id (int): The ID of the asset. + report_item_id (int): The ID of the report item. + """ self.id = None self.asset_id = asset_id self.report_item_id = report_item_id @classmethod def get_by_report(cls, report_id): + """Get assets by report ID. + + Parameters: + report_id (int): The ID of the report. + + Returns: + List[Asset]: A list of assets associated with the given report ID. + """ return cls.query.filter_by(report_item_id=report_id).all() class NewAssetGroupGroupSchema(AssetGroupSchema): + """Schema for creating a new asset group with additional fields. + + Attributes: + users (list): A list of user IDs associated with the asset group. + templates (list): A list of notification template IDs associated with the asset group. + + Methods: + make(data, **kwargs): Post-load method that creates an AssetGroup object from the given data. + + Returns: + AssetGroup: An instance of the AssetGroup class. + """ + users = fields.Nested(UserIdSchema, many=True) templates = fields.Nested(NotificationTemplateIdSchema, many=True) @post_load def make(self, data, **kwargs): + """Use decorator to create an instance of the AssetGroup class from the provided data. + + Parameters: + data: A dictionary containing the data to initialize the AssetGroup instance. + + Returns: + An instance of the AssetGroup class. + """ return AssetGroup(**data) class AssetGroup(db.Model): + """AssetGroup class represents a group of assets in the system. + + Attributes: + id (str): The unique identifier of the asset group. + name (str): The name of the asset group. + description (str): The description of the asset group. + templates (list): The list of notification templates associated with the asset group. + organizations (list): The list of organizations associated with the asset group. + users (list): The list of users associated with the asset group. + title (str): The title of the asset group. + subtitle (str): The subtitle of the asset group. + tag (str): The tag of the asset group. + + Methods: + __init__(self, id, name, description, users, templates): Initializes a new instance of the AssetGroup class. + reconstruct(self): Reconstructs the asset group object. + find(cls, group_id): Finds an asset group by its ID. + access_allowed(cls, user, group_id): Checks if the user has access to the asset group. + get(cls, search, organization): Retrieves asset groups based on search criteria and organization. + get_all_json(cls, user, search): Retrieves all asset groups in JSON format. + add(cls, user, data): Adds a new asset group. + delete(cls, user, group_id): Deletes an asset group. + update(cls, user, group_id, data): Updates an asset group. + """ + id = db.Column(db.String(64), primary_key=True) name = db.Column(db.String(), nullable=False) description = db.Column(db.String()) @@ -239,6 +491,15 @@ class AssetGroup(db.Model): users = db.relationship("User", secondary="asset_group_user") def __init__(self, id, name, description, users, templates): + """Initialize an instance of the Asset class. + + Parameters: + id (str): The unique identifier of the asset. + name (str): The name of the asset. + description (str): The description of the asset. + users (list): A list of User objects associated with the asset. + templates (list): A list of NotificationTemplate objects associated with the asset. + """ self.id = str(uuid.uuid4()) self.name = name self.description = description @@ -257,40 +518,84 @@ def __init__(self, id, name, description, users, templates): @orm.reconstructor def reconstruct(self): + """Reconstruct the asset object. + + This method is called when the asset object is loaded from the database. It sets the `title` attribute to the value of the `name` + attribute and the `subtitle` attribute to the value of the `description` attribute. It also sets the `tag` attribute + to "mdi-folder-multiple". + """ self.title = self.name self.subtitle = self.description self.tag = "mdi-folder-multiple" @classmethod def find(cls, group_id): + """Find an asset group by its ID. + + Parameters: + group_id (int): The ID of the asset group to find. + + Returns: + group (AssetGroup): The asset group with the specified ID. + """ group = cls.query.get(group_id) return group @classmethod def access_allowed(cls, user, group_id): + """ + Check if the access is allowed for a user in a specific group. + + Parameters: + user: The user object representing the user. + group_id: The ID of the group to check access for. + + Returns: + True if the access is allowed, False otherwise. + """ group = cls.query.get(group_id) return any(org in user.organizations for org in group.organizations) @classmethod def get(cls, search, organization): + """ + Get assets based on search criteria and organization. + + Parameters: + search (str): A string representing the search criteria. + organization: An organization object to filter the assets by. + + Returns: + A tuple containing a list of assets and the count of assets. + """ query = cls.query if organization is not None: query = query.join(AssetGroupOrganization, AssetGroup.id == AssetGroupOrganization.asset_group_id) if search is not None: - search_string = '%' + search.lower() + '%' - query = query.filter(or_( - func.lower(AssetGroup.name).like(search_string), - func.lower(AssetGroup.description).like(search_string))) + search_string = "%" + search.lower() + "%" + query = query.filter(or_(func.lower(AssetGroup.name).like(search_string), func.lower(AssetGroup.description).like(search_string))) return query.order_by(db.asc(AssetGroup.name)).all(), query.count() @classmethod def get_all_json(cls, user, search): - groups, count = cls.get(search, user.organizations[0]) + """Get all assets in JSON format. + + Parameters: + user (User): The user object. + search (str): The search query. + + Returns: + dict: A dictionary containing the total count of assets and a list of asset groups in JSON format. + """ + if user.organizations: + groups, count = cls.get(search, user.organizations[0]) + else: + return {"total_count": 0, "items": []} permissions = user.get_permissions() - if 'MY_ASSETS_CONFIG' not in permissions: + if "MY_ASSETS_CONFIG" not in permissions: for group in groups[:]: if len(group.users) > 0: found = False @@ -304,10 +609,16 @@ def get_all_json(cls, user, search): count -= 1 group_schema = AssetGroupPresentationSchema(many=True) - return {'total_count': count, 'items': group_schema.dump(groups)} + return {"total_count": count, "items": group_schema.dump(groups)} @classmethod def add(cls, user, data): + """Add a new asset group to the database. + + Parameters: + user: The user object representing the user adding the asset group. + data: The data containing the information for the new asset group. + """ new_group_schema = NewAssetGroupGroupSchema() group = new_group_schema.load(data) group.organizations = user.organizations @@ -324,6 +635,13 @@ def add(cls, user, data): @classmethod def delete(cls, user, group_id): + """Delete a group if the user belongs to any of the organizations associated with the group. + + Parameters: + cls (class): The class object. + user (User): The user object. + group_id (int): The ID of the group to be deleted. + """ group = cls.query.get(group_id) if any(org in user.organizations for org in group.organizations): db.session.delete(group) @@ -331,6 +649,14 @@ def delete(cls, user, group_id): @classmethod def update(cls, user, group_id, data): + """Update an asset group with the provided data. + + Parameters: + cls: The class object. + user: The user performing the update. + group_id: The ID of the asset group to update. + data: The data to update the asset group with. + """ new_group_schema = NewAssetGroupGroupSchema() updated_group = new_group_schema.load(data) group = cls.query.get(group_id) @@ -351,15 +677,36 @@ def update(cls, user, group_id, data): class AssetGroupOrganization(db.Model): - asset_group_id = db.Column(db.String, db.ForeignKey('asset_group.id'), primary_key=True) - organization_id = db.Column(db.Integer, db.ForeignKey('organization.id'), primary_key=True) + """AssetGroupOrganization represents the relationship between an asset group and an organization. + + Attributes: + asset_group_id (str): The ID of the asset group. + organization_id (int): The ID of the organization. + """ + + asset_group_id = db.Column(db.String, db.ForeignKey("asset_group.id"), primary_key=True) + organization_id = db.Column(db.Integer, db.ForeignKey("organization.id"), primary_key=True) class AssetGroupUser(db.Model): - asset_group_id = db.Column(db.String, db.ForeignKey('asset_group.id'), primary_key=True) - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), primary_key=True) + """AssetGroupUser model represents the association between an AssetGroup and a User. + + Attributes: + asset_group_id (str): The ID of the associated AssetGroup. + user_id (int): The ID of the associated User. + """ + + asset_group_id = db.Column(db.String, db.ForeignKey("asset_group.id"), primary_key=True) + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), primary_key=True) class AssetGroupNotificationTemplate(db.Model): - asset_group_id = db.Column(db.String, db.ForeignKey('asset_group.id'), primary_key=True) - notification_template_id = db.Column(db.Integer, db.ForeignKey('notification_template.id'), primary_key=True) + """AssetGroupNotificationTemplate model represents the association between an Asset Group and a Notification Template. + + Attributes: + asset_group_id (str): The ID of the associated Asset Group. + notification_template_id (int): The ID of the associated Notification Template. + """ + + asset_group_id = db.Column(db.String, db.ForeignKey("asset_group.id"), primary_key=True) + notification_template_id = db.Column(db.Integer, db.ForeignKey("notification_template.id"), primary_key=True) diff --git a/src/core/model/notification_template.py b/src/core/model/notification_template.py index 072f7edf1..9dbcef91a 100644 --- a/src/core/model/notification_template.py +++ b/src/core/model/notification_template.py @@ -1,3 +1,5 @@ +"""Module for NotificationTemplate model.""" + from sqlalchemy import orm, func, or_ from marshmallow import post_load, fields @@ -6,34 +8,115 @@ class NewEmailRecipientSchema(EmailRecipientSchema): + """This class represents a schema for creating a new email recipient. + + Attributes: + Inherits EmailRecipientSchema. + Methods: + make(data, **kwargs): A method decorated with @post_load that creates an EmailRecipient instance from the given data. + Returns: + An instance of EmailRecipient. + """ @post_load def make(self, data, **kwargs): + """Create an instance of EmailRecipient using the provided data. + + Parameters: + data (dict): A dictionary containing the data for creating the EmailRecipient instance. + **kwargs: Additional keyword arguments. + Returns: + EmailRecipient: An instance of EmailRecipient created using the provided data. + """ return EmailRecipient(**data) class EmailRecipient(db.Model): + """Represents an email recipient. + + Attributes: + id (int): The unique identifier of the recipient. + email (str): The email address of the recipient. + name (str): The name of the recipient. + notification_template_id (int): The ID of the associated notification template. + Methods: + __init__(email, name): Initializes a new instance of the EmailRecipient class. + """ + id = db.Column(db.Integer, primary_key=True) email = db.Column(db.String(), nullable=False) name = db.Column(db.String()) - notification_template_id = db.Column(db.Integer, db.ForeignKey('notification_template.id')) + notification_template_id = db.Column(db.Integer, db.ForeignKey("notification_template.id")) def __init__(self, email, name): + """Initialize a NotificationTemplate object. + + Parameters: + email (str): The email address associated with the template. + name (str): The name of the template. + Attributes: + id (None): The ID of the template (initially set to None). + email (str): The email address associated with the template. + name (str): The name of the template. + """ self.id = None self.email = email self.name = name class NewNotificationTemplateSchema(NotificationTemplateSchema): + """NewNotificationTemplateSchema class is a schema for creating a new notification template. + + Attributes: + recipients (list): A list of NewEmailRecipientSchema objects representing the recipients of the notification. + Methods: + make(data, **kwargs): A post-load method that creates a NotificationTemplate object from the given data. + """ + recipients = fields.Nested(NewEmailRecipientSchema, many=True) @post_load def make(self, data, **kwargs): + """Create a new `NotificationTemplate` object based on the given data. + + Parameters: + data (dict): A dictionary containing the data for the notification template. + **kwargs: Additional keyword arguments. + Returns: + NotificationTemplate: A new `NotificationTemplate` object. + """ return NotificationTemplate(**data) class NotificationTemplate(db.Model): + """NotificationTemplate class represents a template for notifications. + + Attributes: + id (int): The unique identifier of the template. + name (str): The name of the template. + description (str): The description of the template. + message_title (str): The title of the notification message. + message_body (str): The body of the notification message. + recipients (list): The list of email recipients for the notification. + organizations (list): The list of organizations associated with the template. + Methods: + __init__(self, id, name, description, message_title, message_body, recipients): + Initializes a new instance of the NotificationTemplate class. + find(cls, id): + Finds a notification template by its ID. + get(cls, search, organization): + Retrieves notification templates based on search criteria and organization. + get_all_json(cls, user, search): + Retrieves all notification templates in JSON format. + add(cls, user, data): + Adds a new notification template. + delete(cls, user, template_id): + Deletes a notification template. + update(cls, user, template_id, data): + Updates a notification template. + """ + id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(), nullable=False) description = db.Column(db.String()) @@ -45,6 +128,26 @@ class NotificationTemplate(db.Model): organizations = db.relationship("Organization", secondary="notification_template_organization") def __init__(self, id, name, description, message_title, message_body, recipients): + """Initialize a NotificationTemplate object. + + Parameters: + id (int): The ID of the notification template. + name (str): The name of the notification template. + description (str): The description of the notification template. + message_title (str): The title of the notification message. + message_body (str): The body of the notification message. + recipients (list): A list of recipients for the notification. + Attributes: + id (int): The ID of the notification template. + name (str): The name of the notification template. + description (str): The description of the notification template. + message_title (str): The title of the notification message. + message_body (str): The body of the notification message. + recipients (list): A list of recipients for the notification. + title (str): The title of the notification template. + subtitle (str): The subtitle of the notification template. + tag (str): The tag of the notification template. + """ self.id = None self.name = name self.description = description @@ -57,39 +160,86 @@ def __init__(self, id, name, description, message_title, message_body, recipient @orm.reconstructor def reconstruct(self): + """Reconstruct the notification template. + + This method updates the title, subtitle, and tag attributes of the notification template object. + The title is set to the name attribute, the subtitle is set to the description attribute, + and the tag is set to "mdi-email-outline". + """ self.title = self.name self.subtitle = self.description self.tag = "mdi-email-outline" @classmethod def find(cls, id): + """Find a notification template by its ID. + + Parameters: + cls: The class object. + id: The ID of the notification template. + Returns: + The notification template with the specified ID. + """ group = cls.query.get(id) return group @classmethod def get(cls, search, organization): + """Retrieve notification templates based on search criteria and organization. + + Parameters: + search (str): The search string to filter notification templates by name or description. + organization (str): The organization to filter notification templates. + Returns: + tuple: A tuple containing: + A list of notification templates matching the search criteria and organization. + The count of notification templates matching the search criteria and organization. + """ query = cls.query if organization is not None: - query = query.join(NotificationTemplateOrganization, - NotificationTemplate.id == NotificationTemplateOrganization.notification_template_id) + query = query.join( + NotificationTemplateOrganization, NotificationTemplate.id == NotificationTemplateOrganization.notification_template_id + ) if search is not None: - search_string = '%' + search.lower() + '%' - query = query.filter(or_( - func.lower(NotificationTemplate.name).like(search_string), - func.lower(NotificationTemplate.description).like(search_string))) + search_string = "%" + search.lower() + "%" + query = query.filter( + or_( + func.lower(NotificationTemplate.name).like(search_string), + func.lower(NotificationTemplate.description).like(search_string), + ) + ) return query.order_by(db.asc(NotificationTemplate.name)).all(), query.count() @classmethod def get_all_json(cls, user, search): - templates, count = cls.get(search, user.organizations[0]) + """Retrieve all notification templates in JSON format. + + Parameters: + cls (class): The class itself. + user (User): The user object. + search (str): The search query. + Returns: + dict: A dictionary containing the total count and a list of template items in JSON format. + """ + if user.organizations: + templates, count = cls.get(search, user.organizations[0]) + else: + return {"total_count": 0, "items": []} template_schema = NotificationTemplatePresentationSchema(many=True) - return {'total_count': count, 'items': template_schema.dump(templates)} + return {"total_count": count, "items": template_schema.dump(templates)} @classmethod def add(cls, user, data): + """Add a new notification template to the database. + + Parameters: + cls: The class object. + user: The user object. + data: The data for the new notification template. + """ new_template_schema = NewNotificationTemplateSchema() template = new_template_schema.load(data) template.organizations = user.organizations @@ -98,6 +248,13 @@ def add(cls, user, data): @classmethod def delete(cls, user, template_id): + """Delete a notification template. + + Parameters: + cls (class): The class itself. + user (User): The user performing the delete operation. + template_id (int): The ID of the template to be deleted. + """ template = cls.query.get(template_id) if any(org in user.organizations for org in template.organizations): db.session.delete(template) @@ -105,6 +262,14 @@ def delete(cls, user, template_id): @classmethod def update(cls, user, template_id, data): + """Update a notification template. + + Parameters: + cls: The class object. + user: The user performing the update. + template_id: The ID of the template to update. + data: The updated template data. + """ new_template_schema = NewNotificationTemplateSchema() updated_template = new_template_schema.load(data) template = cls.query.get(template_id) @@ -118,5 +283,12 @@ def update(cls, user, template_id, data): class NotificationTemplateOrganization(db.Model): - notification_template_id = db.Column(db.Integer, db.ForeignKey('notification_template.id'), primary_key=True) - organization_id = db.Column(db.Integer, db.ForeignKey('organization.id'), primary_key=True) + """Model class representing the association table between NotificationTemplate and Organization. + + Attributes: + notification_template_id (int): The ID of the notification template. + organization_id (int): The ID of the organization. + """ + + notification_template_id = db.Column(db.Integer, db.ForeignKey("notification_template.id"), primary_key=True) + organization_id = db.Column(db.Integer, db.ForeignKey("organization.id"), primary_key=True) diff --git a/src/core/model/user.py b/src/core/model/user.py index 10fdeb7d2..fd0f19808 100644 --- a/src/core/model/user.py +++ b/src/core/model/user.py @@ -1,3 +1,5 @@ +"""User module for the model representing a user in the system.""" + from marshmallow import fields, post_load from sqlalchemy import func, or_, orm @@ -13,37 +15,130 @@ class NewUserSchema(UserSchemaBase): + """NewUserSchema class for defining the schema of a new user. + + Attributes: + roles (Nested): A nested field representing the roles of the user. + permissions (Nested): A nested field representing the permissions of the user. + organizations (Nested): A nested field representing the organizations the user belongs to. + Methods: + make(self, data, **kwargs): A post-load method that creates a User object from the given data. + Returns: + User: A User object created from the given data. + """ + roles = fields.Nested(RoleIdSchema, many=True) permissions = fields.Nested(PermissionIdSchema, many=True) organizations = fields.Nested(OrganizationIdSchema, many=True) @post_load def make(self, data, **kwargs): + """Create a new User object based on the provided data. + + Parameters: + data (dict): A dictionary containing the user data. + **kwargs: Additional keyword arguments. + Returns: + User: A new User object initialized with the provided data. + """ return User(**data) class UpdateUserSchema(UserSchemaBase): - password = fields.Str(load_default=None, allow_none=True) + """Schema for updating user information. + + Attributes: + password (str): The user's password. If not provided, the password will not be updated. + roles (list): A list of role IDs assigned to the user. + permissions (list): A list of permission IDs assigned to the user. + organizations (list): A list of organization IDs associated with the user. + """ + password = fields.Str(load_default=None, allow_none=True) roles = fields.Nested(RoleIdSchema, many=True) permissions = fields.Nested(PermissionIdSchema, many=True) organizations = fields.Nested(OrganizationIdSchema, many=True) class User(db.Model): + """User class represents a user in the system. + + Attributes: + id (int): The unique identifier of the user. + username (str): The username of the user. + name (str): The name of the user. + password (str): The password of the user. + organizations (list): The organizations the user belongs to. + roles (list): The roles assigned to the user. + permissions (list): The permissions granted to the user. + profile_id (int): The ID of the user's profile. + profile (UserProfile): The profile of the user. + title (str): The title of the user. + subtitle (str): The subtitle of the user. + tag (str): The tag associated with the user. + Methods: + __init__(self, id, username, name, password, organizations, roles, permissions): + Initializes a new User object. + reconstruct(self): + Reconstructs the user object. + find(cls, username): + Finds a user by username. + find_by_id(cls, user_id): + Finds a user by ID. + get_all(cls): + Retrieves all users. + get(cls, search, organization): + Retrieves users based on search criteria and organization. + get_all_json(cls, search): + Retrieves all users in JSON format. + get_all_external_json(cls, user, search): + Retrieves all users in JSON format for an external user. + add_new(cls, data): + Adds a new user. + add_new_external(cls, user, permissions, data): + Adds a new external user. + update(cls, user_id, data): + Updates a user. + update_external(cls, user, permissions, user_id, data): + Updates an external user. + delete(cls, id): + Deletes a user. + delete_external(cls, user, id): + Deletes an external user. + get_permissions(self): + Retrieves all permissions of the user. + get_current_organization_name(self): + Retrieves the name of the current organization the user belongs to. + get_profile_json(cls, user): + Retrieves the user's profile in JSON format. + update_profile(cls, user, data): + Updates the user's profile. + """ + id = db.Column(db.Integer, primary_key=True) username = db.Column(db.String(64), unique=True, nullable=False) name = db.Column(db.String(), nullable=False) password = db.Column(db.String(), nullable=False) organizations = db.relationship("Organization", secondary="user_organization") - roles = db.relationship(Role, secondary='user_role') - permissions = db.relationship(Permission, secondary='user_permission') + roles = db.relationship(Role, secondary="user_role") + permissions = db.relationship(Permission, secondary="user_permission") - profile_id = db.Column(db.Integer, db.ForeignKey('user_profile.id')) + profile_id = db.Column(db.Integer, db.ForeignKey("user_profile.id")) profile = db.relationship("UserProfile", cascade="all") def __init__(self, id, username, name, password, organizations, roles, permissions): + """Initialize a User object with the given parameters. + + Parameters: + id (int): The user's ID. + username (str): The user's username. + name (str): The user's name. + password (str): The user's password. + organizations (list): A list of organizations the user belongs to. + roles (list): A list of roles assigned to the user. + permissions (list): A list of permissions granted to the user. + """ self.id = None self.username = username self.name = name @@ -70,53 +165,113 @@ def __init__(self, id, username, name, password, organizations, roles, permissio @orm.reconstructor def reconstruct(self): + """Reconstruct the user object. + + This method updates the `title`, `subtitle`, and `tag` attributes of the user object + based on the current `name` and `username` values. + """ self.title = self.name self.subtitle = self.username self.tag = "mdi-account" @classmethod def find(cls, username): + """Find a user by their username. + + Parameters: + cls: The class object. + username: The username of the user to find. + Returns: + The user object if found, None otherwise. + """ user = cls.query.filter_by(username=username).first() return user @classmethod def find_by_id(cls, user_id): + """Find a user by their ID. + + Parameters: + cls: The class object. + user_id: The ID of the user to find. + Returns: + The user object if found, None otherwise. + """ user = cls.query.get(user_id) return user @classmethod def get_all(cls): + """Retrieve all instances of the User class from the database. + + Returns: + list: A list of User instances, ordered by name in ascending order. + """ return cls.query.order_by(db.asc(User.name)).all() @classmethod def get(cls, search, organization): + """Retrieve users based on search criteria and organization. + + Parameters: + cls: The class object. + search (str): The search string to filter users by name or username. + organization: The organization to filter users by. + Returns: + tuple: A tuple containing two elements: + A list of users matching the search criteria and organization, ordered by name. + The total count of users matching the search criteria and organization. + """ query = cls.query if organization is not None: query = query.join(UserOrganization, User.id == UserOrganization.user_id) if search is not None: - search_string = '%' + search.lower() + '%' - query = query.filter(or_( - func.lower(User.name).like(search_string), - func.lower(User.username).like(search_string))) + search_string = "%" + search.lower() + "%" + query = query.filter(or_(func.lower(User.name).like(search_string), func.lower(User.username).like(search_string))) return query.order_by(db.asc(User.name)).all(), query.count() @classmethod def get_all_json(cls, search): + """Retrieve all users matching the given search criteria and returns them as a JSON object. + + Parameters: + cls: The class object. + search: The search criteria. + Returns: + A JSON object containing the total count of users and a list of user items matching the search criteria. + """ users, count = cls.get(search, None) user_schema = UserPresentationSchema(many=True) - return {'total_count': count, 'items': user_schema.dump(users)} + return {"total_count": count, "items": user_schema.dump(users)} @classmethod def get_all_external_json(cls, user, search): - users, count = cls.get(search, user.organizations[0]) + """Retrieve all external JSON data for a given user. + + Parameters: + cls (class): The class object. + user (User): The user object. + search (str): The search query. + Returns: + dict: A dictionary containing the total count and items of the retrieved data. + """ + if user.organizations: + users, count = cls.get(search, user.organizations[0]) + else: + return {"total_count": 0, "items": []} user_schema = UserPresentationSchema(many=True) - return {'total_count': count, 'items': user_schema.dump(users)} + return {"total_count": count, "items": user_schema.dump(users)} @classmethod def add_new(cls, data): + """Add a new user to the database. + + Parameters: + data: A dictionary containing the user data. + """ new_user_schema = NewUserSchema() user = new_user_schema.load(data) db.session.add(user) @@ -124,6 +279,14 @@ def add_new(cls, data): @classmethod def add_new_external(cls, user, permissions, data): + """Add a new external user to the system. + + Parameters: + cls: The class object. + user: The user object. + permissions: The list of permissions. + data: The data for the new user. + """ new_user_schema = NewUserSchema() new_user = new_user_schema.load(data) new_user.roles = [] @@ -138,6 +301,13 @@ def add_new_external(cls, user, permissions, data): @classmethod def update(cls, user_id, data): + """Update a user with the given user_id using the provided data. + + Parameters: + cls (class): The class object. + user_id (int): The ID of the user to be updated. + data (dict): The data containing the updated user information. + """ schema = UpdateUserSchema() updated_user = schema.load(data) user = cls.query.get(user_id) @@ -164,6 +334,15 @@ def update(cls, user_id, data): @classmethod def update_external(cls, user, permissions, user_id, data): + """Update an external user with the provided data. + + Parameters: + cls (class): The class object. + user (User): The current user performing the update. + permissions (list): The list of permissions. + user_id (int): The ID of the user to be updated. + data (dict): The data to update the user with. + """ schema = NewUserSchema() updated_user = schema.load(data) existing_user = cls.query.get(user_id) @@ -182,18 +361,36 @@ def update_external(cls, user, permissions, user_id, data): @classmethod def delete(cls, id): + """Delete a user from the database. + + Parameters: + cls (class): The class representing the user model. + id (int): The ID of the user to be deleted. + """ user = cls.query.get(id) db.session.delete(user) db.session.commit() @classmethod def delete_external(cls, user, id): + """Delete an external user from the database. + + Parameters: + cls (class): The class object. + user (User): The user performing the deletion. + id (int): The ID of the user to be deleted. + """ existing_user = cls.query.get(id) if any(org in user.organizations for org in existing_user.organizations): db.session.delete(existing_user) db.session.commit() def get_permissions(self): + """Return a list of all permissions associated with the user. + + Returns: + list: A list of permission IDs. + """ all_permissions = set() for permission in self.permissions: @@ -205,6 +402,11 @@ def get_permissions(self): return list(all_permissions) def get_current_organization_name(self): + """Return the name of the current organization. + + Returns: + str: The name of the current organization. If no organization is available, an empty string is returned. + """ if len(self.organizations) > 0: return self.organizations[0].name else: @@ -212,11 +414,28 @@ def get_current_organization_name(self): @classmethod def get_profile_json(cls, user): + """Return the JSON representation of a user's profile. + + Parameters: + cls (class): The class object. + user (User): The user object. + Returns: + dict: The JSON representation of the user's profile. + """ profile_schema = UserProfileSchema() return profile_schema.dump(user.profile) @classmethod def update_profile(cls, user, data): + """Update the user's profile with the provided data. + + Parameters: + cls (class): The class object. + user (User): The user object to update the profile for. + data (dict): The data containing the updated profile information. + Returns: + dict: The updated profile information in JSON format. + """ new_profile_schema = NewUserProfileSchema() updated_profile = new_profile_schema.load(data) @@ -225,6 +444,7 @@ def update_profile(cls, user, data): user.profile.language = updated_profile.language user.profile.word_lists = [] from model.word_list import WordList + for word_list in updated_profile.word_lists: if WordList.allowed_with_acl(word_list.id, user, True, False, False): user.profile.word_lists.append(word_list) @@ -237,46 +457,119 @@ def update_profile(cls, user, data): class UserOrganization(db.Model): - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), primary_key=True) - organization_id = db.Column(db.Integer, db.ForeignKey('organization.id'), primary_key=True) + """Represents the association table between User and Organization. + + Attributes: + user_id (int): The ID of the user. + organization_id (int): The ID of the organization. + """ + + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), primary_key=True) + organization_id = db.Column(db.Integer, db.ForeignKey("organization.id"), primary_key=True) class UserRole(db.Model): - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), primary_key=True) - role_id = db.Column(db.Integer, db.ForeignKey('role.id'), primary_key=True) + """Model class representing the association table between User and Role. + + Attributes: + user_id (int): The ID of the user. + role_id (int): The ID of the role. + """ + + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), primary_key=True) + role_id = db.Column(db.Integer, db.ForeignKey("role.id"), primary_key=True) class UserPermission(db.Model): - user_id = db.Column(db.Integer, db.ForeignKey('user.id'), primary_key=True) - permission_id = db.Column(db.String, db.ForeignKey('permission.id'), primary_key=True) + """Represents the association table between User and Permission models. + + Attributes: + user_id (int): The ID of the user. + permission_id (str): The ID of the permission. + """ + + user_id = db.Column(db.Integer, db.ForeignKey("user.id"), primary_key=True) + permission_id = db.Column(db.String, db.ForeignKey("permission.id"), primary_key=True) class NewHotkeySchema(HotkeySchema): + """Represents a schema for creating a new hotkey. + + Methods: + make(data, **kwargs): Creates a new Hotkey instance based on the provided data. + """ @post_load def make(self, data, **kwargs): + """Create a new Hotkey instance based on the given data. + + Parameters: + data (dict): A dictionary containing the data for the Hotkey. + **kwargs: Additional keyword arguments. + Returns: + Hotkey: A new Hotkey instance. + """ return Hotkey(**data) class NewUserProfileSchema(UserProfileSchema): + """Schema for creating a new user profile. + + Attributes: + word_lists (List[Nested[WordListIdSchema]]): A list of nested schemas for word lists. + hotkeys (List[Nested[NewHotkeySchema]]): A list of nested schemas for hotkeys. + Methods: + make(data, **kwargs): Post-load method that creates a UserProfile instance from the given data. + """ + word_lists = fields.List(fields.Nested(WordListIdSchema)) hotkeys = fields.List(fields.Nested(NewHotkeySchema)) @post_load def make(self, data, **kwargs): + """Create a new UserProfile instance based on the given data. + + Parameters: + data (dict): A dictionary containing the data for the UserProfile. + **kwargs: Additional keyword arguments. + Returns: + UserProfile: A new UserProfile instance. + """ return UserProfile(**data) class UserProfile(db.Model): + """Represent a user profile. + + Attributes: + id (int): The unique identifier for the user profile. + spellcheck (bool): Indicates whether spellcheck is enabled for the user. + dark_theme (bool): Indicates whether dark theme is enabled for the user. + language (str): The language code for the user's preferred language. + hotkeys (list): A list of hotkeys associated with the user profile. + word_lists (list): A list of word lists associated with the user profile. + Methods: + __init__(spellcheck, dark_theme, language, hotkeys, word_lists): Initializes a new instance of the UserProfile class. + """ + id = db.Column(db.Integer, primary_key=True) spellcheck = db.Column(db.Boolean, default=True) dark_theme = db.Column(db.Boolean, default=False) language = db.Column(db.String(2)) hotkeys = db.relationship("Hotkey", cascade="all, delete-orphan") - word_lists = db.relationship('WordList', secondary='user_profile_word_list') + word_lists = db.relationship("WordList", secondary="user_profile_word_list") def __init__(self, spellcheck, dark_theme, language, hotkeys, word_lists): + """Initialize a User object with the given parameters. + + Parameters: + spellcheck (bool): Indicates whether spellcheck is enabled for the user. + dark_theme (bool): Indicates whether the user has enabled dark theme. + language (str): The language preference of the user. + hotkeys (list): A list of hotkeys configured by the user. + word_lists (list): A list of WordList objects associated with the user. + """ self.id = None self.spellcheck = spellcheck self.dark_theme = dark_theme @@ -285,24 +578,51 @@ def __init__(self, spellcheck, dark_theme, language, hotkeys, word_lists): self.word_lists = [] from model.word_list import WordList + for word_list in word_lists: self.word_lists.append(WordList.find(word_list.id)) class UserProfileWordList(db.Model): - user_profile_id = db.Column(db.Integer, db.ForeignKey('user_profile.id'), primary_key=True) - word_list_id = db.Column(db.Integer, db.ForeignKey('word_list.id'), primary_key=True) + """Model class representing the association table between UserProfile and WordList. + + Attributes: + user_profile_id (int): The ID of the user profile. + word_list_id (int): The ID of the word list. + """ + + user_profile_id = db.Column(db.Integer, db.ForeignKey("user_profile.id"), primary_key=True) + word_list_id = db.Column(db.Integer, db.ForeignKey("word_list.id"), primary_key=True) class Hotkey(db.Model): + """Represents a hotkey for a user. + + Attributes: + id (int): The unique identifier for the hotkey. + key_code (int): The code of the key associated with the hotkey. + key (str): The key associated with the hotkey. + alias (str): The alias for the hotkey. + user_profile_id (int): The foreign key referencing the user profile. + Methods: + __init__(key_code, key, alias): Initializes a new instance of the Hotkey class. + """ + id = db.Column(db.Integer, primary_key=True) key_code = db.Column(db.Integer) key = db.Column(db.String) alias = db.Column(db.String) - user_profile_id = db.Column(db.Integer, db.ForeignKey('user_profile.id')) + user_profile_id = db.Column(db.Integer, db.ForeignKey("user_profile.id")) def __init__(self, key_code, key, alias): + """Initialize a User object. + + Parameters: + key_code (str): The key code of the user. + key (str): The key of the user. + alias (str): The alias of the user. + """ self.id = None self.key_code = key_code self.key = key From 1e88d525c98b3b121ece80d1f4ef7cb8283972b6 Mon Sep 17 00:00:00 2001 From: multiflexi Date: Wed, 31 Jul 2024 20:30:26 +0200 Subject: [PATCH 2/2] f-strings --- src/core/model/asset.py | 4 ++-- src/core/model/notification_template.py | 2 +- src/core/model/user.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/core/model/asset.py b/src/core/model/asset.py index 18d313888..9d95346aa 100644 --- a/src/core/model/asset.py +++ b/src/core/model/asset.py @@ -289,7 +289,7 @@ def get(cls, group_id, search, sort, vulnerable): query = query.filter(Asset.vulnerabilities_count > 0) if search is not None: - search_string = "%" + search.lower() + "%" + search_string = f"%{search.lower()}%" query = query.join(AssetCpe, Asset.id == AssetCpe.asset_id).filter( or_( func.lower(Asset.name).like(search_string), @@ -574,7 +574,7 @@ def get(cls, search, organization): query = query.join(AssetGroupOrganization, AssetGroup.id == AssetGroupOrganization.asset_group_id) if search is not None: - search_string = "%" + search.lower() + "%" + search_string = f"%{search.lower()}%" query = query.filter(or_(func.lower(AssetGroup.name).like(search_string), func.lower(AssetGroup.description).like(search_string))) return query.order_by(db.asc(AssetGroup.name)).all(), query.count() diff --git a/src/core/model/notification_template.py b/src/core/model/notification_template.py index 9dbcef91a..2942c8f6a 100644 --- a/src/core/model/notification_template.py +++ b/src/core/model/notification_template.py @@ -203,7 +203,7 @@ def get(cls, search, organization): ) if search is not None: - search_string = "%" + search.lower() + "%" + search_string = f"%{search.lower()}%" query = query.filter( or_( func.lower(NotificationTemplate.name).like(search_string), diff --git a/src/core/model/user.py b/src/core/model/user.py index fd0f19808..6e8a038ae 100644 --- a/src/core/model/user.py +++ b/src/core/model/user.py @@ -228,7 +228,7 @@ def get(cls, search, organization): query = query.join(UserOrganization, User.id == UserOrganization.user_id) if search is not None: - search_string = "%" + search.lower() + "%" + search_string = f"%{search.lower()}%" query = query.filter(or_(func.lower(User.name).like(search_string), func.lower(User.username).like(search_string))) return query.order_by(db.asc(User.name)).all(), query.count()