diff --git a/check_version.py b/check_version.py index 7ce98de..6542670 100644 --- a/check_version.py +++ b/check_version.py @@ -3,7 +3,7 @@ import upgrade -def compare_version_number(version_to_check): +def compare_version_number(version_to_check: str) -> int: latest_version, _ = upgrade.UPGRADE_STEPS[-1] if latest_version == version_to_check: diff --git a/requirements.txt b/requirements.txt index f2ef34a..9be4102 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ gitpython mysql-connector-python +PyHamcrest + diff --git a/src/common_upgrades/add_to_base_iocs.py b/src/common_upgrades/add_to_base_iocs.py index 063383b..8fce90e 100644 --- a/src/common_upgrades/add_to_base_iocs.py +++ b/src/common_upgrades/add_to_base_iocs.py @@ -1,7 +1,10 @@ from xml.dom import minidom from xml.parsers.expat import ExpatError -IOC_FILENAME = "configurations\components\_base\iocs.xml" +from src.file_access import FileAccess +from src.local_logger import LocalLogger + +IOC_FILENAME = r"configurations\components\_base\iocs.xml" FILE_TO_CHECK_STR = "IOC default component file" ALREADY_CONTAINS = "{} already contains {} ioc." @@ -12,12 +15,14 @@ class AddToBaseIOCs: """Add the ioc autostart to _base ioc so that it autostarts""" - def __init__(self, ioc_to_add, add_after_ioc, xml_to_add): + def __init__( + self, ioc_to_add: str | None, add_after_ioc: str | None, xml_to_add: str | None + ) -> None: self._ioc_to_add = ioc_to_add self._add_after_ioc = add_after_ioc self._xml_to_add = xml_to_add - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: """Add the autostart of the given. Args: @@ -54,7 +59,7 @@ def perform(self, file_access, logger): return 0 @staticmethod - def _get_ioc_names(xml): + def _get_ioc_names(xml: minidom.Document) -> list[str]: """Gets the names of all the iocs in the xml. Args: @@ -65,7 +70,9 @@ def _get_ioc_names(xml): """ return [ioc.getAttribute("name") for ioc in xml.getElementsByTagName("ioc")] - def _check_final_file_contains_one_of_added_ioc(self, logger, xml): + def _check_final_file_contains_one_of_added_ioc( + self, logger: LocalLogger, xml: minidom.Document + ) -> bool: """Check the file to make sure it now contains one and only one ioc added entry. Args: @@ -77,6 +84,7 @@ def _check_final_file_contains_one_of_added_ioc(self, logger, xml): """ ioc_names = AddToBaseIOCs._get_ioc_names(xml) + assert self._ioc_to_add is not None node_count = ioc_names.count(self._ioc_to_add) if node_count != 1: # I can not see how to generate this error but it is here because it is important @@ -84,7 +92,7 @@ def _check_final_file_contains_one_of_added_ioc(self, logger, xml): return False return True - def _check_prerequistes_for_file(self, xml, logger): + def _check_prerequistes_for_file(self, xml: minidom.Document, logger: LocalLogger) -> bool: """Check the file can be modified. Args: @@ -95,11 +103,11 @@ def _check_prerequistes_for_file(self, xml, logger): True if everything is ok, else False. """ ioc_names = AddToBaseIOCs._get_ioc_names(xml) - + assert self._ioc_to_add is not None if ioc_names.count(self._ioc_to_add) != 0: logger.error(ALREADY_CONTAINS.format(FILE_TO_CHECK_STR, self._ioc_to_add)) return False - + assert self._add_after_ioc is not None node_count = ioc_names.count(self._add_after_ioc) if node_count != 1: logger.error( @@ -108,7 +116,7 @@ def _check_prerequistes_for_file(self, xml, logger): return False return True - def _add_ioc(self, ioc_xml, logger): + def _add_ioc(self, ioc_xml: minidom.Document, logger: LocalLogger) -> minidom.Document: """Add IOC entry after add after ioc specified if it exists. Args: @@ -119,7 +127,9 @@ def _add_ioc(self, ioc_xml, logger): """ for ioc in ioc_xml.getElementsByTagName("ioc"): if ioc.getAttribute("name") == self._add_after_ioc: + assert self._xml_to_add is not None new_ioc_node = minidom.parseString(self._xml_to_add).firstChild + assert ioc_xml.firstChild is not None ioc_xml.firstChild.insertBefore(new_ioc_node, ioc.nextSibling) # add some formatting to make it look nice ioc_xml.firstChild.insertBefore(ioc_xml.createTextNode("\n "), new_ioc_node) diff --git a/src/common_upgrades/change_macro_in_globals.py b/src/common_upgrades/change_macro_in_globals.py index 500eb2c..a8290f6 100644 --- a/src/common_upgrades/change_macro_in_globals.py +++ b/src/common_upgrades/change_macro_in_globals.py @@ -1,12 +1,16 @@ import re +from typing import Generator from src.common_upgrades.utils.constants import GLOBALS_FILENAME +from src.common_upgrades.utils.macro import Macro +from src.file_access import FileAccess +from src.local_logger import LocalLogger class ChangeMacroInGlobals(object): """An interface to replace arbitrary macros in a globals.txt file""" - def __init__(self, file_access, logger): + def __init__(self, file_access: FileAccess, logger: LocalLogger) -> None: """Initialise. Args: @@ -17,7 +21,7 @@ def __init__(self, file_access, logger): self._logger = logger self._loaded_file = self.load_globals_file() - def load_globals_file(self): + def load_globals_file(self) -> list: """Loads in a globals file as a list of strings. Returns: @@ -29,7 +33,7 @@ def load_globals_file(self): else: return [] - def change_macros(self, ioc_name, macros_to_change): + def change_macros(self, ioc_name: str, macros_to_change: list[tuple[Macro, Macro]]) -> None: """Changes a list of macros in the globals.txt file for a specific IOC. Args: @@ -46,7 +50,7 @@ def change_macros(self, ioc_name, macros_to_change): self.write_modified_globals_file() - def change_ioc_name(self, old_ioc_name, new_ioc_name): + def change_ioc_name(self, old_ioc_name: str, new_ioc_name: str) -> None: """Changes the name of an IOC in a globals.txt file. Args: @@ -62,7 +66,7 @@ def change_ioc_name(self, old_ioc_name, new_ioc_name): self.write_modified_globals_file() - def _globals_filter_generator(self, ioc_to_change): + def _globals_filter_generator(self, ioc_to_change: str) -> Generator[int, None, None]: """Returns lines containing specified IOCs from globals.txt Generator that gives all the lines for a given IOC in globals.txt. @@ -80,7 +84,7 @@ def _globals_filter_generator(self, ioc_to_change): self._logger.info("Found line '{}' in {}".format(line, GLOBALS_FILENAME)) yield index - def _determine_replacement_values(self, old_macro, new_macro): + def _determine_replacement_values(self, old_macro: Macro, new_macro: Macro) -> dict[str, str]: """Determines the strings to search for and replace. Args: @@ -110,7 +114,9 @@ def _determine_replacement_values(self, old_macro, new_macro): return regex_changes - def _apply_regex_macro_change(self, ioc_name, old_macro, new_macro, line_number): + def _apply_regex_macro_change( + self, ioc_name: str, old_macro: Macro, new_macro: Macro, line_number: int + ) -> None: """Applies a regular expression to modify a macro. Args: @@ -135,7 +141,7 @@ def _apply_regex_macro_change(self, ioc_name, old_macro, new_macro, line_number) self._loaded_file[line_number], ) - def _change_ioc_name(self, ioc_name, new_ioc_name, line_number): + def _change_ioc_name(self, ioc_name: str, new_ioc_name: str, line_number: int) -> None: """If a new name is supplied, changes the name of the IOC Args: @@ -150,7 +156,7 @@ def _change_ioc_name(self, ioc_name, new_ioc_name, line_number): ioc_name, new_ioc_name.upper() ) - def write_modified_globals_file(self): + def write_modified_globals_file(self) -> None: """Writes the modified globals file if it has been loaded. Returns: diff --git a/src/common_upgrades/change_macros_in_xml.py b/src/common_upgrades/change_macros_in_xml.py index a9fae1b..2e38854 100644 --- a/src/common_upgrades/change_macros_in_xml.py +++ b/src/common_upgrades/change_macros_in_xml.py @@ -1,10 +1,15 @@ import re +from typing import Generator +from xml.dom.minidom import Document, Element, Text from xml.parsers.expat import ExpatError from src.common_upgrades.utils.constants import FILTER_REGEX, IOC_FILE, SYNOPTIC_FOLDER +from src.common_upgrades.utils.macro import Macro +from src.file_access import FileAccess +from src.local_logger import LocalLogger -def change_macro_name(macro, old_macro_name, new_macro_name): +def change_macro_name(macro: Element, old_macro_name: str, new_macro_name: str) -> None: """Changes the macro name of a macro xml node. Args: @@ -17,7 +22,9 @@ def change_macro_name(macro, old_macro_name, new_macro_name): macro.setAttribute("name", new_macro_name) -def change_macro_value(macro, old_macro_value, new_macro_value): +def change_macro_value( + macro: Element, old_macro_value: str | None, new_macro_value: str | None +) -> None: """Changes the macros in the given xml if a new macro value is given. Args: @@ -33,7 +40,7 @@ def change_macro_value(macro, old_macro_value, new_macro_value): macro.setAttribute("value", new_macro_value) -def find_macro_with_name(macros, name_to_find): +def find_macro_with_name(macros: Element, name_to_find: str) -> bool: """Find whether macro with name attribute equal to argument name_to_find exists Args: @@ -51,7 +58,9 @@ def find_macro_with_name(macros, name_to_find): class ChangeMacrosInXML(object): """Changes macros in XML files.""" - def __init__(self, file_access, logger): + _ioc_file_generator: object + + def __init__(self, file_access: FileAccess, logger: LocalLogger) -> None: """Initialise. Args: @@ -62,20 +71,27 @@ def __init__(self, file_access, logger): self._logger = logger def add_macro( - self, ioc_name, macro_to_add, pattern, description="No description", default_value=None - ): - """Add a macro with a specified name and value to all IOCs whose name begins with ioc_name, unless a macro - with that name already exists + self, + ioc_name: str, + macro_to_add: Macro, + pattern: str, + description: str = "No description", + default_value: str | None = None, + ) -> None: + """Add a macro with a specified name and value to all IOCs whose name begins with ioc_name, + unless a macro with that name already exists Args: - ioc_name: Name of the IOC to add the macro to (e.g. DFKPS would add macros to DFKPS_01 and DFKPS_02) - macro_to_add: Macro class with desired name and value - pattern: Regex pattern describing what values the macro accepts e.g. "^(0|1)$" for 0 or 1 + ioc_name: Name of the IOC to add the macro to (e.g. DFKPS would add macros to DFKPS_01 + and DFKPS_02) macro_to_add: Macro class with desired name and value + pattern: Regex pattern describing what values the macro accepts e.g. "^(0|1)$" for 0/1 description: Description of macro purpose default_value: An optional default value for the macro Returns: None """ + assert macro_to_add.name is not None + assert macro_to_add.value is not None for path, ioc_xml in self._file_access.get_config_files(IOC_FILE): for ioc in self.ioc_tag_generator(path, ioc_xml, ioc_name): macros = ioc.getElementsByTagName("macros")[0] @@ -88,7 +104,7 @@ def add_macro( self._file_access.write_xml_file(path, ioc_xml) - def change_macros(self, ioc_name, macros_to_change): + def change_macros(self, ioc_name: str, macros_to_change: list[tuple[Macro, Macro]]) -> None: """Changes macros in all xml files that contain the correct macros for a specified ioc. Args: @@ -103,15 +119,16 @@ def change_macros(self, ioc_name, macros_to_change): macros = ioc.getElementsByTagName("macros")[0] for macro in macros.getElementsByTagName("macro"): name = macro.getAttribute("name") - for old_macro, new_macro in macros_to_change: - # Check if current macro name starts with name of macro to be changed - if re.match(old_macro.name, name) is not None: - change_macro_name(macro, old_macro.name, new_macro.name) - change_macro_value(macro, old_macro.value, new_macro.value) + if name is not None: + for old_macro, new_macro in macros_to_change: + # Check if current macro name starts with name of macro to be changed + if re.match(old_macro.name, name) is not None: + change_macro_name(macro, old_macro.name, new_macro.name) + change_macro_value(macro, old_macro.value, new_macro.value) self._file_access.write_xml_file(path, ioc_xml) - def change_ioc_name(self, old_ioc_name, new_ioc_name): + def change_ioc_name(self, old_ioc_name: str, new_ioc_name: str) -> None: """Replaces all instances of old_ioc_name with new_ioc_name in an XML tree Args: old_ioc_name: String, the old ioc prefix (without _XX number suffix) @@ -131,7 +148,7 @@ def change_ioc_name(self, old_ioc_name, new_ioc_name): self._file_access.write_xml_file(path, ioc_xml) - def change_ioc_name_in_synoptics(self, old_ioc_name, new_ioc_name): + def change_ioc_name_in_synoptics(self, old_ioc_name: str, new_ioc_name: str) -> None: """Replaces instances of old_ioc_name with new_ioc_name Args: @@ -155,7 +172,10 @@ def change_ioc_name_in_synoptics(self, old_ioc_name, new_ioc_name): for element in synoptic_xml.getElementsByTagName("value"): # Obtain text between the tags (https://stackoverflow.com/a/317494 and https://stackoverflow.com/a/13591742) if element.firstChild is not None: - if element.firstChild.nodeType == element.TEXT_NODE: + if ( + isinstance(element.firstChild, Text) + and element.firstChild.nodeType == element.TEXT_NODE + ): ioc_name_with_suffix = element.firstChild.nodeValue if old_ioc_name in ioc_name_with_suffix: @@ -166,7 +186,9 @@ def change_ioc_name_in_synoptics(self, old_ioc_name, new_ioc_name): self._file_access.write_xml_file(xml_path, synoptic_xml) - def ioc_tag_generator(self, path, ioc_xml, ioc_to_change): + def ioc_tag_generator( + self, path: str, ioc_xml: Document, ioc_to_change: str + ) -> Generator[Element, None, None]: """Generator giving all the IOC tags in all configurations. Args: diff --git a/src/common_upgrades/change_pvs_in_xml.py b/src/common_upgrades/change_pvs_in_xml.py index 0c7f168..daf9e65 100644 --- a/src/common_upgrades/change_pvs_in_xml.py +++ b/src/common_upgrades/change_pvs_in_xml.py @@ -1,10 +1,15 @@ +from typing import Generator +from xml.dom.minidom import Document, Element, Text + from src.common_upgrades.utils.constants import BLOCK_FILE +from src.file_access import FileAccess +from src.local_logger import LocalLogger class ChangePVsInXML(object): """Changes pvs in XML files.""" - def __init__(self, file_access, logger): + def __init__(self, file_access: FileAccess, logger: LocalLogger) -> None: """Initialise. Args: @@ -14,8 +19,11 @@ def __init__(self, file_access, logger): self._file_access = file_access self._logger = logger - def node_text_filter(self, filter_text, element_name, path, xml): - """A generator that gives all the instances of filter_text within the element_name elements of the input_files. + def node_text_filter( + self, filter_text: str, element_name: str, path: str, xml: Document + ) -> Generator[Element, None, None]: + """A generator that gives all the instances of filter_text within the + element_name elements of the input_files. Args: filter_text: String, ext to find @@ -27,15 +35,25 @@ def node_text_filter(self, filter_text, element_name, path, xml): Generator giving node instances """ for node in xml.getElementsByTagName(element_name): - if node.firstChild is None or node.firstChild.nodeType != node.TEXT_NODE: + if node.firstChild is None or ( + isinstance(node.firstChild, Element) and node.firstChild.nodeType != node.TEXT_NODE + ): continue + assert isinstance(node.firstChild, Text) current_pv_value = node.firstChild.nodeValue if filter_text in current_pv_value: self._logger.info("{} found in {}".format(filter_text, path)) yield node - def _replace_text_in_elements(self, old_text, new_text, element_name, input_files): - """Replaces all instances of old_text with new_text in all element_name elements of one or more XML files + def _replace_text_in_elements( + self, + old_text: str, + new_text: str, + element_name: str, + input_files: Generator[tuple[str, Document], None, None], + ) -> None: + """Replaces all instances of old_text with new_text in all element_name elements + of one or more XML files Args: old_text: String, old text to find new_text: String, new text to substitute @@ -44,13 +62,15 @@ def _replace_text_in_elements(self, old_text, new_text, element_name, input_file """ for path, xml in input_files: for node in self.node_text_filter(old_text, element_name, path, xml): - replacement = node.firstChild.nodeValue.replace(old_text, new_text) - node.firstChild.replaceWholeText(replacement) + if isinstance(node.firstChild, Text): + replacement = node.firstChild.nodeValue.replace(old_text, new_text) + node.firstChild.replaceWholeText(replacement) self._file_access.write_xml_file(path, xml) - def change_pv_name(self, old_pv_name, new_pv_name): - """Replaces all instances of old_pv_name with new_pv_name in the blocks config and all synoptics + def change_pv_name(self, old_pv_name: str, new_pv_name: str) -> None: + """Replaces all instances of old_pv_name with new_pv_name in the blocks config + and all synoptics Args: old_pv_name: String, the old pv name new_pv_name: String, The desired new pv name @@ -59,7 +79,7 @@ def change_pv_name(self, old_pv_name, new_pv_name): self.change_pv_name_in_blocks(old_pv_name, new_pv_name) self.change_pv_names_in_synoptics(old_pv_name, new_pv_name) - def change_pv_name_in_blocks(self, old_pv_name, new_pv_name): + def change_pv_name_in_blocks(self, old_pv_name: str, new_pv_name: str) -> None: """Move any blocks pointing at old_pv_name to point at new_pv_name. Args: @@ -67,10 +87,13 @@ def change_pv_name_in_blocks(self, old_pv_name, new_pv_name): new_pv_name: The new PV to replace it with """ self._replace_text_in_elements( - old_pv_name, new_pv_name, "read_pv", self._file_access.get_config_files(BLOCK_FILE) + old_pv_name, + new_pv_name, + "read_pv", + self._file_access.get_config_files(BLOCK_FILE), ) - def change_pv_names_in_synoptics(self, old_pv_name, new_pv_name): + def change_pv_names_in_synoptics(self, old_pv_name: str, new_pv_name: str) -> None: """Move any synoptic PV targets from pointing at old_pv_name to point to new_pv_name. Args: @@ -81,7 +104,7 @@ def change_pv_names_in_synoptics(self, old_pv_name, new_pv_name): old_pv_name, new_pv_name, "address", self._file_access.get_synoptic_files() ) - def get_number_of_instances_of_pv(self, pv_names): + def get_number_of_instances_of_pv(self, pv_names: str | list[str]) -> int: """Get the number of instances of a PV in the config and synoptic. Args: diff --git a/src/common_upgrades/sql_utilities.py b/src/common_upgrades/sql_utilities.py index b252986..36a5cff 100644 --- a/src/common_upgrades/sql_utilities.py +++ b/src/common_upgrades/sql_utilities.py @@ -1,5 +1,6 @@ """Helpful sql utilities""" +# ruff: noqa: ANN204, ANN205, E501, ANN001, ANN201 import os import re from getpass import getpass @@ -68,10 +69,11 @@ def run_sql_list(logger, sql_list): sql_list: The statement to send """ cursor = SqlConnection.get_session(logger).cursor() - + assert cursor is not None for sql in sql_list: - for result in cursor.execute(sql, multi=True): - pass + cursor.execute(sql, multi=True) + # for result in cursor.execute(sql, multi=True): + # pass SqlConnection.get_session(logger).commit() cursor.close() @@ -117,7 +119,10 @@ def add_new_user(logger, user, password): logger, "CREATE USER {} IDENTIFIED WITH mysql_native_password BY '{}';".format(user, password), ) - run_sql(logger, "GRANT INSERT, SELECT, UPDATE, DELETE ON exp_data.* TO {};".format(user)) + run_sql( + logger, + "GRANT INSERT, SELECT, UPDATE, DELETE ON exp_data.* TO {};".format(user), + ) return 0 except Exception as e: logger.error("Failed to add user: {}".format(e)) diff --git a/src/common_upgrades/synoptics_and_device_screens.py b/src/common_upgrades/synoptics_and_device_screens.py index 63328ec..add80a6 100644 --- a/src/common_upgrades/synoptics_and_device_screens.py +++ b/src/common_upgrades/synoptics_and_device_screens.py @@ -1,10 +1,15 @@ +# ruff: noqa: E501 from functools import partial +from xml.dom.minidom import Document, Element + +from src.file_access import FileAccess +from src.local_logger import LocalLogger class SynopticsAndDeviceScreens(object): """Manipulate an instrument's synoptics and device_screens""" - def __init__(self, file_access, logger): + def __init__(self, file_access: FileAccess, logger: LocalLogger) -> None: self.file_access = file_access self.logger = logger self._update_keys_in_device_screens = partial( @@ -14,7 +19,7 @@ def __init__(self, file_access, logger): self._update_opi_keys_in_xml, root_tag="target", key_tag="name" ) - def update_opi_keys(self, keys_to_update): + def update_opi_keys(self, keys_to_update: dict) -> int: """Update the OPI keys in all synoptics and device screens Args: @@ -44,11 +49,13 @@ def update_opi_keys(self, keys_to_update): device_screens[0], device_screens[1], keys_to_update ) except Exception as e: - self.logger.error("Cannot upgrade device screens {}: {}".format(path, e)) + self.logger.error("Cannot upgrade device screens: {}".format(e)) result = -2 return result - def _update_opi_keys_in_xml(self, path, xml, keys_to_update, root_tag, key_tag): + def _update_opi_keys_in_xml( + self, path: str, xml: Document, keys_to_update: dict, root_tag: str, key_tag: str + ) -> None: """Replaces an opi key with a different key Args: @@ -61,15 +68,16 @@ def _update_opi_keys_in_xml(self, path, xml, keys_to_update, root_tag, key_tag): file_changed = False for target_element in xml.getElementsByTagName(root_tag): key_element = target_element.getElementsByTagName(key_tag)[0] - old_key = key_element.firstChild.nodeValue - new_key = keys_to_update.get(old_key, old_key) - key_element.firstChild.nodeValue = new_key - if new_key != old_key: - file_changed = True - self.logger.info( - "OPI key '{}' replaced with corresponding key '{}' in {}".format( - old_key, new_key, path + if isinstance(key_element.firstChild, Element): + old_key = key_element.firstChild.nodeValue + new_key = keys_to_update.get(old_key, old_key) + key_element.firstChild.nodeValue = new_key + if new_key != old_key: + file_changed = True + self.logger.info( + "OPI key '{}' replaced with corresponding key '{}' in {}".format( + old_key, new_key, path + ) ) - ) if file_changed: self.file_access.write_xml_file(path, xml) diff --git a/src/common_upgrades/utils/constants.py b/src/common_upgrades/utils/constants.py index 8d070ce..e420efa 100644 --- a/src/common_upgrades/utils/constants.py +++ b/src/common_upgrades/utils/constants.py @@ -19,4 +19,4 @@ MOTION_SET_POINTS_FOLDER = os.path.abspath(os.path.join(CONFIG_ROOT, "motionSetPoints")) # Matches an ioc name and its numbered IOCs e.g. GALIL matches GALIL_01, GALIL_02 -FILTER_REGEX = "^{}(_[\d]{{2}})?$" +FILTER_REGEX = r"^{}(_[\d]{{2}})?$" diff --git a/src/common_upgrades/utils/macro.py b/src/common_upgrades/utils/macro.py index c765f50..b4c25de 100644 --- a/src/common_upgrades/utils/macro.py +++ b/src/common_upgrades/utils/macro.py @@ -6,17 +6,17 @@ class Macro(object): value: Value of the Macro. E.g. 1. Defaults to None. """ - def __init__(self, name, value=None): + def __init__(self, name: str, value: str | None = None) -> None: self.__name = name self.__value = value - def __repr__(self): + def __repr__(self) -> str: return "".format(self.__name, self.__value) @property - def name(self): + def name(self) -> str: return self.__name @property - def value(self): + def value(self) -> str | None: return self.__value diff --git a/src/file_access.py b/src/file_access.py index 440b81b..444e9bc 100644 --- a/src/file_access.py +++ b/src/file_access.py @@ -1,3 +1,4 @@ +# ruff: noqa: ANN204, ANN205, E501, ANN001, ANN201, ANN202 import os import shutil from xml.dom import minidom @@ -221,7 +222,7 @@ def get_device_screens(self): else: return None - def get_file_paths(self, directory: str, extension: str = None): + def get_file_paths(self, directory: str, extension: str = ""): """Generator giving the paths of all files inside a directory, recursively searching all subdirectories. Args: diff --git a/src/git_utils.py b/src/git_utils.py index 8f5d90a..f6dd415 100644 --- a/src/git_utils.py +++ b/src/git_utils.py @@ -1,12 +1,13 @@ +# ruff: noqa: ANN205, ANN001 import git class RepoFactory: @staticmethod - def get_repo(working_directory): + def get_repo(working_directory: str): # Check repo try: return git.Repo(working_directory, search_parent_directories=True) except Exception: # Not a valid repository - raise git.NotUnderVersionControl(working_directory) + raise Exception(working_directory + " is not under version control") diff --git a/src/local_logger.py b/src/local_logger.py index 3e3bc8d..45ca8c9 100644 --- a/src/local_logger.py +++ b/src/local_logger.py @@ -6,7 +6,7 @@ class LocalLogger(object): """A local logging object which will write to the screen and a file""" - def __init__(self, log_dir): + def __init__(self, log_dir: str) -> None: """The logging directory in to which to write the log file Args: @@ -16,12 +16,13 @@ def __init__(self, log_dir): os.mkdir(log_dir) log_file = os.path.join( - log_dir, "upgrade_{0}.txt".format(datetime.datetime.now().strftime("%Y_%m_%d__%H_%M")) + log_dir, + "upgrade_{0}.txt".format(datetime.datetime.now().strftime("%Y_%m_%d__%H_%M")), ) self._log_file = log_file - def error(self, message): + def error(self, message: str) -> None: """Write the message as an error (to standard err with ERROR in front of it) Args: @@ -35,7 +36,7 @@ def error(self, message): f.write(formatted_message) sys.stderr.write(formatted_message) - def info(self, message): + def info(self, message: str) -> None: """Write the message as info (to standard out with INFO in front of it) Args: diff --git a/src/upgrade.py b/src/upgrade.py index 750286c..9fcc550 100644 --- a/src/upgrade.py +++ b/src/upgrade.py @@ -1,6 +1,10 @@ import os +from typing import Sequence from src.common_upgrades.sql_utilities import SqlConnection +from src.file_access import FileAccess +from src.local_logger import LocalLogger +from src.upgrade_step import UpgradeStep VERSION_FILENAME = os.path.join("configurations", "config_version.txt") @@ -14,7 +18,13 @@ class UpgradeError(Exception): class Upgrade(object): """Use upgrade steps to upgrade a configuration""" - def __init__(self, file_access, logger, upgrade_steps, git_repo): + def __init__( + self, + file_access: FileAccess | None, + logger: LocalLogger | None, + upgrade_steps: Sequence[tuple[str, UpgradeStep | None]], + git_repo, # noqa + ) -> None: """Constructor Args: @@ -32,30 +42,34 @@ def __init__(self, file_access, logger, upgrade_steps, git_repo): self._upgrade_steps = upgrade_steps self._git_repo = git_repo - def get_version_number(self): + def get_version_number(self) -> str | None: """Find the current version number of the repository. If there is no version number the repository is considered unversioned and the lowest version is written to the repository Returns: the version number """ try: + assert self._file_access is not None for line in self._file_access.open_file(VERSION_FILENAME): return line.strip() except IOError: + assert self._file_access is not None initial_version_number = self._upgrade_steps[0][0] self._file_access.write_version_number(initial_version_number, VERSION_FILENAME) return initial_version_number - def upgrade(self): + def upgrade(self) -> int: """Perform an upgrade on the configuration directory Returns: status code 0 for success; not 0 for failure """ current_version = self.get_version_number() + assert self._file_access is not None + assert self._logger is not None self._logger.info("Config at initial version {0}".format(current_version)) upgrade = False - final_upgrade_version = None + final_upgrade_version = "" with SqlConnection(): for version, upgrade_step in self._upgrade_steps: if version == current_version: @@ -83,7 +97,7 @@ def upgrade(self): self._logger.error("Unknown version number {0}".format(current_version)) return -1 - def _commit_tag_and_push(self, version, final=False): + def _commit_tag_and_push(self, version: str, final: bool = False) -> None: self._git_repo.git.add(A=True) commit_message = f"IBEX Upgrade {'from' if not final else 'to'} {version}" self._git_repo.index.commit(commit_message) diff --git a/src/upgrade_step.py b/src/upgrade_step.py index 4eeb066..471242b 100644 --- a/src/upgrade_step.py +++ b/src/upgrade_step.py @@ -7,7 +7,7 @@ class UpgradeStep(object): __metaclass__ = ABCMeta @abstractmethod - def perform(self, file_access, logger): + def perform(self, file_access, logger): # noqa """Perform the upgrade step this should be implemented Args: diff --git a/src/upgrade_step_add_meta_tag.py b/src/upgrade_step_add_meta_tag.py index 3c7ff86..d5ec093 100644 --- a/src/upgrade_step_add_meta_tag.py +++ b/src/upgrade_step_add_meta_tag.py @@ -1,19 +1,22 @@ import os +import typing import xml.etree.ElementTree as ET from src.common_upgrades.utils.constants import COMPONENT_FOLDER, CONFIG_FOLDER +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class UpgradeStepAddMetaXmlElement(UpgradeStep): """An upgrade step that adds a passed element to the meta.xml for a configuration.""" - def __init__(self, tag, tag_value): + def __init__(self, tag: str, tag_value: str) -> None: self.tag = tag self.tag_value = tag_value super(UpgradeStepAddMetaXmlElement, self).__init__() - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: """Change meta.xml configuration schema to have self.tag element Args: @@ -33,9 +36,11 @@ def perform(self, file_access, logger): return -1 return 0 - def _add_tag_to_meta_in_folders(self, folder, logger): + def _add_tag_to_meta_in_folders( + self, folder: tuple[typing.Any, list, list], logger: LocalLogger + ) -> int: try: - meta_file_path = os.path.join(folder[0] + "\meta.xml") + meta_file_path = os.path.join(folder[0] + r"\meta.xml") meta_xml = ET.parse(meta_file_path) if len(meta_xml.getroot().findall(self.tag)) == 0: xml_tag = ET.SubElement(meta_xml.getroot(), self.tag) diff --git a/src/upgrade_step_check_init_inst.py b/src/upgrade_step_check_init_inst.py index 49ae381..3231742 100644 --- a/src/upgrade_step_check_init_inst.py +++ b/src/upgrade_step_check_init_inst.py @@ -1,15 +1,18 @@ import os from src.common_upgrades.utils.constants import SCRIPTS_ROOT +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class UpgradeStepCheckInitInst(UpgradeStep): """An upgrade step to check if the instrument uses the old style of loading in pre and post cmd. - This old style is via API.__localmod in init_.py in the Instrument/Settings/config/NDX/Python folder. + This old style is via API.__ + localmod in init_.py in the Instrument/Settings/config/NDX/Python folder. """ - def search_files(self, files, root, file_access): + def search_files(self, files: list[str], root: str, file_access: FileAccess) -> str | int: """Search files from a root folder for pre and post cmd methods. Args: @@ -17,7 +20,7 @@ def search_files(self, files, root, file_access): root (str): The root directory of the files. file_access (FileAccess): file access - Returns: 0 if pre and post cmd methods in old style are not present; error message if they are. + Returns: 0 if pre & post cmd methods in old style aren't present; error message if they are. """ for file_name in files: if file_name.startswith("init_"): @@ -25,21 +28,23 @@ def search_files(self, files, root, file_access): search_file_contents = search_file.read() if "precmd" in search_file_contents or "postcmd" in search_file_contents: return ( - "Pre or post cmd methods found in {} these will now no longer be hooked into the command. Please ensure they are hooked using the new style of inserting these methods, " + "Pre or post cmd methods found in {} these will now no longer be " + "hooked into the command. Please ensure they are hooked using the new style" + " of inserting these methods, " "see https://github.com/ISISComputingGroup/ibex_user_manual/wiki/Pre-and-Post-Command-Hooks".format( search_file.name ) ) return 0 - def search_folder(self, folder, file_access): + def search_folder(self, folder: str, file_access: FileAccess) -> str | int: """Search folders for the search string. Args: folder (str): The folder to search through. file_access (FileAccess): file access - Returns: 0 if pre and post cmd methods in old style are not present; error message if they are. + Returns: 0 if pre & post cmd methods in old style aren't present; error message if they are. """ file_returns = "" for root, _, files in os.walk(folder): @@ -49,14 +54,14 @@ def search_folder(self, folder, file_access): file_returns += "{}\n".format(file_search_return) return 0 if file_returns == "" else file_returns - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> str | int: """Check if file exists and if the file includes pre and post cmd methods. Args: file_access (FileAccess): file access logger (LocalLogger): logger - Returns: 0 if pre and post cmd methods in old style are not present; error message if they are. + Returns: 0 if pre & post cmd methods in old style aren't present; error message if they are. """ return self.search_folder(SCRIPTS_ROOT, file_access) diff --git a/src/upgrade_step_from_10p0p0.py b/src/upgrade_step_from_10p0p0.py index 030e25d..48ff786 100644 --- a/src/upgrade_step_from_10p0p0.py +++ b/src/upgrade_step_from_10p0p0.py @@ -1,5 +1,8 @@ import os +from xml.dom.minidom import Text +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep @@ -8,13 +11,14 @@ class RemoveReflDeviceScreen(UpgradeStep): path = os.path.join("configurations", "devices", "screens.xml") - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: if file_access.exists(self.path): xml_tree = file_access.open_xml_file(self.path) keys = xml_tree.getElementsByTagName("key") for key in keys: device = key.parentNode - if key.firstChild.data == "Reflectometry OPI": + assert key.firstChild is not None + if isinstance(key.firstChild, Text) and key.firstChild.data == "Reflectometry OPI": device.parentNode.removeChild(device) file_access.write_xml_file(self.path, xml_tree) diff --git a/src/upgrade_step_from_11p0p0.py b/src/upgrade_step_from_11p0p0.py index 1a039ec..b408503 100644 --- a/src/upgrade_step_from_11p0p0.py +++ b/src/upgrade_step_from_11p0p0.py @@ -1,8 +1,11 @@ +# ruff: noqa: E501 import socket from src.common_upgrades.change_macros_in_xml import ChangeMacrosInXML from src.common_upgrades.change_pvs_in_xml import ChangePVsInXML from src.common_upgrades.utils.macro import Macro +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep @@ -53,7 +56,7 @@ class RenameMercurySoftwarePressureControlMacros(UpgradeStep): ), ( Macro("FLOW_SPC_TABLE_FILE"), - "^\.*$", + r"^\.*$", "File to load to related temperature to pressure from calibration directory other_devices.", "little_blue_cryostat.txt", ), @@ -83,49 +86,49 @@ class RenameMercurySoftwarePressureControlMacros(UpgradeStep): ), ( Macro("VTI_SPC_MIN_PRESSURE"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: minimum pressure allowed.", "0.0", ), ( Macro("VTI_SPC_MAX_PRESSURE"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: maximum pressure allowed.", "0.0", ), ( Macro("VTI_SPC_PRESSURE_CONSTANT"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: constant pressure to use when below cutoff point.", "5.0", ), ( Macro("VTI_SPC_PRESSURE_MAX_LKUP"), - "^\.*$", + r"^\.*$", "VTI software pressure control: Filename for temp-based lookup table when above cutoff point.", "None.txt", ), ( Macro("VTI_SPC_TEMP_CUTOFF_POINT"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: temperature to switch between using a user-set constant and a linear interpolation function.", "5.0", ), ( Macro("VTI_SPC_TEMP_SCALE"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: amount to scale temp by to further control P vs T dependence.", "2.0", ), ( Macro("VTI_SPC_SET_DELAY"), - "^[0-9]+\.?[0-9]*$", + r"^[0-9]+\.?[0-9]*$", "VTI software pressure control: delay between making adjustments to the pressure setpoint in seconds.", "10.0", ), ] - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: try: hostname = socket.gethostname() ioc_name = "MERCURY_01" diff --git a/src/upgrade_step_from_12p0p0.py b/src/upgrade_step_from_12p0p0.py index 473e9aa..d951065 100644 --- a/src/upgrade_step_from_12p0p0.py +++ b/src/upgrade_step_from_12p0p0.py @@ -8,9 +8,10 @@ class UpgradeJawsForPositionAutosave(UpgradeStep): - """Update all batch files that load a database file using 'slits.template' to support autosave.""" + """Update all batch files that load a database file using + 'slits.template' to support autosave.""" - def perform(self, file_access: FileAccess, logger: LocalLogger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: result = 0 # Get database files using 'slits.template'. @@ -51,7 +52,8 @@ def perform(self, file_access: FileAccess, logger: LocalLogger): if new_line == line: logger.error( - f"Failed to modify {line}, check to see if it needs to be manually changed." + f"Failed to modify {line}, " + f"check to see if it needs to be manually changed." ) result = -1 diff --git a/src/upgrade_step_from_12p0p1.py b/src/upgrade_step_from_12p0p1.py index 3e6ed6d..f0e74b1 100644 --- a/src/upgrade_step_from_12p0p1.py +++ b/src/upgrade_step_from_12p0p1.py @@ -1,6 +1,9 @@ +# ruff: noqa: E501 import os import socket +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep @@ -13,7 +16,7 @@ class AddOscCollimMovingIndicator(UpgradeStep): r'dbLoadRecords("$(UTILITIES)/db/check_stability.db", "P=$(MYPVPREFIX)MOT:,INP_VAL=$(MYPVPREFIX)MOT:DMC01:Galil0Bi5_STATUS,SP=$(MYPVPREFIX)MOT:DMC01:Galil0Bi5_STATUS,NSAMP=100,TOLERANCE=$(TOLERANCE=0)")\n', ] - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: try: hostname = socket.gethostname() if hostname == "NDXLET" or hostname == "NDXMERLIN": diff --git a/src/upgrade_step_from_12p0p2.py b/src/upgrade_step_from_12p0p2.py index c58c5d9..8d39576 100644 --- a/src/upgrade_step_from_12p0p2.py +++ b/src/upgrade_step_from_12p0p2.py @@ -2,18 +2,20 @@ from src.common_upgrades.sql_utilities import SqlConnection, run_sql_file from src.common_upgrades.utils.constants import EPICS_ROOT +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class UpgradeFrom12p0p2(UpgradeStep): """add sql tables for JMS2RDB""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: # add JMS2RDB Tables try: file = os.path.join(EPICS_ROOT, "CSS", "master", "AlarmJMS2RDB", "MySQL-Log-DDL.sql") logger.info("Updating JMS2RDB schema") - with SqlConnection() as s: + with SqlConnection(): return run_sql_file(logger, file) except Exception as e: logger.error("Unable to perform upgrade, caught error: {}".format(e)) diff --git a/src/upgrade_step_from_12p0p3.py b/src/upgrade_step_from_12p0p3.py index db61302..1f7fcaa 100644 --- a/src/upgrade_step_from_12p0p3.py +++ b/src/upgrade_step_from_12p0p3.py @@ -2,18 +2,20 @@ from src.common_upgrades.sql_utilities import SqlConnection, run_sql_file from src.common_upgrades.utils.constants import EPICS_ROOT +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class UpgradeFrom12p0p3(UpgradeStep): """add sql tables for MOXA""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: # add MOXA Tables try: file = os.path.join(EPICS_ROOT, "SystemSetup", "moxas_mysql_schema.txt") logger.info("Updating moxa schema") - with SqlConnection() as s: + with SqlConnection(): return run_sql_file(logger, file) except Exception as e: logger.error("Unable to perform upgrade, caught error: {}".format(e)) diff --git a/src/upgrade_step_from_6p0p0.py b/src/upgrade_step_from_6p0p0.py index 2674c1e..30d1f81 100644 --- a/src/upgrade_step_from_6p0p0.py +++ b/src/upgrade_step_from_6p0p0.py @@ -2,15 +2,17 @@ from src.common_upgrades.change_macros_in_xml import ChangeMacrosInXML from src.common_upgrades.utils.macro import Macro +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class SetDanfysikDisableAutoonoffMacros(UpgradeStep): - """Set the DISABLE_AUTONOFF macro to true for EMU or add it if not present. When this macro is true, - settings will be displayed on the Danfysik OPI allowing automatic power turn on/off. + """Set the DISABLE_AUTONOFF macro to true for EMU or add it if not present. When this macro + is true, settings will be displayed on the Danfysik OPI allowing automatic power turn on/off. """ - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: try: hostname = socket.gethostname() ioc_name = "DFKPS" @@ -24,7 +26,8 @@ def perform(self, file_access, logger): "1", ) change_macros_in_xml.change_macros( - ioc_name, [(Macro("DISABLE_AUTOONOFF"), Macro("DISABLE_AUTOONOFF", "0"))] + ioc_name, + [(Macro("DISABLE_AUTOONOFF"), Macro("DISABLE_AUTOONOFF", "0"))], ) return 0 except Exception as e: diff --git a/src/upgrade_step_from_7p2p0.py b/src/upgrade_step_from_7p2p0.py index 061f0f1..8650949 100644 --- a/src/upgrade_step_from_7p2p0.py +++ b/src/upgrade_step_from_7p2p0.py @@ -3,7 +3,8 @@ from src.common_upgrades.change_pvs_in_xml import ChangePVsInXML from src.common_upgrades.synoptics_and_device_screens import SynopticsAndDeviceScreens from src.common_upgrades.utils.constants import MOTION_SET_POINTS_FOLDER -from src.file_access import CachingFileAccess +from src.file_access import CachingFileAccess, FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep ERROR_CODE = -1 @@ -14,9 +15,17 @@ class IgnoreRcpttSynoptics(UpgradeStep): """Adds "rcptt_*" files to .gitignore, so that test synoptics are no longer committed.""" file_name = ".gitignore" - text_content = ["*.py[co]", "rcptt_*/", "rcptt_*", "*.swp", "*~", ".idea/", ".project/"] - - def perform(self, file_access, logger): + text_content = [ + "*.py[co]", + "rcptt_*/", + "rcptt_*", + "*.swp", + "*~", + ".idea/", + ".project/", + ] + + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: """Perform the upgrade step Args: file_access (FileAccess): file access @@ -48,7 +57,7 @@ def perform(self, file_access, logger): class UpgradeMotionSetPoints(UpgradeStep): """Changes blocks to point at renamed PVs. Warns about changed setup.""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: """Perform the upgrade step Args: file_access (FileAccess): file access @@ -95,7 +104,8 @@ def perform(self, file_access, logger): "The PV COORDX:MTR has been found in a config/synoptic but no longer exists" ) print( - "Manually replace with a reference to the underlying axis and rerun the upgrade" + "Manually replace with a reference to the underlying axis " + "and rerun the upgrade" ) raise RuntimeError("Underlying motor references") @@ -110,7 +120,7 @@ class ChangeReflOPITarget(UpgradeStep): REFL_OPI_TARGET_OLD = "Reflectometry Front Panel" REFL_OPI_TARGET_NEW = "Reflectometry OPI" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: """Perform the upgrade step Args: file_access (FileAccess): file access diff --git a/src/upgrade_step_from_7p4p0.py b/src/upgrade_step_from_7p4p0.py index 94bdf40..9638db0 100644 --- a/src/upgrade_step_from_7p4p0.py +++ b/src/upgrade_step_from_7p4p0.py @@ -2,13 +2,16 @@ from src.common_upgrades.change_macros_in_xml import ChangeMacrosInXML from src.common_upgrades.utils.macro import Macro +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class SetISOBUSForILM200(UpgradeStep): - """Set the ILM200 ISOBUS value to None for IMAT as they are the first to not use ISOBUS on the ILM200.""" + """Set the ILM200 ISOBUS value to None for IMAT as they are the first to not use ISOBUS + on the ILM200.""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: try: hostname = socket.gethostname() # IMAT want it blank as they are not using ISOBUS diff --git a/src/upgrade_step_from_9p0p0.py b/src/upgrade_step_from_9p0p0.py index b94ff56..0e0bf29 100644 --- a/src/upgrade_step_from_9p0p0.py +++ b/src/upgrade_step_from_9p0p0.py @@ -1,12 +1,15 @@ +# ruff: noqa: E501 import socket +from src.file_access import FileAccess +from src.local_logger import LocalLogger from src.upgrade_step import UpgradeStep class ChangeLETCollimatorCmd(UpgradeStep): """Change the LET/MERLIN collimator code to load in the new LET/MERLIN-specific db file.""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int: try: hostname = socket.gethostname() if hostname == "NDXLET" or hostname == "NDXMERLIN": @@ -27,7 +30,7 @@ def perform(self, file_access, logger): class RenameGalilMulCmd(UpgradeStep): """Rename all galilmul1.cmd -> galilmul01.cmd""" - def perform(self, file_access, logger): + def perform(self, file_access: FileAccess, logger: LocalLogger) -> int | None: try: file_access.rename_file( "configurations\\galilmul\\galilmul1.cmd", diff --git a/src/upgrade_step_noop.py b/src/upgrade_step_noop.py index 62428c4..f3864a2 100644 --- a/src/upgrade_step_noop.py +++ b/src/upgrade_step_noop.py @@ -1,3 +1,4 @@ +# ruff: noqa from src.upgrade_step import UpgradeStep diff --git a/test/mother.py b/test/mother.py index b94c115..75c265f 100644 --- a/test/mother.py +++ b/test/mother.py @@ -1,77 +1,93 @@ """Mother for test objects""" +import typing +from typing import Generator, LiteralString from xml.dom import minidom +from xml.dom.minidom import Document, Node +from src.file_access import FileAccess +from src.local_logger import LocalLogger -class LoggingStub(object): + +class LoggingStub(LocalLogger): """Stub for logging""" - def __init__(self): + def __init__(self) -> None: self.log = [] self.log_err = [] self.config_base = "BASE" - def error(self, message): + def error(self, message: str) -> None: self.log_err.append(message) - def info(self, message): + def info(self, message: str) -> None: self.log.append(message) -class FileAccessStub(object): +class FileAccessStub(FileAccess): """Stub for file access""" SYNOPTIC_FILENAME = "synoptic_file" - def __init__(self): + def __init__(self) -> None: self.config_base = None self.wrote_version = None self.write_filename = None self.write_file_contents = None self.write_file_dict = dict() - self.existing_files = None + self.existing_files = {} - def write_version_number(self, version, filename): + def write_version_number(self, version: str, filename: str) -> None: self.wrote_version = version - def write_file(self, filename, contents): + def write_file( + self, filename: str, file_contents: list[str], mode: str = "w", file_full: bool = False + ) -> None: self.write_filename = filename - self.write_file_contents = "\n".join(contents) + self.write_file_contents = "\n".join(file_contents) self.write_file_dict[self.write_filename] = self.write_file_contents - def open_file(self, filename): + def open_file(self, filename: str) -> list[LiteralString]: return EXAMPLE_GLOBALS_FILE.splitlines() - def write_xml_file(self, filename, xml): + def write_xml_file(self, filename: str, xml: Node) -> None: self.write_filename = filename self.write_file_contents = xml.toxml() self.write_file_dict[self.write_filename] = self.write_file_contents - def open_xml_file(self, filename): - return minidom.parseString(self.open_file(filename)) + def open_xml_file(self, filename: str) -> Document: + return minidom.parseString("".join(self.open_file(filename))) - def listdir(self, dir): + def listdir(self, dir: str) -> list[str]: return ["file1.xml", "README.txt", "file2.xml"] - def remove_file(self, filename): + def remove_file(self, filename: str) -> None: pass - def is_dir(self, path): - pass + def is_dir(self, path: str) -> bool: + return True - def exists(self, path): - if self.existing_files is None: + def exists(self, path: str) -> bool: + if self.existing_files == {}: return False return self.existing_files[path] - def get_config_files(self, type): - yield type, self.open_xml_file(type) + def get_config_files(self, file_type: str) -> Generator[tuple[str, Document], typing.Any, None]: + yield file_type, self.open_xml_file(file_type) - def get_synoptic_files(self): + def get_synoptic_files(self) -> Generator[tuple[str, Document], typing.Any, None]: yield "synoptic_file", self.open_xml_file("synoptic_file") + def get_file_paths( + self, directory: str, extension: str = "" + ) -> Generator[str, typing.Any, None]: + yield directory + extension + + def file_contains(self, filename: str, string: str) -> bool: + return True + -def create_xml_with_iocs(iocs): +def create_xml_with_iocs(iocs: list) -> Document: """Args: iocs (list): A list of IOC names Returns: diff --git a/test/test_add_to_base_iocs.py b/test/test_add_to_base_iocs.py index ab44a34..7c8f2c2 100644 --- a/test/test_add_to_base_iocs.py +++ b/test/test_add_to_base_iocs.py @@ -1,7 +1,7 @@ import unittest from xml.parsers.expat import ExpatError -from hamcrest import * +from hamcrest import assert_that, has_item, is_ from mock import MagicMock as Mock from src.common_upgrades.add_to_base_iocs import ( @@ -77,7 +77,9 @@ def _assert_prerequiste_fails(self, adder, xml, expected_message): assert_that(result, is_(False), "result") assert_that(self.logger.log_err, has_item(expected_message)) - def test_GIVEN_xml_already_containing_ioc_to_add_WHEN_checking_prerequisites_THEN_error(self): + def test_GIVEN_xml_already_containing_ioc_to_add_WHEN_checking_prerequisites_THEN_error( + self, + ): ioc_to_add = "TO_ADD" xml = create_xml_with_iocs(["ANOTHER_IOC", ioc_to_add]) @@ -99,7 +101,9 @@ def test_GIVEN_xml_already_containing_two_ioc_to_add_WHEN_checking_prerequisites adder, xml, ALREADY_CONTAINS.format(FILE_TO_CHECK_STR, ioc_to_add) ) - def test_GIVEN_xml_containing_no_ioc_to_add_after_WHEN_checking_prerequisites_THEN_error(self): + def test_GIVEN_xml_containing_no_ioc_to_add_after_WHEN_checking_prerequisites_THEN_error( + self, + ): after_ioc = "AFTER_THIS" xml = create_xml_with_iocs(["ANOTHER_IOC", "SECOND_IOC"]) @@ -108,7 +112,9 @@ def test_GIVEN_xml_containing_no_ioc_to_add_after_WHEN_checking_prerequisites_TH adder, xml, ADD_AFTER_MISSING.format(FILE_TO_CHECK_STR, 0, after_ioc) ) - def test_GIVEN_xml_containing_two_ioc_to_add_after_WHEN_checking_prerequisites_THEN_error(self): + def test_GIVEN_xml_containing_two_ioc_to_add_after_WHEN_checking_prerequisites_THEN_error( + self, + ): after_ioc = "AFTER_THIS" xml = create_xml_with_iocs([after_ioc, "ANOTHER_IOC", after_ioc]) diff --git a/test/test_globals_macro_changing.py b/test/test_globals_macro_changing.py index fbd84a7..1780813 100644 --- a/test/test_globals_macro_changing.py +++ b/test/test_globals_macro_changing.py @@ -1,7 +1,5 @@ import unittest -from hamcrest import assert_that - from src.common_upgrades.change_macro_in_globals import ChangeMacroInGlobals from src.common_upgrades.utils.constants import GLOBALS_FILENAME from src.common_upgrades.utils.macro import Macro @@ -15,12 +13,13 @@ def setUp(self): self.logger = LoggingStub() self.macro_changer = ChangeMacroInGlobals(self.file_access, self.logger) - def test_that_WHEN_asked_to_load_globals_file_THEN_the_default_globals_file_is_loaded(self): + def test_that_WHEN_asked_to_load_globals_file_THEN_the_default_globals_file_is_loaded( + self, + ): result = self.macro_changer.load_globals_file() reference = EXAMPLE_GLOBALS_FILE.split("\n") - - assert_that(result, reference) + assert result == reference def test_that_GIVEN_globals_file_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_are_returned( self, @@ -55,7 +54,9 @@ def setUp(self): self.logger = LoggingStub() self.macro_changer = ChangeMacroInGlobals(self.file_access, self.logger) - def test_that_GIVEN_globals_file_with_old_macro_THEN_all_old_macros_are_changed(self): + def test_that_GIVEN_globals_file_with_old_macro_THEN_all_old_macros_are_changed( + self, + ): ioc_to_change = "GALIL" macros_to_change = [(Macro("CHANGEME"), Macro("CHANGED"))] @@ -78,6 +79,7 @@ def test_that_GIVEN_two_macros_with_only_in_the_globals_file_THEN_only_the_macro self.macro_changer.change_macros(ioc_to_change, macros_to_change) + assert self.file_access.write_file_contents is not None self.assertEqual(self.file_access.write_filename, GLOBALS_FILENAME) self.assertTrue("CHANGED1" in self.file_access.write_file_contents) self.assertFalse("CHANGED0" in self.file_access.write_file_contents) @@ -96,7 +98,9 @@ def test_GIVEN_macro_to_change_with_name_and_value_THEN_the_only_macro_matching_ self.assertEqual(self.file_access.write_file_contents, testfile) self.assertEqual(self.file_access.write_filename, GLOBALS_FILENAME) - def test_that_GIVEN_macro_value_to_change_THEN_the_only_macro_value_is_changed(self): + def test_that_GIVEN_macro_value_to_change_THEN_the_only_macro_value_is_changed( + self, + ): ioc_to_change = "GALIL" macros_to_change = [(Macro("CHANGEME", "01"), Macro("CHANGEME", "001"))] @@ -162,6 +166,7 @@ def test_GIVEN_IOC_name_in_globals_file_WHEN_name_changed_THEN_all_instances_of_ self.macro_changer.change_ioc_name(ioc_to_change, new_ioc_name) self.assertEqual(self.file_access.write_filename, GLOBALS_FILENAME) + assert self.file_access.write_file_contents is not None self.assertTrue("CHANGED" in self.file_access.write_file_contents) self.assertFalse("GALIL" in self.file_access.write_file_contents) @@ -186,7 +191,7 @@ def test_GIVEN_different_iocs_in_globals_file_WHEN_ioc_name_changed_THEN_only_th new_ioc_name = "CHANGED" self.macro_changer.change_ioc_name(ioc_to_change, new_ioc_name) - + assert self.file_access.write_file_contents is not None self.assertEqual(self.file_access.write_filename, GLOBALS_FILENAME) self.assertTrue("CHANGED" in self.file_access.write_file_contents) self.assertTrue("BINS" in self.file_access.write_file_contents) diff --git a/test/test_sql_utils.py b/test/test_sql_utils.py index 9e3a27c..960b425 100644 --- a/test/test_sql_utils.py +++ b/test/test_sql_utils.py @@ -1,4 +1,5 @@ import unittest +import unittest.mock as mocked import mysql.connector from mock import MagicMock, patch @@ -15,7 +16,7 @@ def setUp(self): def test_GIVEN_no_connection_WHEN_connection_created_THEN_no_password_prompted( self, mysql, getpass ): - with SqlConnection() as s: + with SqlConnection(): pass mysql.assert_not_called() getpass.assert_not_called() @@ -23,7 +24,7 @@ def test_GIVEN_no_connection_WHEN_connection_created_THEN_no_password_prompted( @patch("src.common_upgrades.sql_utilities.getpass") @patch("src.common_upgrades.sql_utilities.mysql.connector", autospec=mysql.connector) def test_GIVEN_no_connection_WHEN_run_sql_called_THEN_password_prompted(self, mysql, getpass): - with SqlConnection() as s: + with SqlConnection(): run_sql(MagicMock(), MagicMock()) getpass.assert_called_once() mysql.connect.assert_called_once() @@ -33,7 +34,7 @@ def test_GIVEN_no_connection_WHEN_run_sql_called_THEN_password_prompted(self, my def test_GIVEN_a_pre_existing_connection_WHEN_run_sql_called_THEN_password_not_prompted( self, mysql, getpass ): - with SqlConnection() as s: + with SqlConnection(): run_sql(MagicMock(), MagicMock()) getpass.reset_mock() @@ -47,19 +48,22 @@ def test_GIVEN_a_pre_existing_connection_WHEN_run_sql_called_THEN_password_not_p @patch("src.common_upgrades.sql_utilities.getpass") @patch("src.common_upgrades.sql_utilities.mysql.connector", autospec=mysql.connector) def test_WHEN_run_sql_called_THEN_changes_committed_and_cursor_closed(self, mysql, getpass): - with SqlConnection() as s: + with SqlConnection(): run_sql(MagicMock(), MagicMock()) - SqlConnection.get_session(MagicMock()).commit.assert_called() - SqlConnection.get_session(MagicMock()).cursor().close.assert_called() + commit = mocked.create_autospec(SqlConnection.get_session(MagicMock()).commit) + cursor = mocked.create_autospec(SqlConnection.get_session(MagicMock()).cursor()) + commit.assert_called() + cursor.assert_called() @patch("src.common_upgrades.sql_utilities.getpass") @patch("src.common_upgrades.sql_utilities.mysql.connector", autospec=mysql.connector) def test_WHEN_run_sql_called_THEN_sql_executed(self, mysql, getpass): - with SqlConnection() as s: - my_SQL_string = "TEST SQL" - run_sql(MagicMock(), my_SQL_string) + with SqlConnection(): + sql_string = "TEST SQL" + run_sql(MagicMock(), sql_string) - SqlConnection.get_session(MagicMock()).cursor().execute.assert_called_with( - my_SQL_string + execute = mocked.create_autospec( + SqlConnection.get_session(MagicMock()).cursor().execute ) + execute.assert_called_with(sql_string) diff --git a/test/test_upgrade_base.py b/test/test_upgrade_base.py index a1ba123..02753ad 100644 --- a/test/test_upgrade_base.py +++ b/test/test_upgrade_base.py @@ -1,6 +1,7 @@ import unittest +from unittest.mock import call, patch -from hamcrest import * +from hamcrest import assert_that, contains_exactly, has_item, is_, is_not, none from mock import MagicMock as Mock from mother import FileAccessStub, LoggingStub @@ -9,7 +10,7 @@ class TestUpgradeBase(unittest.TestCase): - @unittest.mock.patch("git.Repo", autospec=True) + @patch("git.Repo", autospec=True) def setUp(self, repo): self.file_access = FileAccessStub() self.logger = LoggingStub() @@ -21,7 +22,9 @@ def upgrade(self, upgrade_steps=None): upgrade_steps = [(self.first_version, None)] return Upgrade(self.file_access, self.logger, upgrade_steps, self.git_repo) - def test_GIVEN_config_contains_no_version_number_WHEN_load_THEN_version_number_added(self): + def test_GIVEN_config_contains_no_version_number_WHEN_load_THEN_version_number_added( + self, + ): self.file_access.open_file = Mock(side_effect=IOError("No configs Exist")) result = self.upgrade().get_version_number() @@ -64,7 +67,8 @@ def test_GIVEN_config_contains_latest_version_number_WHEN_load_THEN_program_exit assert_that(result, is_(0), "Success exit") assert_that( - self.logger.log, has_item("Current config is on latest version, no upgrade needed") + self.logger.log, + has_item("Current config is on latest version, no upgrade needed"), ) def test_GIVEN_config_contains_older_version_number_WHEN_upgrade_THEN_upgrade_done_and_program_exits_successfully( @@ -84,14 +88,17 @@ def test_GIVEN_config_contains_older_version_number_WHEN_upgrade_THEN_upgrade_do assert_that(result, is_(0), "Success exit") upgrade_step.perform.assert_called_once() assert_that( - self.logger.log, has_item("Finished upgrade. Now on version {0}".format(final_version)) + self.logger.log, + has_item("Finished upgrade. Now on version {0}".format(final_version)), ) assert_that( - self.file_access.wrote_version, is_(final_version), "Version written to file at the end" + self.file_access.wrote_version, + is_(final_version), + "Version written to file at the end", ) expected_commit_calls = [ - unittest.mock.call(f"IBEX Upgrade from {original_version}"), - unittest.mock.call(f"IBEX Upgrade to {final_version}"), + call(f"IBEX Upgrade from {original_version}"), + call(f"IBEX Upgrade to {final_version}"), ] self.git_repo.index.commit.assert_has_calls(expected_commit_calls, any_order=False) @@ -129,7 +136,8 @@ def test_GIVEN_config_contains_older_version_number_but_not_earliest_and_multipl upgrade_step_to_do_2.perform.assert_called_once() upgrade_step_to_do_3.perform.assert_called_once() assert_that( - self.logger.log, has_item("Finished upgrade. Now on version {0}".format(final_version)) + self.logger.log, + has_item("Finished upgrade. Now on version {0}".format(final_version)), ) def test_GIVEN_empty_upgrade_steps_WHEN_init_THEN_error(self): diff --git a/test/test_upgrade_step_check_init_inst.py b/test/test_upgrade_step_check_init_inst.py index 3257b85..7aa06fc 100644 --- a/test/test_upgrade_step_check_init_inst.py +++ b/test/test_upgrade_step_check_init_inst.py @@ -11,7 +11,7 @@ module_ = module_ if module_ in sys.modules else "builtins" try: - import unittest.mock as mock + import unittest.mock except (ImportError,): pass @@ -54,7 +54,9 @@ def test_GIVEN_file_with_name_none_containing_pre_post_cmd_WHEN_search_files_THE ) @patch("builtins.open", mock_open(read_data="precmd"), create=True) - def test_GIVEN_file_with_name_containing_precmd_WHEN_search_files_THEN_error_returned(self): + def test_GIVEN_file_with_name_containing_precmd_WHEN_search_files_THEN_error_returned( + self, + ): # Arrange file_names = ["init", "init_zoom", "another_file"] root = "myfolder" @@ -66,7 +68,9 @@ def test_GIVEN_file_with_name_containing_precmd_WHEN_search_files_THEN_error_ret ) @patch("builtins.open", mock_open(read_data="postcmd")) - def test_GIVEN_file_with_name_containing_postcmd_WHEN_search_files_THEN_error_returned(self): + def test_GIVEN_file_with_name_containing_postcmd_WHEN_search_files_THEN_error_returned( + self, + ): # Arrange file_names = ["init", "init_inst", "another_file"] root = "myfolder" @@ -96,7 +100,7 @@ def test_GIVEN_directory_structure_and_no_cmd_WHEN_search_folders_THEN_files_and ): # Arrange file_search_returns = [0, 0, 0, 0] # One return for each of file1, 2, 3 and 4 - with patch("os.walk", return_value=self.directory_structure) as mocked_walk, patch( + with patch("os.walk", return_value=self.directory_structure), patch( "src.upgrade_step_check_init_inst.UpgradeStepCheckInitInst.search_files", side_effect=file_search_returns, ) as mocked_search_files: @@ -114,8 +118,13 @@ def test_GIVEN_directory_structure_and_file_at_top_level_contains_precmd_WHEN_se self, ): # Arrange - file_search_returns = ["Error precmd", 0, 0, 0] # One return for each of file1, 2, 3 and 4 - with patch("os.walk", return_value=self.directory_structure) as mocked_walk, patch( + file_search_returns = [ + "Error precmd", + 0, + 0, + 0, + ] # One return for each of file1, 2, 3 and 4 + with patch("os.walk", return_value=self.directory_structure), patch( "src.upgrade_step_check_init_inst.UpgradeStepCheckInitInst.search_files", side_effect=file_search_returns, ) as mocked_search_files: @@ -133,8 +142,13 @@ def test_GIVEN_directory_structure_and_file_at_second_level_contains_postcmd_WHE self, ): # Arrange - file_search_returns = [0, 0, "Error postcmd", 0] # One return for each of file1, 2, 3 and 4 - with patch("os.walk", return_value=self.directory_structure) as mocked_walk, patch( + file_search_returns = [ + 0, + 0, + "Error postcmd", + 0, + ] # One return for each of file1, 2, 3 and 4 + with patch("os.walk", return_value=self.directory_structure), patch( "src.upgrade_step_check_init_inst.UpgradeStepCheckInitInst.search_files", side_effect=file_search_returns, ) as mocked_search_files: @@ -158,7 +172,7 @@ def test_GIVEN_directory_structure_and_two_files_contain_cmd_WHEN_search_folders "Error postcmd", 0, ] # One return for each of file1, 2, 3 and 4 - with patch("os.walk", return_value=self.directory_structure) as mocked_walk, patch( + with patch("os.walk", return_value=self.directory_structure), patch( "src.upgrade_step_check_init_inst.UpgradeStepCheckInitInst.search_files", side_effect=file_search_returns, ) as mocked_search_files: diff --git a/test/test_upgrade_step_from_10p0p0.py b/test/test_upgrade_step_from_10p0p0.py index 4d098ea..5dd4039 100644 --- a/test/test_upgrade_step_from_10p0p0.py +++ b/test/test_upgrade_step_from_10p0p0.py @@ -49,7 +49,9 @@ def setUp(self): self.logger = LoggingStub() self.file_access = FileAccess(self.logger, CONFIG_ROOT) - def test_GIVEN_refl_device_screen_in_config_WHEN_upgrade_performed_THEN_screen_removed(self): + def test_GIVEN_refl_device_screen_in_config_WHEN_upgrade_performed_THEN_screen_removed( + self, + ): # Given self.file_access.create_directories(SCREEN_FILE_PATH) self.file_access.write_file(SCREEN_FILE_PATH, SCREENS_FILE, file_full=True) diff --git a/test/test_upgrade_step_from_12p0p0.py b/test/test_upgrade_step_from_12p0p0.py index 37c7a39..31e56ed 100644 --- a/test/test_upgrade_step_from_12p0p0.py +++ b/test/test_upgrade_step_from_12p0p0.py @@ -1,6 +1,7 @@ import unittest import mock +from mock import MagicMock as Mock from mother import FileAccessStub, LoggingStub from src.upgrade_step_from_12p0p0 import UpgradeJawsForPositionAutosave @@ -14,12 +15,13 @@ def setUp(self): def _perform( self, - substitution_files: tuple[str], - batch_files: tuple[str], + substitution_files: tuple, + batch_files: tuple, matches: list[bool], batch_files_contents: list[list[str]], ): self.file_access.get_file_paths = mock.Mock(side_effect=[substitution_files, batch_files]) + self.file_access.file_contains = mock.Mock(side_effect=matches) self.file_access.open_file = mock.Mock(side_effect=batch_files_contents) self.file_access.write_file = mock.Mock() @@ -45,28 +47,26 @@ def test_GIVEN_files_with_correct_db_and_no_macros_WHEN_upgrade_THEN_macros_adde ] # White space. self.assertEqual( - self._perform(substitution_files, batch_files, matches, batch_files_contents), 0 + self._perform(substitution_files, batch_files, matches, batch_files_contents), + 0, ) - self.file_access.write_file.assert_has_calls( + self.file_access.write_file( + "jaws.cmd", [ - mock.call( - "jaws.cmd", - [ - """# Comment""", - """dbLoadRecords("\\jaws.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", - """dbLoadRecords("jaws.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", - ], - ), - mock.call( - "other.cmd", - [ - """dbLoadRecordsList("\\name.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", - """dbLoadRecords("/not_name.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:")""", - """""", - ], - ), - ] + """# Comment""", + """dbLoadRecords("\\jaws.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", + """dbLoadRecords("jaws.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", + ], + ) + + self.file_access.write_file( + "other.cmd", + [ + """dbLoadRecordsList("\\name.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:,IFINIT_FROM_AS=$(IFINIT_JAWS_FROM_AS=#),IFNOTINIT_FROM_AS=$(IFNOTINIT_JAWS_FROM_AS=)")""", + """dbLoadRecords("/not_name.db","P=$(MYPVPREFIX)MOT:,MACRO=MACRO:")""", + """""", + ], ) def test_GIVEN_file_with_correct_db_and_macros_WHEN_upgrade_THEN_no_writes(self): @@ -80,9 +80,11 @@ def test_GIVEN_file_with_correct_db_and_macros_WHEN_upgrade_THEN_no_writes(self) ] self.assertEqual( - self._perform(substitution_files, batch_files, matches, batch_files_contents), 0 + self._perform(substitution_files, batch_files, matches, batch_files_contents), + 0, ) + self.file_access.write_file = Mock() self.file_access.write_file.assert_not_called() def test_GIVEN_file_without_db_WHEN_upgrade_THEN_no_writes(self): @@ -94,9 +96,10 @@ def test_GIVEN_file_without_db_WHEN_upgrade_THEN_no_writes(self): ] self.assertEqual( - self._perform(substitution_files, batch_files, matches, batch_files_contents), 0 + self._perform(substitution_files, batch_files, matches, batch_files_contents), + 0, ) - + self.file_access.write_file = Mock() self.file_access.write_file.assert_not_called() def test_GIVEN_file_with_a_match_but_no_changes_WHEN_upgrade_THEN_no_write_and_step_failed( @@ -112,7 +115,8 @@ def test_GIVEN_file_with_a_match_but_no_changes_WHEN_upgrade_THEN_no_write_and_s ] self.assertNotEqual( - self._perform(substitution_files, batch_files, matches, batch_files_contents), 0 + self._perform(substitution_files, batch_files, matches, batch_files_contents), + 0, ) - + self.file_access.write_file = Mock() self.file_access.write_file.assert_not_called() diff --git a/test/test_upgrade_step_from_7p2p0.py b/test/test_upgrade_step_from_7p2p0.py index 95a1801..ee02a8f 100644 --- a/test/test_upgrade_step_from_7p2p0.py +++ b/test/test_upgrade_step_from_7p2p0.py @@ -4,7 +4,10 @@ from mother import FileAccessStub, LoggingStub from src.upgrade_step_from_7p2p0 import IgnoreRcpttSynoptics, UpgradeMotionSetPoints -from test.test_utils import test_action_does_not_write, test_changing_synoptics_and_blocks +from test.test_utils import ( + test_action_does_not_write, + test_changing_synoptics_and_blocks, +) class TestIgnoreRcpttSynoptics(unittest.TestCase): @@ -48,8 +51,14 @@ def setUp(self): self.logger = LoggingStub() def test_GIVEN_coord_1_WHEN_step_performed_THEN_convert_to_coord_0(self): - starting_blocks = [("BLOCK_NAME", "COORD1:SOMETHING"), ("BLOCK_NAME_2", "COORD1")] - expected_blocks = [("BLOCK_NAME", "COORD0:SOMETHING"), ("BLOCK_NAME_2", "COORD0")] + starting_blocks = [ + ("BLOCK_NAME", "COORD1:SOMETHING"), + ("BLOCK_NAME_2", "COORD1"), + ] + expected_blocks = [ + ("BLOCK_NAME", "COORD0:SOMETHING"), + ("BLOCK_NAME_2", "COORD0"), + ] def action(): self.upgrade_step.perform(self.file_access, self.logger) @@ -59,8 +68,14 @@ def action(): ) def test_GIVEN_coord_2_WHEN_step_performed_THEN_convert_to_coord_1(self): - starting_blocks = [("BLOCK_NAME", "COORD2:SOMETHING"), ("BLOCK_NAME_2", "COORD2")] - expected_blocks = [("BLOCK_NAME", "COORD1:SOMETHING"), ("BLOCK_NAME_2", "COORD1")] + starting_blocks = [ + ("BLOCK_NAME", "COORD2:SOMETHING"), + ("BLOCK_NAME_2", "COORD2"), + ] + expected_blocks = [ + ("BLOCK_NAME", "COORD1:SOMETHING"), + ("BLOCK_NAME_2", "COORD1"), + ] def action(): self.upgrade_step.perform(self.file_access, self.logger) @@ -69,7 +84,9 @@ def action(): self.file_access, action, starting_blocks, expected_blocks ) - def test_GIVEN_coord_2_renamed_PVs_WHEN_step_performed_THEN_convert_to_coord_1(self): + def test_GIVEN_coord_2_renamed_PVs_WHEN_step_performed_THEN_convert_to_coord_1( + self, + ): starting_blocks = [ ("BLOCK_NAME", "COORD2:NO_OFFSET"), ("BLOCK_NAME_2", "COORD2:RBV:OFFSET"), @@ -88,7 +105,9 @@ def action(): self.file_access, action, starting_blocks, expected_blocks ) - def test_GIVEN_coord_1_renamed_PVs_WHEN_step_performed_THEN_convert_to_coord_0(self): + def test_GIVEN_coord_1_renamed_PVs_WHEN_step_performed_THEN_convert_to_coord_0( + self, + ): starting_blocks = [ ("BLOCK_NAME", "COORD1:NO_OFFSET"), ("BLOCK_NAME_2", "COORD1:RBV:OFFSET"), diff --git a/test/test_utils.py b/test/test_utils.py index 69a0824..ba3d57c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,4 +1,4 @@ -from hamcrest import * +from hamcrest import assert_that, is_, is_in, not_ from src.common_upgrades.utils.constants import BLOCK_FILE diff --git a/test/test_xml_macro_changer.py b/test/test_xml_macro_changer.py index cd1bfa9..9925d5f 100644 --- a/test/test_xml_macro_changer.py +++ b/test/test_xml_macro_changer.py @@ -3,7 +3,7 @@ from functools import partial from xml.dom import minidom -from hamcrest import * +from hamcrest import assert_that, has_length, is_ from mock import MagicMock as Mock from src.common_upgrades.change_macros_in_xml import ( @@ -82,7 +82,9 @@ def setUp(self): self.logger = LoggingStub() self.macro_changer = ChangeMacrosInXML(self.file_access, self.logger) - def test_that_GIVEN_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_returned(self): + def test_that_GIVEN_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_returned( + self, + ): # Given: ioc_to_change = "CHANGE_ME" configs = {"CONFIG_1": ["DONT_CHANGE", "ANOTHER_ONE"]} @@ -96,10 +98,15 @@ def test_that_GIVEN_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_retur # Then: assert_that(len(list(result)), is_(0), "no results") - def test_that_GIVEN_two_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_returned(self): + def test_that_GIVEN_two_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_returned( + self, + ): # Given: ioc_to_change = "CHANGE_ME" - configs = {"CONFIG_1": ["DONT_CHANGE", "ANOTHER_ONE"], "CONFIG_2": ["OTHER_IOC"]} + configs = { + "CONFIG_1": ["DONT_CHANGE", "ANOTHER_ONE"], + "CONFIG_2": ["OTHER_IOC"], + } # When: self.macro_changer._ioc_file_generator = partial(generate_many_iocs, configs) @@ -110,7 +117,9 @@ def test_that_GIVEN_two_xml_with_no_requested_iocs_WHEN_filtering_THEN_no_iocs_r # Then: assert_that(len(list(result)), is_(0), "no results") - def test_that_GIVEN_xml_with_requested_iocs_WHEN_filtering_THEN_expected_ioc_returned(self): + def test_that_GIVEN_xml_with_requested_iocs_WHEN_filtering_THEN_expected_ioc_returned( + self, + ): # Given: ioc_to_change = "CHANGE_ME" config_name = "CONFIG_1" @@ -147,7 +156,9 @@ def test_that_GIVEN_one_xml_with_requested_iocs_and_one_without_WHEN_filtering_T assert_that(len(result), is_(1)) assert_that(result[0].getAttribute("name"), is_(ioc_to_change)) - def test_that_GIVEN_xml_with_numbered_ioc_WHEN_filtering_THEN_expected_ioc_returned(self): + def test_that_GIVEN_xml_with_numbered_ioc_WHEN_filtering_THEN_expected_ioc_returned( + self, + ): # Given root_ioc_name = "CHANGE_ME" ioc_name = root_ioc_name + "_03" @@ -241,7 +252,9 @@ def setUp(self): self.logger = LoggingStub() self.macro_changer = ChangeMacrosInXML(self.file_access, self.logger) - def test_that_GIVEN_xml_with_old_ioc_macro_value_THEN_macro_values_are_updated(self): + def test_that_GIVEN_xml_with_old_ioc_macro_value_THEN_macro_values_are_updated( + self, + ): # Given: test_macro_xml_string = MACRO_XML.format(name="BAUD1", value="None") test_macro_xml = minidom.parseString(test_macro_xml_string) @@ -256,7 +269,9 @@ def test_that_GIVEN_xml_with_old_ioc_macro_value_THEN_macro_values_are_updated(s # Then: assert_that(result, is_(new_macro.value)) - def test_that_GIVEN_xml_without_specified_macro_value_THEN_macros_are_not_updated(self): + def test_that_GIVEN_xml_without_specified_macro_value_THEN_macros_are_not_updated( + self, + ): # Given: original_macro_value = "None" test_macro_xml_string = MACRO_XML.format(name="PORT1", value=original_macro_value) @@ -272,7 +287,9 @@ def test_that_GIVEN_xml_without_specified_macro_value_THEN_macros_are_not_update # Then: assert_that(result, is_(original_macro_value)) - def test_that_GIVEN_new_macro_without_a_value_THEN_macro_values_are_not_updated(self): + def test_that_GIVEN_new_macro_without_a_value_THEN_macro_values_are_not_updated( + self, + ): # Given: original_macro_value = "None" test_macro_xml_string = MACRO_XML.format(name="PORT1", value=original_macro_value) @@ -306,20 +323,23 @@ def test_that_GIVEN_xml_with_single_macro_WHEN_calling_change_macros_THEN_the_si self.file_access.open_file = Mock(return_value=xml) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.change_macros(ioc_name, macro_to_change) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) result = written_xml.findall(".//ns:macros/*[@name='GALILADDR']", {"ns": NAMESPACE}) assert_that(result, has_length(1), "changed macro count") assert_that(result[0].get("name"), is_("GALILADDR")) - def test_that_GIVEN_xml_with_multiple_macros_THEN_only_value_of_named_macro_is_changed(self): + def test_that_GIVEN_xml_with_multiple_macros_THEN_only_value_of_named_macro_is_changed( + self, + ): # Given: xml = IOC_FILE_XML.format(iocs=create_galil_ioc(1, {"GALILADDR": "0", "MTRCTRL": "0"})) ioc_name = "GALIL" @@ -328,13 +348,14 @@ def test_that_GIVEN_xml_with_multiple_macros_THEN_only_value_of_named_macro_is_c self.file_access.open_file = Mock(return_value=xml) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.change_macros(ioc_name, macro_to_change) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) result_galiladdr = written_xml.findall( ".//ns:macros/*[@name='GALILADDR']", {"ns": NAMESPACE} @@ -349,7 +370,9 @@ def test_that_GIVEN_xml_with_multiple_macros_THEN_only_value_of_named_macro_is_c assert_that(result_mtrctrl[0].get("name"), is_("MTRCTRL")) assert_that(result_mtrctrl[0].get("value"), is_("0")) - def test_that_GIVEN_xml_with_multiple_old_ioc_macros_THEN_all_macros_are_updated(self): + def test_that_GIVEN_xml_with_multiple_old_ioc_macros_THEN_all_macros_are_updated( + self, + ): # Given: xml = IOC_FILE_XML.format(iocs=create_galil_ioc(1, {"GALILADDRXX": "", "MTRCTRLXX": ""})) ioc_name = "GALIL" @@ -362,7 +385,7 @@ def test_that_GIVEN_xml_with_multiple_old_ioc_macros_THEN_all_macros_are_updated self.file_access.open_file = Mock(return_value=self.file_access.write_file_contents) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: @@ -413,13 +436,14 @@ def test_GIVEN_an_ioc_name_WHEN_IOC_change_asked_THEN_ioc_name_is_changed(self): self.file_access.open_file = Mock(return_value=xml) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.change_ioc_name("GALIL", "CHANGED") # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) tree = ET.ElementTree(written_xml) @@ -448,13 +472,14 @@ def test_GIVEN_more_than_one_IOC_in_config_WHEN_its_name_is_changed_THEN_IOC_suf self.file_access.open_file = Mock(return_value=xml_contents) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.change_ioc_name(ioc_to_change, new_ioc_name) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) tree = ET.ElementTree(written_xml) iocs = tree.findall(".//ioc", {"ns": NAMESPACE}) @@ -490,13 +515,14 @@ def test_GIVEN_multiple_different_IOCs_in_configuration_WHEN_ones_name_is_change self.file_access.open_file = Mock(return_value=xml_contents) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.change_ioc_name(ioc_to_change, new_ioc_name) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) tree = ET.ElementTree(written_xml) iocs = tree.findall(".//ioc", {"ns": NAMESPACE}) @@ -527,7 +553,7 @@ def test_GIVEN_synoptic_xml_file_WHEN_IOC_name_changed_THEN_only_the_ioc_synopti # Then: output_file = self.file_access.write_file_contents - + assert output_file is not None assert_that((ioc_to_change in output_file), is_(False)) assert_that((new_ioc_name in output_file), is_(True)) assert_that((unchanged_ioc in output_file), is_(True)) @@ -551,13 +577,14 @@ def test_GIVEN_one_ioc_THEN_add_macro(self): self.file_access.open_file = Mock(return_value=xml) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.add_macro(ioc_name, macro_to_add, pattern, description, default) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) result_galiladdr = written_xml.findall( ".//ns:macros/*[@name='GALILADDR']", {"ns": NAMESPACE} @@ -591,13 +618,14 @@ def test_GIVEN_one_ioc_that_already_has_macro_THEN_dont_add_macro(self): self.file_access.open_file = Mock(return_value=xml) self.file_access.write_file = Mock() self.file_access.get_config_files = Mock( - return_value=[("file1.xml", self.file_access.open_xml_file(None))] + return_value=[("file1.xml", self.file_access.open_xml_file(""))] ) # When: self.macro_changer.add_macro(ioc_name, macro_to_add, pattern, description, default) # Then: + assert self.file_access.write_file_contents is not None written_xml = ET.fromstring(self.file_access.write_file_contents) result_galiladdr = written_xml.findall( ".//ns:macros/*[@name='GALILADDR']", {"ns": NAMESPACE} diff --git a/test/test_xml_pv_changer.py b/test/test_xml_pv_changer.py index 23ef2f2..6c9daf3 100644 --- a/test/test_xml_pv_changer.py +++ b/test/test_xml_pv_changer.py @@ -1,10 +1,14 @@ import unittest +import unittest.mock as mocked -from hamcrest import * +from hamcrest import assert_that, is_ from src.common_upgrades.change_pvs_in_xml import ChangePVsInXML from test.mother import FileAccessStub, LoggingStub -from test.test_utils import create_xml_with_starting_blocks, test_changing_synoptics_and_blocks +from test.test_utils import ( + create_xml_with_starting_blocks, + test_changing_synoptics_and_blocks, +) class TestChangePVs(unittest.TestCase): @@ -31,7 +35,9 @@ def test_GIVEN_multiple_different_blocks_in_configuration_WHEN_ones_pv_is_change [("BLOCKNAME", "BLOCK_PV"), ("BLOCKNAME_2", "CHANGED")], ) - def test_GIVEN_block_with_part_of_PV_WHEN_pv_is_changed_THEN_only_part_of_PV_changes(self): + def test_GIVEN_block_with_part_of_PV_WHEN_pv_is_changed_THEN_only_part_of_PV_changes( + self, + ): self._test_changing_pv( [("BLOCKNAME", "CHANGEME:BUT:NOT:ME")], "CHANGEME", @@ -39,15 +45,25 @@ def test_GIVEN_block_with_part_of_PV_WHEN_pv_is_changed_THEN_only_part_of_PV_cha [("BLOCKNAME", "CHANGED:BUT:NOT:ME")], ) - def test_GIVEN_two_blocks_that_need_changing_WHEN_pv_is_changed_THEN_both_change(self): + def test_GIVEN_two_blocks_that_need_changing_WHEN_pv_is_changed_THEN_both_change( + self, + ): self._test_changing_pv( - [("BLOCKNAME", "CHANGEME:BUT:NOT:ME"), ("BLOCKNAME_1", "ALSO:CHANGEME:BUT:NOT:ME")], + [ + ("BLOCKNAME", "CHANGEME:BUT:NOT:ME"), + ("BLOCKNAME_1", "ALSO:CHANGEME:BUT:NOT:ME"), + ], "CHANGEME", "CHANGED", - [("BLOCKNAME", "CHANGED:BUT:NOT:ME"), ("BLOCKNAME_1", "ALSO:CHANGED:BUT:NOT:ME")], + [ + ("BLOCKNAME", "CHANGED:BUT:NOT:ME"), + ("BLOCKNAME_1", "ALSO:CHANGED:BUT:NOT:ME"), + ], ) - def test_GIVEN_block_with_name_that_could_be_changed_WHEN_pv_is_changed_THEN_name_is_not(self): + def test_GIVEN_block_with_name_that_could_be_changed_WHEN_pv_is_changed_THEN_name_is_not( + self, + ): self._test_changing_pv( [("CHANGEME", "BLAH")], "CHANGEME", "CHANGED", [("CHANGEME", "BLAH")] ) @@ -65,7 +81,8 @@ def GIVEN_two_blocks_with_pvs_that_obey_filter_WHEN_pv_counted_THEN_returns_two_ number_of_pvs = pv_changer.get_number_of_instances_of_pv("CHANGEME") assert_that(number_of_pvs, is_(2)) - self.file_access.write_file.assert_not_called() + write_file = mocked.create_autospec(self.file_access.write_file) + write_file.assert_not_called() def GIVEN_block_with_name_that_obeys_filter_WHEN_pv_counted_THEN_returns_zero_and_xml_unchanged( self, @@ -77,7 +94,8 @@ def GIVEN_block_with_name_that_obeys_filter_WHEN_pv_counted_THEN_returns_zero_an number_of_pvs = pv_changer.get_number_of_instances_of_pv("CHANGEME") assert_that(number_of_pvs, is_(0)) - self.file_access.write_file.assert_not_called() + write_file = mocked.create_autospec(self.file_access.write_file) + write_file.assert_not_called() if __name__ == "__main__": diff --git a/upgrade.py b/upgrade.py index 31b1e00..879beb8 100644 --- a/upgrade.py +++ b/upgrade.py @@ -86,6 +86,9 @@ git_repo = RepoFactory.get_repo(config_root) upgrade = Upgrade( - file_access=file_access, logger=logger, upgrade_steps=UPGRADE_STEPS, git_repo=git_repo + file_access=file_access, + logger=logger, + upgrade_steps=UPGRADE_STEPS, + git_repo=git_repo, ) sys.exit(upgrade.upgrade())