diff --git a/manager/status_handler.py b/manager/status_handler.py index 723e71b35..e923cffc3 100644 --- a/manager/status_handler.py +++ b/manager/status_handler.py @@ -3,9 +3,13 @@ """ from datetime import datetime from datetime import timezone +from uuid import UUID from peewee import DataError +from peewee import EXCLUDED +from peewee import fn from peewee import IntegrityError +from peewee import ValuesList from psycopg2 import IntegrityError as psycopg2IntegrityError from .base import cyndi_join @@ -22,8 +26,11 @@ from common.peewee_model import InsightsRule from common.peewee_model import RHAccount from common.peewee_model import Status +from common.peewee_model import SystemCveData from common.peewee_model import SystemPlatform from common.peewee_model import SystemVulnerabilities +from common.peewee_model import SystemVulnerablePackage +from common.peewee_model import VulnerablePackageCVE LOGGER = get_logger(__name__) @@ -40,7 +47,6 @@ def handle_get(cls, **kwargs): # pylint: disable=unused-argument status_list = [] for status in query: status_list.append(status) - LOGGER.debug(status_list) return {"data": status_list, "meta": {"total_items": len(status_list)}}, 200 @@ -50,8 +56,7 @@ class PatchStatus(PatchRequest): _endpoint_name = r"/v1/status" @staticmethod - def _prepare_data(data, rh_account_id): - + def _prepare_data(data): if "inventory_id" in data: in_inventory_id_list = parse_str_or_list(data["inventory_id"]) else: @@ -59,53 +64,99 @@ def _prepare_data(data, rh_account_id): in_inventory_id_list = None in_cve_list = parse_str_or_list(data["cve"]) + in_status_id = data.get("status_id") + if data.get("status_text") and data["status_text"].strip(): + in_status_text = data["status_text"].strip() + else: + in_status_text = None + + return in_inventory_id_list, in_cve_list, in_status_id, in_status_text - status_to_cves_map = {} - status_text_to_cves_map = {} - if "status_id" in data: - # single status for all CVEs - status_to_cves_map[data["status_id"]] = in_cve_list - if "status_text" in data: - # single status for all CVEs - try: - key = data["status_text"].strip() if data["status_text"].strip() else None - except AttributeError: - key = None - status_text_to_cves_map[key] = in_cve_list - # if neither of status_id or status_text is set => inherit from CVE-level - if not status_to_cves_map and not status_text_to_cves_map: - # use CVE-level status if status not specified - cve_details = (CveAccountData.select(CveMetadata.cve, CveAccountData.status_id, CveAccountData.status_text) - .join(CveMetadata, on=(CveAccountData.cve_id == CveMetadata.id)) - .where((CveAccountData.rh_account_id == rh_account_id) & - (CveMetadata.cve << in_cve_list)) + @staticmethod + def _apply_system_list_filter(query, rh_account_id, in_inventory_id_list): + query = cyndi_join(query) + query = query.where((SystemPlatform.rh_account_id == rh_account_id) & + (SystemPlatform.when_deleted.is_null(True))) + if in_inventory_id_list is not None: + query = query.where(SystemPlatform.inventory_id << in_inventory_id_list) + return query + + @classmethod + def _get_current_status(cls, rh_account_id, in_inventory_id_list, in_cve_list): + # pair status + system_cve_details = (SystemCveData.select(SystemPlatform.inventory_id, CveMetadata.cve, + SystemCveData.status_id, SystemCveData.status_text) + .join(CveMetadata, on=(SystemCveData.cve_id == CveMetadata.id)) + .join(SystemPlatform, on=(SystemCveData.system_id == SystemPlatform.id)) + .where(CveMetadata.cve << in_cve_list) + .dicts()) + system_cve_details = cls._apply_system_list_filter(system_cve_details, rh_account_id, in_inventory_id_list) + current_status = {} + for system_cve_detail in system_cve_details: + current_status.setdefault(system_cve_detail["cve"], {})[system_cve_detail["inventory_id"]] = \ + (system_cve_detail["status_id"], system_cve_detail["status_text"]) + + # global status + cve_details = (CveAccountData.select(CveMetadata.cve, CveAccountData.status_id, CveAccountData.status_text) + .join(CveMetadata, on=(CveAccountData.cve_id == CveMetadata.id)) + .where((CveAccountData.rh_account_id == rh_account_id) & + (CveMetadata.cve << in_cve_list)) + .dicts()) + for cve_detail in cve_details: + current_status.setdefault(cve_detail["cve"], {})["global"] = (cve_detail["status_id"], cve_detail["status_text"]) + return current_status + + @classmethod + def _get_affected_pairs(cls, rh_account_id, in_inventory_id_list, in_cve_list): + affected_pairs = set() + fixable_pairs = (SystemVulnerabilities.select(SystemPlatform.inventory_id, CveMetadata.cve) + .join(CveMetadata, on=(SystemVulnerabilities.cve_id == CveMetadata.id)) + .join(SystemPlatform, on=(SystemVulnerabilities.system_id == SystemPlatform.id)) + .where(SystemVulnerabilities.rh_account_id == rh_account_id) + .where((SystemVulnerabilities.cve_id << + (CveMetadata.select(CveMetadata.id).where( + CveMetadata.cve << in_cve_list))) & + ((SystemVulnerabilities.when_mitigated.is_null(True)) | + ((SystemVulnerabilities.mitigation_reason.is_null(True)) & + (SystemVulnerabilities.rule_id << (InsightsRule.select(InsightsRule.id) + .where((InsightsRule.active == True) & (InsightsRule.rule_only == False))))))) + .dicts()) + fixable_pairs = cls._apply_system_list_filter(fixable_pairs, rh_account_id, in_inventory_id_list) + for pair in fixable_pairs: + affected_pairs.add((pair["inventory_id"], pair["cve"])) + + unfixable_pairs = (SystemVulnerablePackage.select(SystemPlatform.inventory_id, CveMetadata.cve) + .join(VulnerablePackageCVE, on=(SystemVulnerablePackage.vulnerable_package_id == VulnerablePackageCVE.vulnerable_package_id)) + .join(SystemPlatform, on=(SystemVulnerablePackage.system_id == SystemPlatform.id)) + .join(CveMetadata, on=(VulnerablePackageCVE.cve_id == CveMetadata.id)) + .where(SystemVulnerablePackage.rh_account_id == rh_account_id) + .where((VulnerablePackageCVE.cve_id << + (CveMetadata.select(CveMetadata.id).where( + CveMetadata.cve << in_cve_list)))) .dicts()) + unfixable_pairs = cls._apply_system_list_filter(unfixable_pairs, rh_account_id, in_inventory_id_list) + for pair in unfixable_pairs: + affected_pairs.add((pair["inventory_id"], pair["cve"])) + return affected_pairs - found_cves = set() - for cve in cve_details: - status_to_cves_map.setdefault(cve["status_id"], []).append(cve["cve"]) - status_text_to_cves_map.setdefault(cve["status_text"], []).append(cve["cve"]) - found_cves.add(cve["cve"]) - # not found CVEs have 0 status by default, status_text is null - for cve in in_cve_list: - if cve not in found_cves: - status_to_cves_map.setdefault(0, []).append(cve) - status_text_to_cves_map.setdefault(None, []).append(cve) + @classmethod + def _get_target_status(cls, inventory_id, cve, current_status, in_status_id, in_status_text): + # set global CVE status_id if there is no status_id in request + global_status_id, global_status_text = current_status.get(cve, {}).get("global", (0, None)) + current_status_id, current_status_text = current_status.get(cve, {}).get(inventory_id, (0, None)) + + if in_status_id is None and in_status_text is None: + target_status_id = global_status_id + target_status_text = global_status_text + else: + target_status_id = current_status_id + target_status_text = current_status_text - return in_inventory_id_list, status_to_cves_map, status_text_to_cves_map + if in_status_id is not None: + target_status_id = in_status_id + target_status_text = in_status_text - @staticmethod - def _build_update_condition(rh_account_id, systems, status_cve_list): - # pylint: disable=singleton-comparison - return ((SystemVulnerabilities.rh_account_id == rh_account_id) & - (SystemVulnerabilities.system_id << systems) & - (SystemVulnerabilities.cve_id << - (CveMetadata.select(CveMetadata.id).where( - CveMetadata.cve << status_cve_list))) & - ((SystemVulnerabilities.when_mitigated.is_null(True)) | - ((SystemVulnerabilities.mitigation_reason.is_null(True)) & - (SystemVulnerabilities.rule_id << (InsightsRule.select(InsightsRule.id) - .where((InsightsRule.active == True) & (InsightsRule.rule_only == False))))))) + return target_status_id, target_status_text @classmethod @RBAC.need_permissions(RbacRoutePermissions.SYSTEM_CVE_STATUS_EDIT) @@ -113,51 +164,75 @@ def handle_patch(cls, **kwargs): """Update the "status" field for a system/cve combination""" # pylint: disable=singleton-comparison data = kwargs["data"] - try: rh_account_id = get_or_create_account() - in_inventory_id_list, status_to_cves_map, status_text_to_cves_map = cls._prepare_data(data, rh_account_id) - systems = (SystemPlatform.select(SystemPlatform.id) - .where((SystemPlatform.rh_account_id == rh_account_id) & - (SystemPlatform.opt_out == False) & - (SystemPlatform.stale == False) & - (SystemPlatform.when_deleted.is_null(True)) & - (SystemPlatform.host_type.is_null(True)))) - if in_inventory_id_list is not None: - systems = systems.where(SystemPlatform.inventory_id << in_inventory_id_list) - rows_modified = set() - # set statuses and their CVE lists - for status_id, status_cve_list in status_to_cves_map.items(): - status_id_update = (SystemVulnerabilities.update(status_id=status_id) - .where(cls._build_update_condition(rh_account_id, systems, status_cve_list)) - .returning(SystemVulnerabilities.id)) - rows_modified.update([row.id for row in status_id_update]) - - for status_text, status_cve_list in status_text_to_cves_map.items(): - status_text_update = (SystemVulnerabilities.update(status_text=status_text) - .where(cls._build_update_condition(rh_account_id, systems, status_cve_list)) - .returning(SystemVulnerabilities.id)) - rows_modified.update([row.id for row in status_text_update]) - - if rows_modified: - RHAccount.update(last_status_change=datetime.now(timezone.utc)).where(RHAccount.id == rh_account_id).execute() - updated_details = (SystemVulnerabilities.select(SystemPlatform.inventory_id, CveMetadata.cve) - .join(CveMetadata, on=(SystemVulnerabilities.cve_id == CveMetadata.id)) - .join(SystemPlatform, on=(SystemVulnerabilities.system_id == SystemPlatform.id)) - .where((SystemVulnerabilities.id << list(rows_modified)) & (SystemVulnerabilities.rh_account_id == rh_account_id)) - .dicts()) - updated_details = cyndi_join(updated_details) - updated = [] - for updated_row in updated_details: - updated.append({"inventory_id": updated_row["inventory_id"], "cve": updated_row["cve"]}) - if not updated: + in_inventory_id_list, in_cve_list, in_status_id, in_status_text = cls._prepare_data(data) + + # current status for system-CVE pairs and CVEs + current_status = cls._get_current_status(rh_account_id, in_inventory_id_list, in_cve_list) + + # get system-CVE pairs for which status should be changed (may result in inserting, updating or deleting rows in status table) + affected_pairs = cls._get_affected_pairs(rh_account_id, in_inventory_id_list, in_cve_list) + + if not affected_pairs: # sysid/cve/acct combination does not exist return cls.format_exception("inventory_id/cve must exist and inventory_id must be visible to user", 404) + + to_upsert = [] + to_delete = [] + updated = [] + for inventory_id, cve in affected_pairs: + target_status_id, target_status_text = cls._get_target_status(inventory_id, cve, current_status, in_status_id, in_status_text) + current_status_row = current_status.get(cve, {}).get(inventory_id) + if not current_status_row: # insert new statuses + if target_status_id != 0 or target_status_text is not None: + to_upsert.append((UUID(inventory_id), cve, target_status_id, target_status_text)) + updated.append({"inventory_id": inventory_id, "cve": cve}) + else: # update existing statuses + if target_status_id != 0 or target_status_text is not None: + if target_status_id != current_status_row[0] or target_status_text != current_status_row[1]: + to_upsert.append((UUID(inventory_id), cve, target_status_id, target_status_text)) + updated.append({"inventory_id": inventory_id, "cve": cve}) + current_status.get(cve, {}).pop(inventory_id, None) + + for cve, systems in current_status.items(): # delete statuses that are set to 0, or no longer relevant + for sys in systems: + if sys != "global": + to_delete.append((UUID(sys), cve)) + updated.append({"inventory_id": sys, "cve": cve}) + + if to_upsert: + values_list = ValuesList(to_upsert, columns=("inventory_id", "cve", "status_id", "status_text")) + SystemCveData.insert_from( + (values_list.select(SystemPlatform.id, CveMetadata.id, values_list.c.status_id, values_list.c.status_text) + .join(SystemPlatform, on=(values_list.c.inventory_id == SystemPlatform.inventory_id)) + .join(CveMetadata, on=(values_list.c.cve == CveMetadata.cve))), + fields=[SystemCveData.system_id, SystemCveData.cve_id, SystemCveData.status_id, SystemCveData.status_text] + ).on_conflict( + conflict_target=[SystemCveData.system_id, SystemCveData.cve_id], + update={SystemCveData.status_id: EXCLUDED.status_id, + SystemCveData.status_text: EXCLUDED.status_text} + ).execute() + + if to_delete: + values_list = ValuesList(to_delete, columns=("inventory_id", "cve")) + SystemCveData.delete().where( + fn.EXISTS( + values_list.select(SystemPlatform.id, CveMetadata.id) + .join(SystemPlatform, on=(values_list.c.inventory_id == SystemPlatform.inventory_id)) + .join(CveMetadata, on=(values_list.c.cve == CveMetadata.cve)) + .where(SystemCveData.system_id == SystemPlatform.id) + .where(SystemCveData.cve_id == CveMetadata.id) + ) + ).execute() + + if updated: + RHAccount.update(last_status_change=datetime.now(timezone.utc)).where(RHAccount.id == rh_account_id).execute() except (IntegrityError, psycopg2IntegrityError, DataError) as value_error: # usually means bad-status-id LOGGER.error(str(value_error)) DB.rollback() - return cls.format_exception(f"status_id={list(status_to_cves_map.keys())} is invalid", 400) + return cls.format_exception(f"status_id={in_status_id} is invalid", 400) except ValueError as value_error: LOGGER.exception("Error during setting status (ValueError):") DB.rollback()