From 6a75d45dc1d3190277101af210187b4bf1177ead Mon Sep 17 00:00:00 2001 From: Christian Henkel <6976069+ct2034@users.noreply.github.com> Date: Sat, 5 Oct 2024 20:20:46 +0200 Subject: [PATCH] Black for formatting (#54) * black for auto formatting * isort and black are happy with each other * formatting by black and isort --------- Signed-off-by: Christian Henkel --- .pre-commit-config.yaml | 6 + docs/source/conf.py | 30 +- pyproject.toml | 11 +- src/as2fm/as2fm_common/common.py | 34 +- .../as2fm_common/ecmascript_interpretation.py | 17 +- .../convince_to_plain_jani.py | 46 +- .../jani_entries/jani_assignment.py | 5 +- .../jani_entries/jani_automaton.py | 40 +- .../jani_entries/jani_composition.py | 39 +- .../jani_entries/jani_constant.py | 19 +- .../jani_convince_expression_expansion.py | 210 +++++---- .../jani_generator/jani_entries/jani_edge.py | 23 +- .../jani_entries/jani_expression.py | 119 +++-- .../jani_generator/jani_entries/jani_guard.py | 14 +- .../jani_generator/jani_entries/jani_model.py | 96 ++-- .../jani_entries/jani_property.py | 47 +- .../jani_generator/jani_entries/jani_utils.py | 12 +- .../jani_generator/jani_entries/jani_value.py | 7 +- .../jani_entries/jani_variable.py | 58 ++- src/as2fm/jani_generator/main.py | 37 +- .../ros_helpers/ros_action_handler.py | 163 ++++--- .../ros_helpers/ros_communication_handler.py | 61 +-- .../ros_helpers/ros_service_handler.py | 82 ++-- .../jani_generator/ros_helpers/ros_timer.py | 163 ++++--- .../scxml_helpers/scxml_event.py | 25 +- .../scxml_helpers/scxml_event_processor.py | 131 +++--- .../scxml_helpers/scxml_expression.py | 96 ++-- .../scxml_helpers/scxml_tags.py | 429 +++++++++++------- .../scxml_helpers/scxml_to_jani.py | 34 +- .../scxml_helpers/top_level_interpreter.py | 106 ++--- src/as2fm/jani_visualizer/main.py | 51 ++- src/as2fm/jani_visualizer/visualizer.py | 94 ++-- src/as2fm/scxml_converter/bt_converter.py | 81 ++-- .../scxml_converter/scxml_entries/__init__.py | 96 ++-- .../scxml_converter/scxml_entries/bt_utils.py | 56 ++- .../scxml_entries/ros_utils.py | 282 +++++++----- .../scxml_converter/scxml_entries/scxml_bt.py | 22 +- .../scxml_entries/scxml_data.py | 67 ++- .../scxml_entries/scxml_executable_entries.py | 114 +++-- .../scxml_entries/scxml_param.py | 23 +- .../scxml_entries/scxml_root.py | 130 +++--- .../scxml_entries/scxml_ros_action_client.py | 111 +++-- .../scxml_entries/scxml_ros_action_server.py | 58 ++- .../scxml_ros_action_server_thread.py | 119 +++-- .../scxml_entries/scxml_ros_base.py | 219 +++++---- .../scxml_entries/scxml_ros_field.py | 33 +- .../scxml_entries/scxml_ros_service.py | 29 +- .../scxml_entries/scxml_ros_timer.py | 15 +- .../scxml_entries/scxml_ros_topic.py | 11 +- .../scxml_entries/scxml_state.py | 103 +++-- .../scxml_entries/scxml_transition.py | 100 ++-- .../scxml_converter/scxml_entries/utils.py | 100 ++-- .../scxml_entries/xml_utils.py | 63 ++- src/as2fm/trace_visualizer/main.py | 31 +- src/as2fm/trace_visualizer/visualizer.py | 186 ++++---- ...test_unittest_ecmascript_interpretation.py | 11 +- test/as2fm_common/test_utilities_smc_storm.py | 27 +- .../test_systemtest_convince_to_plain_jani.py | 5 +- .../test_systemtest_scxml_to_jani.py | 143 +++--- .../test_unittest_jani_model_loading.py | 7 +- .../jani_generator/test_unittest_ros_timer.py | 15 +- test/jani_visualizer/jani_visualizer_test.py | 20 +- .../test_systemtest_scxml_entries.py | 182 +++++--- test/scxml_converter/test_systemtest_xml.py | 91 ++-- .../test_unittest_scxml_data.py | 42 +- .../test_unittest_scxml_utils.py | 21 +- test/scxml_converter/test_utils.py | 4 +- .../trace_visualizer/trace_visualizer_test.py | 37 +- 68 files changed, 2915 insertions(+), 2044 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d5dd70d4..659538fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,12 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files + # black for auto-formatting + - repo: https://github.com/psf/black + rev: 24.8.0 + hooks: + - id: black + language_version: python3.10 # same as lint.yml # - repo: https://github.com/pycqa/pylint # rev: v3.3.1 diff --git a/docs/source/conf.py b/docs/source/conf.py index 85246466..3c533021 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -4,19 +4,19 @@ # mypy: ignore-errors -project = 'CONVINCE Model Checking Components' -copyright = '2024' -author = 'CONVINCE Consortium' +project = "CONVINCE Model Checking Components" +copyright = "2024" +author = "CONVINCE Consortium" -release = '0.1' -version = '0.1.0' +release = "0.1" +version = "0.1.0" # -- General configuration extensions = [ - 'sphinx.ext.autosummary', - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', + "sphinx.ext.autosummary", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", # 'myst_parser', # 'autodoc2', ] @@ -26,14 +26,14 @@ # 'sphinx': ('https://www.sphinx-doc.org/en/master/', None), # 'networkx': ('https://networkx.org/documentation/stable/', None), # } -intersphinx_disabled_domains = ['std'] +intersphinx_disabled_domains = ["std"] -templates_path = ['_templates'] +templates_path = ["_templates"] # -- Options for HTML output -html_theme = 'sphinx_rtd_theme' -html_logo = 'convince_logo_horizontal_200p.png' +html_theme = "sphinx_rtd_theme" +html_logo = "convince_logo_horizontal_200p.png" html_theme_options = { # 'analytics_id': 'G-XXXXXXXXXX', # Provided by Google in your dashboard # 'analytics_anonymize_ip': False, @@ -50,11 +50,11 @@ # 'includehidden': True, # 'titles_only': False } -html_static_path = ['_static'] +html_static_path = ["_static"] # html_css_files = [ # 'css/custom.css', # ] -html_style = 'css/custom.css' +html_style = "css/custom.css" # -- Options for EPUB output -epub_show_urls = 'footnote' +epub_show_urls = "footnote" diff --git a/pyproject.toml b/pyproject.toml index a01b649a..a8691e2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,10 +62,6 @@ dev = [ "bumpver" ] -[isort] -profile = "google" -line_length = 100 - [tool.pylint.main] disable = [ "C0114", # Missing module docstring @@ -73,3 +69,10 @@ disable = [ "R0401", # Cyclic import "W0511", # TODO comments (we need them) ] + +[tool.black] +line-length = 100 + +[tool.isort] +profile = "black" +line_length = 100 diff --git a/src/as2fm/as2fm_common/common.py b/src/as2fm/as2fm_common/common.py index 58c6ec31..758cac01 100644 --- a/src/as2fm/as2fm_common/common.py +++ b/src/as2fm/as2fm_common/common.py @@ -50,8 +50,8 @@ def remove_namespace(tag: str) -> str: :param tag: The tag to remove the namespace from. :return: The tag without the namespace. """ - if '}' in tag: - tag_wo_ns = tag.split('}')[-1] + if "}" in tag: + tag_wo_ns = tag.split("}")[-1] else: tag_wo_ns = tag return tag_wo_ns @@ -61,9 +61,9 @@ def get_default_expression_for_type(field_type: Type[ValidTypes]) -> ValidTypes: """Generate a default expression for a field type.""" assert field_type in get_args(ValidTypes), f"Error: Unsupported data type {field_type}." if field_type is MutableSequence[int]: - return array('i') + return array("i") elif field_type is MutableSequence[float]: - return array('d') + return array("d") else: return field_type() @@ -71,9 +71,9 @@ def get_default_expression_for_type(field_type: Type[ValidTypes]) -> ValidTypes: def value_to_type(value: ValidTypes) -> Type[ValidTypes]: """Convert a value to a type.""" if isinstance(value, array): - if value.typecode == 'i': + if value.typecode == "i": return MutableSequence[int] - elif value.typecode == 'd': + elif value.typecode == "d": return MutableSequence[float] else: raise ValueError(f"Type of array '{value.typecode}' not supported.") @@ -99,22 +99,26 @@ def value_to_string(value: ValidTypes) -> str: def string_to_value(value_str: str, value_type: Type[ValidTypes]) -> ValidTypes: """Convert a string to a value of the desired type.""" value_str = value_str.strip() - assert isinstance(value_str, str), \ - f"Error: provided value is of type {type(value_str)}, expected a string." + assert isinstance( + value_str, str + ), f"Error: provided value is of type {type(value_str)}, expected a string." assert len(value_str) > 0, "Error: provided value is an empty string, cannot convert." - is_array_value = value_str.startswith('[') and value_str.endswith(']') + is_array_value = value_str.startswith("[") and value_str.endswith("]") if not is_array_value: - assert value_type in (bool, int, float), \ - f"Error: the value {value_str} shall be converted to a base type." + assert value_type in ( + bool, + int, + float, + ), f"Error: the value {value_str} shall be converted to a base type." return value_type(value_str) else: - str_entries = value_str.strip('[]').split(',') - if str_entries == ['']: + str_entries = value_str.strip("[]").split(",") + if str_entries == [""]: str_entries = [] if value_type is MutableSequence[int]: - return array('i', [int(v) for v in str_entries]) + return array("i", [int(v) for v in str_entries]) elif value_type is MutableSequence[float]: - return array('d', [float(v) for v in str_entries]) + return array("d", [float(v) for v in str_entries]) else: raise ValueError(f"Unsupported value type {value_type}.") diff --git a/src/as2fm/as2fm_common/ecmascript_interpretation.py b/src/as2fm/as2fm_common/ecmascript_interpretation.py index 86d7dc1a..1437f3db 100644 --- a/src/as2fm/as2fm_common/ecmascript_interpretation.py +++ b/src/as2fm/as2fm_common/ecmascript_interpretation.py @@ -28,7 +28,8 @@ def interpret_ecma_script_expr( - expr: str, variables: Optional[Dict[str, ValidTypes]] = None) -> object: + expr: str, variables: Optional[Dict[str, ValidTypes]] = None +) -> object: """ Interpret the ECMA script expression. @@ -54,12 +55,16 @@ def interpret_ecma_script_expr( if all(isinstance(x, int) for x in res_as_list): return array("i", res_as_list) else: - return array('d', res_as_list) + return array("d", res_as_list) else: - raise ValueError(f"Expected expr. {expr} to be of type {BasicJsTypes} or " - f"an array, got '{type(expr_result._obj)}'") + raise ValueError( + f"Expected expr. {expr} to be of type {BasicJsTypes} or " + f"an array, got '{type(expr_result._obj)}'" + ) elif isinstance(expr_result, array): return expr_result else: - raise ValueError(f"Expected expr. {expr} to be of type {BasicJsTypes} or " - f"JsObjectWrapper, got '{type(expr_result)}'") + raise ValueError( + f"Expected expr. {expr} to be of type {BasicJsTypes} or " + f"JsObjectWrapper, got '{type(expr_result)}'" + ) diff --git a/src/as2fm/jani_generator/convince_jani_helpers/convince_to_plain_jani.py b/src/as2fm/jani_generator/convince_jani_helpers/convince_to_plain_jani.py index 85de77ca..a7124226 100644 --- a/src/as2fm/jani_generator/convince_jani_helpers/convince_to_plain_jani.py +++ b/src/as2fm/jani_generator/convince_jani_helpers/convince_to_plain_jani.py @@ -22,8 +22,12 @@ from os import path from typing import List -from as2fm.jani_generator.jani_entries import (JaniAutomaton, JaniComposition, - JaniModel, JaniProperty) +from as2fm.jani_generator.jani_entries import ( + JaniAutomaton, + JaniComposition, + JaniModel, + JaniProperty, +) def to_cm(value: float) -> int: @@ -61,8 +65,9 @@ def __convince_env_model_to_jani(base_model: JaniModel, env_model: dict): # The robot pose should be stored using integers -> centimeters and degrees base_model.add_variable(f"robots.{robot_name}.pose.x_cm", int, to_cm(robot_pose["x"])) base_model.add_variable(f"robots.{robot_name}.pose.y_cm", int, to_cm(robot_pose["y"])) - base_model.add_variable(f"robots.{robot_name}.pose.theta_deg", int, - to_deg(robot_pose["theta"])) + base_model.add_variable( + f"robots.{robot_name}.pose.theta_deg", int, to_deg(robot_pose["theta"]) + ) base_model.add_variable(f"robots.{robot_name}.pose.x", float, transient=True) base_model.add_variable(f"robots.{robot_name}.pose.y", float, transient=True) base_model.add_variable(f"robots.{robot_name}.pose.theta", float, transient=True) @@ -70,14 +75,18 @@ def __convince_env_model_to_jani(base_model: JaniModel, env_model: dict): base_model.add_variable(f"robots.{robot_name}.goal.y", float, transient=True) base_model.add_variable(f"robots.{robot_name}.goal.theta", float, transient=True) robot_shape = robot["shape"] - base_model.add_constant(f"robots.{robot_name}.shape.radius", float, - float(robot_shape["radius"])) - base_model.add_constant(f"robots.{robot_name}.shape.height", float, - float(robot_shape["height"])) - base_model.add_constant(f"robots.{robot_name}.linear_velocity", float, - float(robot["linear_velocity"])) - base_model.add_constant(f"robots.{robot_name}.angular_velocity", float, - float(robot["angular_velocity"])) + base_model.add_constant( + f"robots.{robot_name}.shape.radius", float, float(robot_shape["radius"]) + ) + base_model.add_constant( + f"robots.{robot_name}.shape.height", float, float(robot_shape["height"]) + ) + base_model.add_constant( + f"robots.{robot_name}.linear_velocity", float, float(robot["linear_velocity"]) + ) + base_model.add_constant( + f"robots.{robot_name}.angular_velocity", float, float(robot["angular_velocity"]) + ) if "obstacles" in env_model: # Extract the obstacles from the env_model # TODO @@ -111,8 +120,9 @@ def __convince_properties_to_jani(base_model: JaniModel, properties: List[dict]) assert isinstance(base_model, JaniModel), "The base_model should be a JaniModel instance" for property_dict in properties: assert isinstance(property_dict, dict), "The properties list should contain dictionaries" - base_model.add_jani_property(JaniProperty(property_dict["name"], - property_dict["expression"])) + base_model.add_jani_property( + JaniProperty(property_dict["name"], property_dict["expression"]) + ) def convince_jani_parser(base_model: JaniModel, convince_jani_path: str): @@ -122,14 +132,14 @@ def convince_jani_parser(base_model: JaniModel, convince_jani_path: str): # Check if the convince_jani_path is a file assert path.isfile(convince_jani_path), "The convince_jani_path should be a file" # Read the convince-jani file - with open(convince_jani_path, "r", encoding='utf-8') as file: + with open(convince_jani_path, "r", encoding="utf-8") as file: convince_jani_json = json.load(file) # ---- Metadata ---- base_model.set_name(convince_jani_json["name"]) # Make sure we are loading a convince-jani file - assert "features" in convince_jani_json and \ - "convince_extensions" in convince_jani_json["features"], \ - "The provided file is not a convince-jani file (missing feature entry)" + assert ( + "features" in convince_jani_json and "convince_extensions" in convince_jani_json["features"] + ), "The provided file is not a convince-jani file (missing feature entry)" # Extract the environment model from the convince-jani file # ---- Environment Model ---- __convince_env_model_to_jani(base_model, convince_jani_json["rob_env_model"]) diff --git a/src/as2fm/jani_generator/jani_entries/jani_assignment.py b/src/as2fm/jani_generator/jani_entries/jani_assignment.py index 6acbf47f..f6afca24 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_assignment.py +++ b/src/as2fm/jani_generator/jani_entries/jani_assignment.py @@ -20,8 +20,7 @@ from typing import Dict from as2fm.jani_generator.jani_entries import JaniConstant, JaniExpression -from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import \ - expand_expression +from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression class JaniAssignment: @@ -43,5 +42,5 @@ def as_dict(self, constants: Dict[str, JaniConstant]): return { "ref": self._var_name.as_dict(), "value": expanded_value.as_dict(), - "index": self._index + "index": self._index, } diff --git a/src/as2fm/jani_generator/jani_entries/jani_automaton.py b/src/as2fm/jani_generator/jani_entries/jani_automaton.py index 757e0839..eb4563c0 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_automaton.py +++ b/src/as2fm/jani_generator/jani_entries/jani_automaton.py @@ -17,8 +17,7 @@ from typing import Any, Dict, List, Optional, Set -from as2fm.jani_generator.jani_entries import (JaniConstant, JaniEdge, - JaniExpression, JaniVariable) +from as2fm.jani_generator.jani_entries import JaniConstant, JaniEdge, JaniExpression, JaniVariable class JaniAutomaton: @@ -36,8 +35,7 @@ def __init__(self, *, automaton_dict: Optional[Dict[str, Any]] = None): if automaton_dict is None: return self._name = automaton_dict["name"] - self._generate_locations( - automaton_dict["locations"], automaton_dict["initial-locations"]) + self._generate_locations(automaton_dict["locations"], automaton_dict["initial-locations"]) self._generate_variables(automaton_dict.get("variables", [])) self._generate_edges(automaton_dict["edges"]) @@ -56,15 +54,16 @@ def get_initial_locations(self) -> Set[str]: return self._initial_locations def make_initial(self, location_name: str): - assert location_name in self._locations, \ - f"Location {location_name} must exist in the automaton" + assert ( + location_name in self._locations + ), f"Location {location_name} must exist in the automaton" self._initial_locations.add(location_name) def unset_initial(self, location_name: str): - assert location_name in self._locations, \ - f"Location {location_name} must exist in the automaton" - assert location_name in self._initial_locations, \ - f"Location {location_name} must be initial" + assert ( + location_name in self._locations + ), f"Location {location_name} must exist in the automaton" + assert location_name in self._initial_locations, f"Location {location_name} must be initial" self._initial_locations.remove(location_name) def add_variable(self, variable: JaniVariable): @@ -90,8 +89,9 @@ def remove_empty_self_loop_edges(self): """Remove all self-loop edges from the automaton.""" self._edges = [edge for edge in self._edges if not edge.is_empty_self_loop()] - def _generate_locations(self, - location_list: List[Dict[str, Any]], initial_locations: List[str]): + def _generate_locations( + self, location_list: List[Dict[str, Any]], initial_locations: List[str] + ): for location in location_list: self._locations.add(location["name"]) for init_location in initial_locations: @@ -106,8 +106,13 @@ def _generate_variables(self, variable_list: List[dict]): if "transient" in variable: is_transient = variable["transient"] var_type = JaniVariable.python_type_from_json(variable["type"]) - self._local_variables.update({variable["name"]: JaniVariable( - variable["name"], var_type, init_expr, is_transient)}) + self._local_variables.update( + { + variable["name"]: JaniVariable( + variable["name"], var_type, init_expr, is_transient + ) + } + ) def _generate_edges(self, edge_list: List[dict]): for edge in edge_list: @@ -122,7 +127,7 @@ def get_actions(self) -> Set[str]: actions.add(action) return actions - def merge(self, other: 'JaniAutomaton'): + def merge(self, other: "JaniAutomaton"): assert self._name == other.get_name(), "Automaton names must match" self._locations.update(other._locations) self._initial_locations.update(other._initial_locations) @@ -134,9 +139,10 @@ def as_dict(self, constant: Dict[str, JaniConstant]): "name": self._name, "locations": [{"name": location} for location in sorted(self._locations)], "initial-locations": sorted(list(self._initial_locations)), - "edges": [edge.as_dict(constant) for edge in self._edges] + "edges": [edge.as_dict(constant) for edge in self._edges], } if len(self._local_variables) > 0: automaton_dict.update( - {"variables": [jani_var.as_dict() for jani_var in self._local_variables.values()]}) + {"variables": [jani_var.as_dict() for jani_var in self._local_variables.values()]} + ) return automaton_dict diff --git a/src/as2fm/jani_generator/jani_entries/jani_composition.py b/src/as2fm/jani_generator/jani_entries/jani_composition.py index 51543515..89efc0eb 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_composition.py +++ b/src/as2fm/jani_generator/jani_entries/jani_composition.py @@ -31,19 +31,18 @@ def __init__(self, composition_dict: Optional[Dict[str, Any]] = None): return self._elements = self._generate_elements(composition_dict["elements"]) self._syncs = self._generate_syncs(composition_dict["syncs"]) - self._element_to_id = {element: idx for idx, - element in enumerate(self._elements)} + self._element_to_id = {element: idx for idx, element in enumerate(self._elements)} assert self.is_valid(), "Invalid composition from dict." def add_element(self, element: str): """Append a new automaton name in the composition.""" - assert element not in self._elements, \ - f"Element {element} already exists in the composition" + assert element not in self._elements, f"Element {element} already exists in the composition" self._elements.append(element) self._element_to_id[element] = len(self._elements) - 1 for sync in self._syncs: - assert len(sync["synchronise"]) == len(self._elements) - 1, \ - "Unexpected number of syncs found in the composition during the update" + assert ( + len(sync["synchronise"]) == len(self._elements) - 1 + ), "Unexpected number of syncs found in the composition during the update" sync["synchronise"].append(None) def get_elements(self): @@ -59,20 +58,18 @@ def add_sync(self, sync_name: str, syncs: Dict[str, str]): # Generate the synchronize list sync_list: List[Optional[str]] = [None] * len(self._elements) for automata, action in syncs.items(): - assert automata in self._element_to_id, \ - f"Automaton {automata} does not exist in the composition" + assert ( + automata in self._element_to_id + ), f"Automaton {automata} does not exist in the composition" sync_list[self._element_to_id[automata]] = action - self._syncs.append({ - "result": sync_name, - "synchronise": sync_list - }) + self._syncs.append({"result": sync_name, "synchronise": sync_list}) def get_syncs_for_element(self, element: str) -> List[str]: """Get the existing syncs for a specific element (=automaton).""" - assert element in self._element_to_id, \ - f"Element {element} does not exist in the composition" - syncs_w_none = [sync['synchronise'][self._element_to_id[element]] - for sync in self._syncs] + assert ( + element in self._element_to_id + ), f"Element {element} does not exist in the composition" + syncs_w_none = [sync["synchronise"][self._element_to_id[element]] for sync in self._syncs] return [sync for sync in syncs_w_none if sync is not None] def is_valid(self) -> bool: @@ -95,11 +92,9 @@ def _generate_syncs(self, syncs_list): generated_syncs = [] for sync in syncs_list: assert len(self._elements) == len( - sync["synchronise"]), "The number of elements and synchronise should be the same" - sync_dict = { - "synchronise": sync["synchronise"], - "result": None - } + sync["synchronise"] + ), "The number of elements and synchronise should be the same" + sync_dict = {"synchronise": sync["synchronise"], "result": None} if "result" in sync: sync_dict["result"] = sync["result"] generated_syncs.append(sync_dict) @@ -110,5 +105,5 @@ def as_dict(self): self._syncs = sorted(self._syncs, key=lambda x: x["result"]) return { "elements": [{"automaton": element} for element in self._elements], - "syncs": self._syncs + "syncs": self._syncs, } diff --git a/src/as2fm/jani_generator/jani_entries/jani_constant.py b/src/as2fm/jani_generator/jani_entries/jani_constant.py index 7105a6a2..ab20caa0 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_constant.py +++ b/src/as2fm/jani_generator/jani_entries/jani_constant.py @@ -34,17 +34,14 @@ def from_dict(constant_dict: dict) -> "JaniConstant": # Check if conversion from string to constant_type is possible try: const_value_cast = constant_type(constant_value) - return JaniConstant(constant_name, - constant_type, - JaniExpression(const_value_cast)) + return JaniConstant(constant_name, constant_type, JaniExpression(const_value_cast)) except ValueError: # If no conversion possible, raise an error (constant names are not supported) raise ValueError( f"Value {constant_value} for constant {constant_name} " - f"is not a valid value for type {constant_type}.") - return JaniConstant(constant_name, - constant_type, - JaniExpression(constant_value)) + f"is not a valid value for type {constant_type}." + ) + return JaniConstant(constant_name, constant_type, JaniExpression(constant_value)) def __init__(self, c_name: str, c_type: Type, c_value: Optional[JaniExpression]): assert isinstance(c_value, JaniExpression), "Value should be a JaniExpression" @@ -60,8 +57,7 @@ def value(self) -> Optional[ValidTypes]: if self._value is None: return None jani_value = self._value.value - assert jani_value.is_valid(), \ - "The expression can't be evaluated to a constant value" + assert jani_value.is_valid(), "The expression can't be evaluated to a constant value" return jani_value.value() @staticmethod @@ -106,10 +102,7 @@ def jani_type_to_string(c_type: Type[ValidTypes]) -> str: raise ValueError(f"Type {c_type} not supported by Jani") def as_dict(self): - const_dict = { - "name": self._name, - "type": JaniConstant.jani_type_to_string(self._type) - } + const_dict = {"name": self._name, "type": JaniConstant.jani_type_to_string(self._type)} if self._value is not None: const_dict["value"] = self._value.as_dict() return const_dict diff --git a/src/as2fm/jani_generator/jani_entries/jani_convince_expression_expansion.py b/src/as2fm/jani_generator/jani_entries/jani_convince_expression_expansion.py index 24d33f98..67a88fed 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_convince_expression_expansion.py +++ b/src/as2fm/jani_generator/jani_entries/jani_convince_expression_expansion.py @@ -18,14 +18,29 @@ from math import pi from typing import Callable, Dict, Union -from as2fm.jani_generator.jani_entries import (JaniConstant, JaniExpression, - JaniValue) +from as2fm.jani_generator.jani_entries import JaniConstant, JaniExpression, JaniValue from as2fm.jani_generator.jani_entries.jani_expression_generator import ( - abs_operator, and_operator, ceil_operator, cos_operator, divide_operator, - equal_operator, floor_operator, greater_equal_operator, if_operator, - log_operator, lower_operator, max_operator, min_operator, minus_operator, - modulo_operator, multiply_operator, or_operator, plus_operator, - pow_operator, sin_operator) + abs_operator, + and_operator, + ceil_operator, + cos_operator, + divide_operator, + equal_operator, + floor_operator, + greater_equal_operator, + if_operator, + log_operator, + lower_operator, + max_operator, + min_operator, + minus_operator, + modulo_operator, + multiply_operator, + or_operator, + plus_operator, + pow_operator, + sin_operator, +) # Map each operator to the corresponding one in Jani OPERATORS_TO_JANI_MAP: Dict[str, str] = { @@ -70,25 +85,26 @@ # Custom operators (CONVINCE, specific to mobile 2D robot use case) def intersection_operator(left, right) -> JaniExpression: - return JaniExpression({ - "op": "intersect", - "robot": JaniExpression(left), - "barrier": JaniExpression(right)}) + return JaniExpression( + {"op": "intersect", "robot": JaniExpression(left), "barrier": JaniExpression(right)} + ) def distance_operator(left, right) -> JaniExpression: - return JaniExpression({ - "op": "distance", - "robot": JaniExpression(left), - "barrier": JaniExpression(right)}) + return JaniExpression( + {"op": "distance", "robot": JaniExpression(left), "barrier": JaniExpression(right)} + ) def distance_to_point_operator(robot, target_x, target_y) -> JaniExpression: - return JaniExpression({ - "op": "distance_to_point", - "robot": JaniExpression(robot), - "x": JaniExpression(target_x), - "y": JaniExpression(target_y)}) + return JaniExpression( + { + "op": "distance_to_point", + "robot": JaniExpression(robot), + "x": JaniExpression(target_x), + "y": JaniExpression(target_y), + } + ) def norm2d_operator(x=None, y=None, *, exp=None) -> JaniExpression: @@ -120,8 +136,9 @@ def cross2d_operator(x1=None, y1=None, x2=None, y2=None, *, exp=None) -> JaniExp exp_y1 = y1 exp_x2 = x2 exp_y2 = y2 - assert all(exp is not None for exp in [exp_x1, exp_y1, exp_x2, exp_y2]), \ - "The 2D vectors components must be provided" + assert all( + exp is not None for exp in [exp_x1, exp_y1, exp_x2, exp_y2] + ), "The 2D vectors components must be provided" return minus_operator(multiply_operator(exp_x1, exp_y2), multiply_operator(exp_y1, exp_x2)) @@ -138,8 +155,9 @@ def dot2d_operator(x1=None, y1=None, x2=None, y2=None, *, exp=None) -> JaniExpre exp_y1 = y1 exp_x2 = x2 exp_y2 = y2 - assert all(exp is not None for exp in [exp_x1, exp_y1, exp_x2, exp_y2]), \ - "The 2D vectors components must be provided" + assert all( + exp is not None for exp in [exp_x1, exp_y1, exp_x2, exp_y2] + ), "The 2D vectors components must be provided" return plus_operator(multiply_operator(exp_x1, exp_x2), multiply_operator(exp_y1, exp_y2)) @@ -197,8 +215,8 @@ def to_rad_operator(value=None, *, exp=None) -> JaniExpression: # Functionalities for interpolation def __expression_interpolation_single_boundary( - jani_constants: Dict[str, JaniConstant], - robot_name: str, boundary_id: int) -> JaniExpression: + jani_constants: Dict[str, JaniConstant], robot_name: str, boundary_id: int +) -> JaniExpression: n_vertices = jani_constants["boundaries.count"].value() # Variables names robot_radius = f"robots.{robot_name}.shape.radius" @@ -226,7 +244,8 @@ def __expression_interpolation_single_boundary( boundary_norm_exp = norm2d_operator(ab_x, ab_y) # Distance from the robot to the boundary perpendicular to the boundary segment v_dist_exp = divide_operator( - abs_operator(cross2d_operator(ab_x, ab_y, ea_x, ea_y)), boundary_norm_exp) + abs_operator(cross2d_operator(ab_x, ab_y, ea_x, ea_y)), boundary_norm_exp + ) # Distance between the boundary extreme points and the robot parallel to the boundary segment ha_dist_exp = divide_operator(dot2d_operator(ab_x, ab_y, ea_x, ea_y), boundary_norm_exp) hb_dist_exp = divide_operator(dot2d_operator(ba_x, ba_y, eb_x, eb_y), boundary_norm_exp) @@ -235,53 +254,79 @@ def __expression_interpolation_single_boundary( is_parallel_exp = equal_operator(cross2d_operator(ab_x, ab_y, es_x, es_y), 0.0) # Interpolation factors ha_interp_exp = if_operator( - and_operator(greater_equal_operator(ha_dist_exp, 0.0), - lower_operator(ha_dist_exp, robot_radius)), - divide_operator(minus_operator(multiply_operator(boundary_norm_exp, robot_radius), - dot2d_operator(ab_x, ab_y, ea_x, ea_y)), - dot2d_operator(ba_x, ba_y, es_x, es_y)), - 1.0) + and_operator( + greater_equal_operator(ha_dist_exp, 0.0), lower_operator(ha_dist_exp, robot_radius) + ), + divide_operator( + minus_operator( + multiply_operator(boundary_norm_exp, robot_radius), + dot2d_operator(ab_x, ab_y, ea_x, ea_y), + ), + dot2d_operator(ba_x, ba_y, es_x, es_y), + ), + 1.0, + ) hb_interp_exp = if_operator( - and_operator(greater_equal_operator(hb_dist_exp, 0.0), - lower_operator(hb_dist_exp, robot_radius)), - divide_operator(minus_operator(multiply_operator(boundary_norm_exp, robot_radius), - dot2d_operator(ba_x, ba_y, eb_x, eb_y)), - dot2d_operator(ab_x, ab_y, es_x, es_y)), - 1.0) - h_interp_exp = if_operator(is_perpendicular_exp, - 1.0, min_operator(ha_interp_exp, hb_interp_exp)) + and_operator( + greater_equal_operator(hb_dist_exp, 0.0), lower_operator(hb_dist_exp, robot_radius) + ), + divide_operator( + minus_operator( + multiply_operator(boundary_norm_exp, robot_radius), + dot2d_operator(ba_x, ba_y, eb_x, eb_y), + ), + dot2d_operator(ab_x, ab_y, es_x, es_y), + ), + 1.0, + ) + h_interp_exp = if_operator( + is_perpendicular_exp, 1.0, min_operator(ha_interp_exp, hb_interp_exp) + ) v_interp_exp = if_operator( or_operator(is_parallel_exp, greater_equal_operator(v_dist_exp, robot_radius)), 1.0, - divide_operator(minus_operator(multiply_operator(boundary_norm_exp, robot_radius), - abs_operator(cross2d_operator(ab_x, ab_y, ea_x, ea_y))), - abs_operator(cross2d_operator(ab_x, ab_y, es_x, es_y)))) + divide_operator( + minus_operator( + multiply_operator(boundary_norm_exp, robot_radius), + abs_operator(cross2d_operator(ab_x, ab_y, ea_x, ea_y)), + ), + abs_operator(cross2d_operator(ab_x, ab_y, es_x, es_y)), + ), + ) return if_operator( - greater_equal_operator(max_operator(v_dist_exp, max_operator(ha_dist_exp, hb_dist_exp)), - robot_radius), - 0.0, min_operator(h_interp_exp, v_interp_exp)) + greater_equal_operator( + max_operator(v_dist_exp, max_operator(ha_dist_exp, hb_dist_exp)), robot_radius + ), + 0.0, + min_operator(h_interp_exp, v_interp_exp), + ) def __expression_interpolation_next_boundaries( - jani_constants: Dict[str, JaniConstant], robot_name, boundary_id) -> JaniExpression: + jani_constants: Dict[str, JaniConstant], robot_name, boundary_id +) -> JaniExpression: n_vertices = jani_constants["boundaries.count"].value() - assert isinstance(n_vertices, int) and n_vertices > 1, \ - f"The number of boundaries ({n_vertices}) must greater than 1" + assert ( + isinstance(n_vertices, int) and n_vertices > 1 + ), f"The number of boundaries ({n_vertices}) must greater than 1" if boundary_id >= n_vertices: return JaniExpression(0.0) return max_operator( __expression_interpolation_single_boundary(jani_constants, robot_name, boundary_id), - __expression_interpolation_next_boundaries(jani_constants, robot_name, boundary_id + 1)) + __expression_interpolation_next_boundaries(jani_constants, robot_name, boundary_id + 1), + ) def __expression_interpolation_next_obstacles( - jani_constants, robot_name, obstacle_id) -> JaniExpression: + jani_constants, robot_name, obstacle_id +) -> JaniExpression: # TODO return JaniExpression(0.0) def __expression_interpolation( - jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: + jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant] +) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "intersect" robot_op = jani_expression.operands["robot"] @@ -293,17 +338,19 @@ def __expression_interpolation( if barrier_name == "all": return max_operator( __expression_interpolation_next_boundaries(jani_constants, robot_name, 0), - __expression_interpolation_next_obstacles(jani_constants, robot_name, 0)) + __expression_interpolation_next_obstacles(jani_constants, robot_name, 0), + ) if barrier_name == "boundaries": return __expression_interpolation_next_boundaries(jani_constants, robot_name, 0) if barrier_name == "obstacles": return __expression_interpolation_next_obstacles(jani_constants, robot_name, 0) - raise NotImplementedError(f"The barrier type \"{barrier_name}\" is not implemented") + raise NotImplementedError(f'The barrier type "{barrier_name}" is not implemented') # Functionalities for validity check def __expression_distance_single_boundary( - jani_constants: Dict[str, JaniConstant], robot_name, boundary_id) -> JaniExpression: + jani_constants: Dict[str, JaniConstant], robot_name, boundary_id +) -> JaniExpression: n_vertices = jani_constants["boundaries.count"].value() # Variables names robot_radius = f"robots.{robot_name}.shape.radius" @@ -326,8 +373,9 @@ def __expression_distance_single_boundary( # Boundary length boundary_norm_exp = norm2d_operator(ab_x, ab_y) # Distance from the robot to the boundary perpendicular to the boundary segment - v_dist_exp = divide_operator(abs_operator(cross2d_operator(ab_x, ab_y, ra_x, ra_y)), - boundary_norm_exp) + v_dist_exp = divide_operator( + abs_operator(cross2d_operator(ab_x, ab_y, ra_x, ra_y)), boundary_norm_exp + ) # Distance between the boundary extreme points and the robot parallel to the boundary segment ha_dist_exp = divide_operator(dot2d_operator(ab_x, ab_y, ra_x, ra_y), boundary_norm_exp) hb_dist_exp = divide_operator(dot2d_operator(ba_x, ba_y, rb_x, rb_y), boundary_norm_exp) @@ -336,15 +384,18 @@ def __expression_distance_single_boundary( def __expression_distance_next_boundaries( - jani_constants: Dict[str, JaniConstant], robot_name, boundary_id) -> JaniExpression: + jani_constants: Dict[str, JaniConstant], robot_name, boundary_id +) -> JaniExpression: n_vertices = jani_constants["boundaries.count"].value() - assert isinstance(n_vertices, int) and n_vertices > 1, \ - f"The number of boundaries ({n_vertices}) must greater than 1" + assert ( + isinstance(n_vertices, int) and n_vertices > 1 + ), f"The number of boundaries ({n_vertices}) must greater than 1" if boundary_id >= n_vertices: return JaniExpression(True) return min_operator( __expression_distance_single_boundary(jani_constants, robot_name, boundary_id), - __expression_distance_next_boundaries(jani_constants, robot_name, boundary_id + 1)) + __expression_distance_next_boundaries(jani_constants, robot_name, boundary_id + 1), + ) def __expression_distance_next_obstacles(jani_constants, robot_name, obstacle_id) -> JaniExpression: @@ -353,7 +404,8 @@ def __expression_distance_next_obstacles(jani_constants, robot_name, obstacle_id def __expression_distance( - jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: + jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant] +) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "distance" robot_op = jani_expression.operands["robot"] @@ -365,7 +417,8 @@ def __expression_distance( if barrier_name == "all": return min_operator( __expression_distance_next_boundaries(jani_constants, robot_name, 0), - __expression_distance_next_obstacles(jani_constants, robot_name, 0)) + __expression_distance_next_obstacles(jani_constants, robot_name, 0), + ) if barrier_name == "boundaries": return __expression_distance_next_boundaries(jani_constants, robot_name, 0) if barrier_name == "obstacles": @@ -374,7 +427,8 @@ def __expression_distance( def __expression_distance_to_point( - jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: + jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant] +) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "distance_to_point" robot_op = jani_expression.operands["robot"] @@ -384,26 +438,30 @@ def __expression_distance_to_point( target_y_cm = to_cm_operator(expand_expression(jani_expression.operands["y"], jani_constants)) robot_x_cm = f"robots.{robot_name}.pose.x_cm" robot_y_cm = f"robots.{robot_name}.pose.y_cm" - return to_m_operator(norm2d_operator(minus_operator(robot_x_cm, target_x_cm), - minus_operator(robot_y_cm, target_y_cm))) + return to_m_operator( + norm2d_operator( + minus_operator(robot_x_cm, target_x_cm), minus_operator(robot_y_cm, target_y_cm) + ) + ) def __substitute_expression_op(expression: JaniExpression) -> JaniExpression: assert isinstance(expression, JaniExpression), "The input must be a JaniExpression" - assert expression.op in OPERATORS_TO_JANI_MAP, \ - f"The operator {expression.op} is not supported" + assert expression.op in OPERATORS_TO_JANI_MAP, f"The operator {expression.op} is not supported" expression.op = OPERATORS_TO_JANI_MAP[expression.op] return expression def expand_expression( - expression: Union[JaniExpression, JaniValue], - jani_constants: Dict[str, JaniConstant]) -> JaniExpression: + expression: Union[JaniExpression, JaniValue], jani_constants: Dict[str, JaniConstant] +) -> JaniExpression: # Given a CONVINCE JaniExpression, expand it to a plain JaniExpression - assert isinstance(expression, JaniExpression), \ - f"The expression should be a JaniExpression instance, found {type(expression)} instead." - assert expression.is_valid(), \ - "The expression is not valid: it defines no value, nor variable, nor operation to be done." + assert isinstance( + expression, JaniExpression + ), f"The expression should be a JaniExpression instance, found {type(expression)} instead." + assert ( + expression.is_valid() + ), "The expression is not valid: it defines no value, nor variable, nor operation to be done." if expression.op is None: # It is either a variable/constant identifier or a value return expression @@ -444,5 +502,5 @@ def expand_expression( "log": log_operator, "pow": pow_operator, "min": min_operator, - "max": max_operator + "max": max_operator, } diff --git a/src/as2fm/jani_generator/jani_entries/jani_edge.py b/src/as2fm/jani_generator/jani_entries/jani_edge.py index f21292e0..b4296d5b 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_edge.py +++ b/src/as2fm/jani_generator/jani_entries/jani_edge.py @@ -17,10 +17,13 @@ from typing import Dict, Optional -from as2fm.jani_generator.jani_entries import (JaniAssignment, JaniConstant, - JaniExpression, JaniGuard) -from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import \ - expand_expression +from as2fm.jani_generator.jani_entries import ( + JaniAssignment, + JaniConstant, + JaniExpression, + JaniGuard, +) +from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression class JaniEdge: @@ -37,7 +40,7 @@ def __init__(self, edge_dict: dict): jani_destination = { "location": dest["location"], "probability": None, - "assignments": [] + "assignments": [], } if "probability" in dest: jani_destination["probability"] = JaniExpression(dest["probability"]["exp"]) @@ -57,18 +60,18 @@ def get_action(self) -> Optional[str]: def is_empty_self_loop(self) -> bool: """Check if the edge is an empty self loop (i.e. has no assignments).""" - return len(self.destinations) == 1 and self.location == self.destinations[0]["location"] \ + return ( + len(self.destinations) == 1 + and self.location == self.destinations[0]["location"] and len(self.destinations[0]["assignments"]) == 0 + ) def set_action(self, action_name: str): """Set the action name.""" self.action = action_name def as_dict(self, constants: Dict[str, JaniConstant]): - edge_dict = { - "location": self.location, - "destinations": [] - } + edge_dict = {"location": self.location, "destinations": []} if self.action is not None: edge_dict.update({"action": self.action}) if self.guard is not None: diff --git a/src/as2fm/jani_generator/jani_entries/jani_expression.py b/src/as2fm/jani_generator/jani_entries/jani_expression.py index b2b6fd4b..ee1cf07c 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_expression.py +++ b/src/as2fm/jani_generator/jani_entries/jani_expression.py @@ -27,9 +27,10 @@ class JaniExpressionType(Enum): """Enumeration of the different types of Jani expressions.""" + IDENTIFIER = 1 # Reference to a constant or variable id - LITERAL = 2 # Reference to a literal value - OPERATOR = 3 # Reference to an operator (a composition of expressions) + LITERAL = 2 # Reference to a literal value + OPERATOR = 3 # Reference to an operator (a composition of expressions) class JaniExpression: @@ -44,7 +45,8 @@ class JaniExpression: - op: a string representing an operator - operands: a dictionary of operands, related to the specified operator """ - def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]): + + def __init__(self, expression: Union[SupportedExp, "JaniExpression", JaniValue]): self.identifier: Optional[str] = None self.value: Optional[JaniValue] = None self.op: Optional[str] = None @@ -57,8 +59,9 @@ def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]) elif isinstance(expression, JaniValue): self.value = expression else: - assert isinstance(expression, SupportedExp), \ - f"Unexpected expression type: {type(expression)} should be a dict or a base type." + assert isinstance( + expression, SupportedExp + ), f"Unexpected expression type: {type(expression)} should be a dict or a base type." if isinstance(expression, str): # If it is a reference to a constant or variable, we do not need to expand further self.identifier = expression @@ -74,9 +77,9 @@ def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]) self.op = expression["op"] self.operands = self._get_operands(expression) - def _get_operands(self, expression_dict: dict) -> Dict[str, 'JaniExpression']: + def _get_operands(self, expression_dict: dict) -> Dict[str, "JaniExpression"]: assert self.op is not None, "Operator not set" - if (self.op in ("intersect", "distance")): + if self.op in ("intersect", "distance"): # intersect: returns a value in [0.0, 1.0], indicating where on the robot trajectory # the intersection occurs. # 0.0 means no intersection occurs (destination reached), 1.0 means the @@ -84,55 +87,96 @@ def _get_operands(self, expression_dict: dict) -> Dict[str, 'JaniExpression']: # the barrier. return { "robot": JaniExpression(expression_dict["robot"]), - "barrier": JaniExpression(expression_dict["barrier"])} - if (self.op in ("distance_to_point")): + "barrier": JaniExpression(expression_dict["barrier"]), + } + if self.op in ("distance_to_point"): # distance between robot outer radius and point x-y coords return { "robot": JaniExpression(expression_dict["robot"]), "x": JaniExpression(expression_dict["x"]), - "y": JaniExpression(expression_dict["y"])} - if (self.op in ( - "&&", "||", "and", "or", "∨", "∧", - "⇒", "=>", "=", "≠", "!=", "+", "-", "*", "%", - "pow", "log", "/", "min", "max", - "<", "≤", ">", "≥", "<=", ">=", "==")): + "y": JaniExpression(expression_dict["y"]), + } + if self.op in ( + "&&", + "||", + "and", + "or", + "∨", + "∧", + "⇒", + "=>", + "=", + "≠", + "!=", + "+", + "-", + "*", + "%", + "pow", + "log", + "/", + "min", + "max", + "<", + "≤", + ">", + "≥", + "<=", + ">=", + "==", + ): return { "left": JaniExpression(expression_dict["left"]), - "right": JaniExpression(expression_dict["right"])} - if (self.op in ("!", "¬", "sin", "cos", "floor", "ceil", - "abs", "to_cm", "to_m", "to_deg", "to_rad")): - return { - "exp": JaniExpression(expression_dict["exp"])} - if (self.op in ("ite")): + "right": JaniExpression(expression_dict["right"]), + } + if self.op in ( + "!", + "¬", + "sin", + "cos", + "floor", + "ceil", + "abs", + "to_cm", + "to_m", + "to_deg", + "to_rad", + ): + return {"exp": JaniExpression(expression_dict["exp"])} + if self.op in ("ite"): return { "if": JaniExpression(expression_dict["if"]), "then": JaniExpression(expression_dict["then"]), - "else": JaniExpression(expression_dict["else"])} + "else": JaniExpression(expression_dict["else"]), + } # Array-specific expressions - if (self.op == "ac"): + if self.op == "ac": return { "var": JaniExpression(expression_dict["var"]), "length": JaniExpression(expression_dict["length"]), - "exp": JaniExpression(expression_dict["exp"])} - if (self.op == "aa"): - return { "exp": JaniExpression(expression_dict["exp"]), - "index": JaniExpression(expression_dict["index"])} - if (self.op == "av"): + } + if self.op == "aa": return { - "elements": JaniExpression(expression_dict["elements"])} + "exp": JaniExpression(expression_dict["exp"]), + "index": JaniExpression(expression_dict["index"]), + } + if self.op == "av": + return {"elements": JaniExpression(expression_dict["elements"])} # Convince specific expressions - if (self.op in ("norm2d")): + if self.op in ("norm2d"): return { "x": JaniExpression(expression_dict["x"]), - "y": JaniExpression(expression_dict["y"])} - if (self.op in ("dot2d", "cross2d")): + "y": JaniExpression(expression_dict["y"]), + } + if self.op in ("dot2d", "cross2d"): return { "x1": JaniExpression(expression_dict["x1"]), "y1": JaniExpression(expression_dict["y1"]), "x2": JaniExpression(expression_dict["x2"]), - "y2": JaniExpression(expression_dict["y2"])} - assert False, f"Unknown operator \"{self.op}\" found." + "y2": JaniExpression(expression_dict["y2"]), + } + assert False, f'Unknown operator "{self.op}" found.' def get_expression_type(self) -> JaniExpressionType: """Get the type of the expression.""" @@ -181,7 +225,7 @@ def as_identifier(self) -> Optional[str]: assert self.is_valid(), "Expression is not valid" return self.identifier - def as_operator(self) -> Optional[Tuple[str, Dict[str, 'JaniExpression']]]: + def as_operator(self) -> Optional[Tuple[str, Dict[str, "JaniExpression"]]]: """Provide the expression as an operator, if possible. None otherwise.""" assert self.is_valid(), "Expression is not valid" if self.op is None: @@ -198,8 +242,9 @@ def as_dict(self) -> Union[str, int, float, bool, dict]: "op": self.op, } for op_key, op_value in self.operands.items(): - assert isinstance(op_value, JaniExpression), \ - f"Expected an expression, found {type(op_value)} for {op_key}" + assert isinstance( + op_value, JaniExpression + ), f"Expected an expression, found {type(op_value)} for {op_key}" assert hasattr(op_value, "identifier"), f"Identifier not set for {op_key}" op_dict.update({op_key: op_value.as_dict()}) return op_dict diff --git a/src/as2fm/jani_generator/jani_entries/jani_guard.py b/src/as2fm/jani_generator/jani_entries/jani_guard.py index 22b2e1d5..19efac7d 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_guard.py +++ b/src/as2fm/jani_generator/jani_entries/jani_guard.py @@ -25,7 +25,7 @@ class JaniGuard: - def __init__(self, guard_exp: Optional[Union['JaniGuard', JaniExpression, dict]]): + def __init__(self, guard_exp: Optional[Union["JaniGuard", JaniExpression, dict]]): """ Construct a new JaniGuard object. @@ -42,8 +42,10 @@ def __init__(self, guard_exp: Optional[Union['JaniGuard', JaniExpression, dict]] assert "exp" in guard_exp, "Expected guard expression to be in the 'exp' dict entry" self._expression = JaniExpression(guard_exp["exp"]) else: - raise ValueError(f"Unexpected guard_exp type {type(guard_exp)}. " - "Should be None, JaniExpression or Dict.") + raise ValueError( + f"Unexpected guard_exp type {type(guard_exp)}. " + "Should be None, JaniExpression or Dict." + ) def get_expression(self) -> Optional[JaniExpression]: return self._expression @@ -52,8 +54,8 @@ def as_dict(self, _: Optional[dict] = None): d = {} if self._expression: exp = self._expression.as_dict() - if (isinstance(exp, dict) and list(exp.keys()) == ['exp']): - d['exp'] = exp['exp'] + if isinstance(exp, dict) and list(exp.keys()) == ["exp"]: + d["exp"] = exp["exp"] else: - d['exp'] = exp + d["exp"] = exp return d diff --git a/src/as2fm/jani_generator/jani_entries/jani_model.py b/src/as2fm/jani_generator/jani_entries/jani_model.py index 4ca67604..fc2995e1 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_model.py +++ b/src/as2fm/jani_generator/jani_entries/jani_model.py @@ -20,10 +20,15 @@ from typing import Dict, List, Optional, Type, Union -from as2fm.jani_generator.jani_entries import (JaniAutomaton, JaniComposition, - JaniConstant, JaniExpression, - JaniProperty, JaniValue, - JaniVariable) +from as2fm.jani_generator.jani_entries import ( + JaniAutomaton, + JaniComposition, + JaniConstant, + JaniExpression, + JaniProperty, + JaniValue, + JaniVariable, +) ValidValue = Union[int, float, bool, dict, JaniExpression] @@ -33,6 +38,7 @@ class JaniModel: Class representing a complete Jani Model, containing all necessary information to generate a plain Jani file. """ + @staticmethod def from_dict(model_dict: dict) -> "JaniModel": model = JaniModel() @@ -77,18 +83,29 @@ def get_features(self) -> List[str]: def add_jani_variable(self, variable: JaniVariable): self._variables.update({variable.name(): variable}) - def add_variable(self, variable_name: str, variable_type: Type, - variable_init_expression: Optional[ValidValue] = None, - transient: bool = False): + def add_variable( + self, + variable_name: str, + variable_type: Type, + variable_init_expression: Optional[ValidValue] = None, + transient: bool = False, + ): if variable_init_expression is None or isinstance(variable_init_expression, JaniExpression): self.add_jani_variable( - JaniVariable(variable_name, variable_type, variable_init_expression, transient)) + JaniVariable(variable_name, variable_type, variable_init_expression, transient) + ) else: - assert JaniValue(variable_init_expression).is_valid(), \ - f"Invalid value for variable {variable_name}" + assert JaniValue( + variable_init_expression + ).is_valid(), f"Invalid value for variable {variable_name}" self.add_jani_variable( - JaniVariable(variable_name, variable_type, - JaniExpression(variable_init_expression), transient)) + JaniVariable( + variable_name, + variable_type, + JaniExpression(variable_init_expression), + transient, + ) + ) def add_jani_constant(self, constant: JaniConstant): self._constants.update({constant.name(): constant}) @@ -97,10 +114,12 @@ def add_constant(self, constant_name: str, constant_type: Type, constant_value: if isinstance(constant_value, JaniExpression): self.add_jani_constant(JaniConstant(constant_name, constant_type, constant_value)) else: - assert JaniValue(constant_value).is_valid(), \ - f"Invalid value for constant {constant_name}" + assert JaniValue( + constant_value + ).is_valid(), f"Invalid value for constant {constant_name}" self.add_jani_constant( - JaniConstant(constant_name, constant_type, JaniExpression(constant_value))) + JaniConstant(constant_name, constant_type, JaniExpression(constant_value)) + ) def add_jani_automaton(self, automaton: JaniAutomaton): self._automata.append(automaton) @@ -136,8 +155,9 @@ def remove_edges_with_action(self, action: str): def _generate_missing_syncs(self): """Automatically generate the syncs that are not explicitly defined.""" - assert len(self._automata) == len(self._system.get_elements()), \ - "We expect there to be explicit syncs for all automata." + assert len(self._automata) == len( + self._system.get_elements() + ), "We expect there to be explicit syncs for all automata." for automaton in self._automata: existing_syncs = self._system.get_syncs_for_element(automaton.get_name()) for action in automaton.get_actions(): @@ -155,21 +175,29 @@ def as_dict(self): available_actions = set() for automaton in self._automata: available_actions.update(automaton.get_actions()) - model_dict.update({ - "jani-version": 1, - "name": self._name, - "type": self._type, - "features": self._features, - "metadata": { - "description": "Autogenerated with CONVINCE toolchain", - }, - "variables": [jani_variable.as_dict() for jani_variable in self._variables.values()], - "constants": [jani_constant.as_dict() for jani_constant in self._constants.values()], - "actions": [{"name": action} for action in sorted(list(available_actions))], - "automata": [jani_automaton.as_dict(self._constants) for - jani_automaton in self._automata], - "system": self._system.as_dict(), - "properties": [jani_property.as_dict(self._constants) for - jani_property in self._properties] - }) + model_dict.update( + { + "jani-version": 1, + "name": self._name, + "type": self._type, + "features": self._features, + "metadata": { + "description": "Autogenerated with CONVINCE toolchain", + }, + "variables": [ + jani_variable.as_dict() for jani_variable in self._variables.values() + ], + "constants": [ + jani_constant.as_dict() for jani_constant in self._constants.values() + ], + "actions": [{"name": action} for action in sorted(list(available_actions))], + "automata": [ + jani_automaton.as_dict(self._constants) for jani_automaton in self._automata + ], + "system": self._system.as_dict(), + "properties": [ + jani_property.as_dict(self._constants) for jani_property in self._properties + ], + } + ) return model_dict diff --git a/src/as2fm/jani_generator/jani_entries/jani_property.py b/src/as2fm/jani_generator/jani_entries/jani_property.py index 632447e5..d6f1d0e1 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_property.py +++ b/src/as2fm/jani_generator/jani_entries/jani_property.py @@ -21,24 +21,26 @@ from typing import Any, Dict, Union from as2fm.jani_generator.jani_entries import JaniConstant, JaniExpression -from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import \ - expand_expression +from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression class FilterProperty: """All Property operators must occur in a FilterProperty object.""" + def __init__(self, property_filter_exp: Dict[str, Any]): assert isinstance(property_filter_exp, dict), "Unexpected FilterProperty initialization" - assert "op" in property_filter_exp and property_filter_exp["op"] == "filter", \ - "Unexpected FilterProperty initialization" + assert ( + "op" in property_filter_exp and property_filter_exp["op"] == "filter" + ), "Unexpected FilterProperty initialization" self._fun = property_filter_exp["fun"] raw_states = property_filter_exp["states"] assert isinstance(raw_states, dict) and raw_states["op"] == "initial" self._process_values(property_filter_exp["values"]) def _process_values(self, prop_values: Dict[str, Any]) -> None: - self._values: Union[ProbabilityProperty, RewardProperty, NumPathsProperty] = \ + self._values: Union[ProbabilityProperty, RewardProperty, NumPathsProperty] = ( ProbabilityProperty(prop_values) + ) if self._values.is_valid(): return self._values = RewardProperty(prop_values) @@ -48,20 +50,20 @@ def _process_values(self, prop_values: Dict[str, Any]) -> None: assert self._values.is_valid(), "Unexpected values in FilterProperty" def as_dict(self, constants: Dict[str, JaniConstant]): - assert isinstance(self._values, ProbabilityProperty), \ - "Only ProbabilityProperty is supported in FilterProperty" + assert isinstance( + self._values, ProbabilityProperty + ), "Only ProbabilityProperty is supported in FilterProperty" return { "op": "filter", "fun": self._fun, - "states": { - "op": "initial" - }, - "values": self._values.as_dict(constants) + "states": {"op": "initial"}, + "values": self._values.as_dict(constants), } class ProbabilityProperty: """Pmin / Pmax""" + def __init__(self, prop_values: Dict[str, Any]): self._valid = False if "op" in prop_values and "exp" in prop_values: @@ -74,14 +76,12 @@ def is_valid(self) -> bool: return self._valid def as_dict(self, constants: Dict[str, JaniConstant]): - return { - "op": self._op, - "exp": self._exp.as_dict(constants) - } + return {"op": self._op, "exp": self._exp.as_dict(constants)} class RewardProperty: """E properties""" + def __init__(self, prop_values: Dict[str, Any]): self._valid = False @@ -91,6 +91,7 @@ def is_valid(self) -> bool: class NumPathsProperty: """This address properties where we want the property verified on all / at least one case.""" + def __init__(self, prop_values: Dict[str, Any]): self._valid = False @@ -100,6 +101,7 @@ def is_valid(self) -> bool: class PathProperty: """Mainly Until properties. Need to check support of Next and Global properties in Jani.""" + def __init__(self, prop_values: Dict[str, Any]): self._valid = False if "op" not in prop_values: @@ -111,7 +113,7 @@ def __init__(self, prop_values: Dict[str, Any]): elif self._op in ("U", "W"): self._operands = { "left": JaniExpression(prop_values["left"]), - "right": JaniExpression(prop_values["right"]) + "right": JaniExpression(prop_values["right"]), } else: print(f"Warning: Unsupported PathProperty operator {self._op}") @@ -129,8 +131,12 @@ def is_valid(self) -> bool: def as_dict(self, constants: Dict[str, JaniConstant]): ret_dict = {"op": self._op} - ret_dict.update({operand: expand_expression(expr, constants).as_dict() for - operand, expr in self._operands.items()}) + ret_dict.update( + { + operand: expand_expression(expr, constants).as_dict() + for operand, expr in self._operands.items() + } + ) if self._bounds is not None: ret_dict["step-bounds"] = self._bounds.as_dict(constants) return ret_dict @@ -171,7 +177,4 @@ def __init__(self, name, expression): self._expression = FilterProperty(expression) def as_dict(self, constants: Dict[str, JaniConstant]): - return { - "name": self._name, - "expression": self._expression.as_dict(constants) - } + return {"name": self._name, "expression": self._expression.as_dict(constants)} diff --git a/src/as2fm/jani_generator/jani_entries/jani_utils.py b/src/as2fm/jani_generator/jani_entries/jani_utils.py index a3aa14cb..b3f8df30 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_utils.py +++ b/src/as2fm/jani_generator/jani_entries/jani_utils.py @@ -17,8 +17,7 @@ from typing import Any, Dict, MutableSequence, Optional, Tuple, Type, get_args -from as2fm.as2fm_common.common import (get_default_expression_for_type, - is_array_type) +from as2fm.as2fm_common.common import get_default_expression_for_type, is_array_type from as2fm.jani_generator.jani_entries import JaniAutomaton @@ -32,8 +31,9 @@ def get_variable_type(jani_automaton: JaniAutomaton, variable_name: Optional[str """ assert variable_name is not None, "Variable name must be provided." variable = jani_automaton.get_variables().get(variable_name) - assert variable is not None, \ - f"Variable {variable_name} not found in {jani_automaton.get_variables()}." + assert ( + variable is not None + ), f"Variable {variable_name} not found in {jani_automaton.get_variables()}." return variable.get_type() @@ -64,9 +64,9 @@ def get_array_type_and_size(jani_automaton: JaniAutomaton, var_name: str) -> Tup init_operator = variable.get_init_expr().as_operator() assert init_operator is not None, f"Expected init expr of {var_name} to be an operator expr." if init_operator[0] == "av": - max_size = len(init_operator[1]['elements'].as_literal().value()) + max_size = len(init_operator[1]["elements"].as_literal().value()) elif init_operator[0] == "ac": - max_size = init_operator[1]['length'].as_literal().value() + max_size = init_operator[1]["length"].as_literal().value() else: raise ValueError(f"Unexpected operator {init_operator[0]} for {var_name} init expr.") return (array_type, max_size) diff --git a/src/as2fm/jani_generator/jani_entries/jani_value.py b/src/as2fm/jani_generator/jani_entries/jani_value.py index dba3fe1a..a3b49d5c 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_value.py +++ b/src/as2fm/jani_generator/jani_entries/jani_value.py @@ -23,14 +23,17 @@ class JaniValue: """Class containing Jani Constant Values""" + def __init__(self, value): self._value = value def is_valid(self) -> bool: if isinstance(self._value, dict): if "constant" in self._value: - assert self._value["constant"] in ("e", "π"), \ - f"Unknown constant value {self._value['constant']}. Only 'e' and 'π' supported." + assert self._value["constant"] in ( + "e", + "π", + ), f"Unknown constant value {self._value['constant']}. Only 'e' and 'π' supported." return True elif isinstance(self._value, list): return all(JaniValue(v).is_valid() for v in self._value) diff --git a/src/as2fm/jani_generator/jani_entries/jani_variable.py b/src/as2fm/jani_generator/jani_entries/jani_variable.py index d9d0fa07..e5947548 100644 --- a/src/as2fm/jani_generator/jani_entries/jani_variable.py +++ b/src/as2fm/jani_generator/jani_entries/jani_variable.py @@ -30,34 +30,43 @@ def from_dict(variable_dict: dict) -> "JaniVariable": initial_value = variable_dict.get("initial-value", None) variable_type: type = JaniVariable.python_type_from_json(variable_dict["type"]) if initial_value is None: - return JaniVariable(variable_name, - variable_type, - None, - variable_dict.get("transient", False)) + return JaniVariable( + variable_name, variable_type, None, variable_dict.get("transient", False) + ) if isinstance(initial_value, str): # Check if conversion from string to variable_type is possible try: init_value_cast = variable_type(initial_value) - return JaniVariable(variable_name, - variable_type, - JaniExpression(init_value_cast), - variable_dict.get("transient", False)) + return JaniVariable( + variable_name, + variable_type, + JaniExpression(init_value_cast), + variable_dict.get("transient", False), + ) except ValueError: # If no conversion possible, raise an error (variable names are not supported) raise ValueError( f"Initial value {initial_value} for variable {variable_name} " - f"is not a valid value for type {variable_type}.") - return JaniVariable(variable_name, - variable_type, - JaniExpression(initial_value), - variable_dict.get("transient", False)) - - def __init__(self, v_name: str, v_type: Type[ValidTypes], - init_value: Optional[Union[JaniExpression, JaniValue]] = None, - v_transient: bool = False): - assert init_value is None or isinstance(init_value, (JaniExpression, JaniValue)), \ - f"Expected {v_name} init_value {init_value} to be of type " \ + f"is not a valid value for type {variable_type}." + ) + return JaniVariable( + variable_name, + variable_type, + JaniExpression(initial_value), + variable_dict.get("transient", False), + ) + + def __init__( + self, + v_name: str, + v_type: Type[ValidTypes], + init_value: Optional[Union[JaniExpression, JaniValue]] = None, + v_transient: bool = False, + ): + assert init_value is None or isinstance(init_value, (JaniExpression, JaniValue)), ( + f"Expected {v_name} init_value {init_value} to be of type " f"(JaniExpression, JaniValue), found {type(init_value)} instead." + ) self._name: str = v_name self._type: Type[ValidTypes] = v_type self._transient: bool = v_transient @@ -74,11 +83,14 @@ def __init__(self, v_name: str, v_type: Type[ValidTypes], self._init_expr = JaniExpression(0.0) else: raise ValueError( - f"JaniVariable {self._name} of type {self._type} needs an initial value") + f"JaniVariable {self._name} of type {self._type} needs an initial value" + ) assert v_type in get_args(ValidTypes), f"Type {v_type} not supported by Jani" if not self._transient and self._type in (float, MutableSequence[float]): - print(f"Warning: Variable {self._name} is not transient and has type float." - "This is not supported by STORM.") + print( + f"Warning: Variable {self._name} is not transient and has type float." + "This is not supported by STORM." + ) def name(self): """Get name.""" @@ -97,7 +109,7 @@ def as_dict(self): d = { "name": self._name, "type": JaniVariable.python_type_to_json(self._type), - "transient": self._transient + "transient": self._transient, } if self._init_expr is not None: d["initial-value"] = self._init_expr.as_dict() diff --git a/src/as2fm/jani_generator/main.py b/src/as2fm/jani_generator/main.py index 535c2e0c..f39a406b 100644 --- a/src/as2fm/jani_generator/main.py +++ b/src/as2fm/jani_generator/main.py @@ -23,8 +23,7 @@ from as2fm.jani_generator.convince_jani_helpers import convince_jani_parser from as2fm.jani_generator.jani_entries import JaniModel -from as2fm.jani_generator.scxml_helpers.top_level_interpreter import \ - interpret_top_level_xml +from as2fm.jani_generator.scxml_helpers.top_level_interpreter import interpret_top_level_xml def main_convince_to_plain_jani(_args: Optional[Sequence[str]] = None) -> None: @@ -34,20 +33,16 @@ def main_convince_to_plain_jani(_args: Optional[Sequence[str]] = None) -> None: :param args: The arguments to parse. If None, sys.argv is used. :return: None """ - parser = argparse.ArgumentParser( - description='Convert CONVINCE JANI to plain JANI.') - parser.add_argument( - '--convince_jani', help='The convince-jani file.', type=str, required=True) - parser.add_argument( - '--output', help='The output Plain JANI file.', type=str, required=True) + parser = argparse.ArgumentParser(description="Convert CONVINCE JANI to plain JANI.") + parser.add_argument("--convince_jani", help="The convince-jani file.", type=str, required=True) + parser.add_argument("--output", help="The output Plain JANI file.", type=str, required=True) args = parser.parse_args(_args) start_time = timeit.default_timer() model_loaded = False jani_model = JaniModel() if args.convince_jani is not None: - assert os.path.isfile( - args.convince_jani), f"File {args.convince_jani} does not exist." + assert os.path.isfile(args.convince_jani), f"File {args.convince_jani} does not exist." # Check the file's extension _, extension = os.path.splitext(args.convince_jani) assert extension == ".jani", f"File {args.convince_jani} is not a JANI file." @@ -55,9 +50,8 @@ def main_convince_to_plain_jani(_args: Optional[Sequence[str]] = None) -> None: model_loaded = True assert model_loaded, "No input file was provided. Check your input." # Write the loaded model to the output file - with open(args.output, "w", encoding='utf-8') as output_file: - json.dump(jani_model.as_dict(), output_file, - indent=4, ensure_ascii=False) + with open(args.output, "w", encoding="utf-8") as output_file: + json.dump(jani_model.as_dict(), output_file, indent=4, ensure_ascii=False) print(f"Converted jani model written to {args.output}.") print(f"Conversion took {timeit.default_timer() - start_time} seconds.") @@ -77,14 +71,17 @@ def main_scxml_to_jani(_args: Optional[Sequence[str]] = None) -> None: :param args: The arguments to parse. If None, sys.argv is used. :return: None """ - parser = argparse.ArgumentParser( - description="Convert SCXML robot system models to JANI model.") - parser.add_argument("--generated-scxml-dir", type=str, default="", - help="Path to the folder containing the generated plain-SCXML files.") - parser.add_argument("--jani-out-file", type=str, default="main.jani", - help="Path to the generated jani file.") + parser = argparse.ArgumentParser(description="Convert SCXML robot system models to JANI model.") + parser.add_argument( + "--generated-scxml-dir", + type=str, + default="", + help="Path to the folder containing the generated plain-SCXML files.", + ) parser.add_argument( - "main_xml", type=str, help="The path to the main XML file to interpret.") + "--jani-out-file", type=str, default="main.jani", help="Path to the generated jani file." + ) + parser.add_argument("main_xml", type=str, help="The path to the main XML file to interpret.") args = parser.parse_args(_args) main_xml_file = args.main_xml diff --git a/src/as2fm/jani_generator/ros_helpers/ros_action_handler.py b/src/as2fm/jani_generator/ros_helpers/ros_action_handler.py index 362d3ffe..ddffe768 100644 --- a/src/as2fm/jani_generator/ros_helpers/ros_action_handler.py +++ b/src/as2fm/jani_generator/ros_helpers/ros_action_handler.py @@ -19,25 +19,38 @@ from typing import Callable, Dict, List, Tuple -from as2fm.jani_generator.ros_helpers.ros_communication_handler import \ - RosCommunicationHandler -from as2fm.scxml_converter.scxml_entries import (ScxmlAssign, ScxmlData, - ScxmlDataModel, ScxmlIf, - ScxmlParam, ScxmlRoot, - ScxmlSend, ScxmlState, - ScxmlTransition) +from as2fm.jani_generator.ros_helpers.ros_communication_handler import RosCommunicationHandler +from as2fm.scxml_converter.scxml_entries import ( + ScxmlAssign, + ScxmlData, + ScxmlDataModel, + ScxmlIf, + ScxmlParam, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.ros_utils import ( - generate_action_feedback_event, generate_action_feedback_handle_event, + generate_action_feedback_event, + generate_action_feedback_handle_event, generate_action_goal_accepted_event, generate_action_goal_handle_accepted_event, generate_action_goal_handle_event, generate_action_goal_handle_rejected_event, - generate_action_goal_rejected_event, generate_action_goal_req_event, - generate_action_result_event, generate_action_result_handle_event, - get_action_goal_id_definition, get_action_type_params, - sanitize_ros_interface_name) + generate_action_goal_rejected_event, + generate_action_goal_req_event, + generate_action_result_event, + generate_action_result_handle_event, + get_action_goal_id_definition, + get_action_type_params, + sanitize_ros_interface_name, +) from as2fm.scxml_converter.scxml_entries.utils import ( - PLAIN_FIELD_EVENT_PREFIX, PLAIN_SCXML_EVENT_PREFIX, ROS_FIELD_PREFIX) + PLAIN_FIELD_EVENT_PREFIX, + PLAIN_SCXML_EVENT_PREFIX, + ROS_FIELD_PREFIX, +) class RosActionHandler(RosCommunicationHandler): @@ -50,8 +63,8 @@ def get_interface_prefix() -> str: return "action_handler_" def _generate_goal_request_transition( - self, goal_state: ScxmlState, client_id: str, goal_id: int, req_params: Dict[str, str] - ) -> ScxmlTransition: + self, goal_state: ScxmlState, client_id: str, goal_id: int, req_params: Dict[str, str] + ) -> ScxmlTransition: """ Generate a scxml transition that, given a client request, sends an event to the server. @@ -68,18 +81,24 @@ def _generate_goal_request_transition( # Add preliminary assignments (part of the hack mentioned in self.to_scxml()) field_w_pref = ROS_FIELD_PREFIX + field_name goal_req_transition.append_body_executable_entry( - ScxmlAssign(field_w_pref, PLAIN_FIELD_EVENT_PREFIX + field_name)) + ScxmlAssign(field_w_pref, PLAIN_FIELD_EVENT_PREFIX + field_name) + ) send_params.append(ScxmlParam(field_w_pref, expr=field_w_pref)) # Add the send to the server goal_req_transition.append_body_executable_entry( - ScxmlSend(action_srv_handle_event, send_params)) + ScxmlSend(action_srv_handle_event, send_params) + ) return goal_req_transition def _generate_srv_event_transition( - self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]], - event_fields: Dict[str, str], srv_event_function: Callable[[str], str], - client_event_function: Callable[[str, str], str], - additional_data: List[str]) -> ScxmlTransition: + self, + goal_state: ScxmlState, + client_to_goal_id: List[Tuple[str, int]], + event_fields: Dict[str, str], + srv_event_function: Callable[[str], str], + client_event_function: Callable[[str, str], str], + additional_data: List[str], + ) -> ScxmlTransition: """ Generate a scxml transition that triggers the client related to the input event's goal_id. @@ -95,50 +114,66 @@ def _generate_srv_event_transition( scxml_transition = ScxmlTransition(goal_state.get_id(), [srv_event_name]) for entry_name in extra_entries: scxml_transition.append_body_executable_entry( - ScxmlAssign(entry_name, PLAIN_SCXML_EVENT_PREFIX + entry_name)) + ScxmlAssign(entry_name, PLAIN_SCXML_EVENT_PREFIX + entry_name) + ) out_params: List[ScxmlParam] = [] for entry_name in additional_data: out_params.append(ScxmlParam(entry_name, expr=entry_name)) for field_name in event_fields: field_w_pref = ROS_FIELD_PREFIX + field_name scxml_transition.append_body_executable_entry( - ScxmlAssign(field_w_pref, PLAIN_FIELD_EVENT_PREFIX + field_name)) + ScxmlAssign(field_w_pref, PLAIN_FIELD_EVENT_PREFIX + field_name) + ) out_params.append(ScxmlParam(field_w_pref, expr=field_w_pref)) condition_send_pairs: List[Tuple[str, List[ScxmlSend]]] = [] for client_id, goal_id in client_to_goal_id: client_event = client_event_function(self._interface_name, client_id) - condition_send_pairs.append((f"{goal_id_name} == {goal_id}", - [ScxmlSend(client_event, out_params)])) + condition_send_pairs.append( + (f"{goal_id_name} == {goal_id}", [ScxmlSend(client_event, out_params)]) + ) scxml_transition.append_body_executable_entry(ScxmlIf(condition_send_pairs)) return scxml_transition def _generate_goal_accept_transition( - self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]] - ) -> ScxmlTransition: + self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]] + ) -> ScxmlTransition: """ Generate a scxml transition that sends an event to the client to report an accepted goal. :param client_to_goal_id: List of tuples (client_id, goal_id) relating clients to goal ids. """ return self._generate_srv_event_transition( - goal_state, client_to_goal_id, {}, generate_action_goal_accepted_event, - generate_action_goal_handle_accepted_event, []) + goal_state, + client_to_goal_id, + {}, + generate_action_goal_accepted_event, + generate_action_goal_handle_accepted_event, + [], + ) def _generate_goal_reject_transition( - self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]] - ) -> ScxmlTransition: + self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]] + ) -> ScxmlTransition: """ Generate a scxml transition that sends an event to the client to report a rejected goal. :param client_to_goal_id: List of tuples (client_id, goal_id) relating clients to goal ids. """ return self._generate_srv_event_transition( - goal_state, client_to_goal_id, {}, generate_action_goal_rejected_event, - generate_action_goal_handle_rejected_event, []) + goal_state, + client_to_goal_id, + {}, + generate_action_goal_rejected_event, + generate_action_goal_handle_rejected_event, + [], + ) def _generate_feedback_response_transition( - self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]], - feedback_params: Dict[str, str]) -> ScxmlTransition: + self, + goal_state: ScxmlState, + client_to_goal_id: List[Tuple[str, int]], + feedback_params: Dict[str, str], + ) -> ScxmlTransition: """ Generate a scxml transition that sends an event to the client to report feedback. @@ -146,12 +181,20 @@ def _generate_feedback_response_transition( :param feedback_params: Dictionary of the parameters of the feedback. """ return self._generate_srv_event_transition( - goal_state, client_to_goal_id, feedback_params, generate_action_feedback_event, - generate_action_feedback_handle_event, []) + goal_state, + client_to_goal_id, + feedback_params, + generate_action_feedback_event, + generate_action_feedback_handle_event, + [], + ) def _generate_result_response_transition( - self, goal_state: ScxmlState, client_to_goal_id: List[Tuple[str, int]], - result_params: Dict[str, str]) -> ScxmlTransition: + self, + goal_state: ScxmlState, + client_to_goal_id: List[Tuple[str, int]], + result_params: Dict[str, str], + ) -> ScxmlTransition: """ Generate a scxml transition that sends an event to the client to report the result. @@ -159,8 +202,13 @@ def _generate_result_response_transition( :param result_params: Dictionary of the parameters of the result. """ return self._generate_srv_event_transition( - goal_state, client_to_goal_id, result_params, generate_action_result_event, - generate_action_result_handle_event, ["code"]) + goal_state, + client_to_goal_id, + result_params, + generate_action_result_event, + generate_action_result_handle_event, + ["code"], + ) def to_scxml(self) -> ScxmlRoot: """ @@ -176,33 +224,44 @@ def to_scxml(self) -> ScxmlRoot: # Design choice: we generate a unique goal_id for each client, and we use it to identify # the recipient of the response. client_to_goal_id: List[Tuple[str, int]] = [ - (client_id, goal_id) for goal_id, client_id in enumerate(self._clients_automata)] + (client_id, goal_id) for goal_id, client_id in enumerate(self._clients_automata) + ] goal_params, feedback_params, result_params = get_action_type_params(self._interface_type) # Hack: Using support variables in the data model to avoid having _event in send params goal_id_def = get_action_goal_id_definition() action_fields_as_data = self._generate_datamodel_from_ros_fields( - goal_params | feedback_params | result_params) + goal_params | feedback_params | result_params + ) action_fields_as_data.append(ScxmlData(goal_id_def[0], "0", goal_id_def[1])) action_fields_as_data.append(ScxmlData("code", "0", "int32")) # Make sure the service name has no slashes and spaces - scxml_root_name = \ - self.get_interface_prefix() + sanitize_ros_interface_name(self._interface_name) + scxml_root_name = self.get_interface_prefix() + sanitize_ros_interface_name( + self._interface_name + ) wait_state = ScxmlState("waiting") goal_requested_state = ScxmlState("goal_requested") for client_id, goal_id in client_to_goal_id: wait_state.add_transition( self._generate_goal_request_transition( - goal_requested_state, client_id, goal_id, goal_params)) + goal_requested_state, client_id, goal_id, goal_params + ) + ) goal_requested_state.add_transition( - self._generate_goal_accept_transition(wait_state, client_to_goal_id)) + self._generate_goal_accept_transition(wait_state, client_to_goal_id) + ) goal_requested_state.add_transition( - self._generate_goal_reject_transition(wait_state, client_to_goal_id)) - wait_state.add_transition(self._generate_feedback_response_transition( - wait_state, client_to_goal_id, feedback_params)) - wait_state.add_transition(self._generate_result_response_transition( - wait_state, client_to_goal_id, result_params)) + self._generate_goal_reject_transition(wait_state, client_to_goal_id) + ) + wait_state.add_transition( + self._generate_feedback_response_transition( + wait_state, client_to_goal_id, feedback_params + ) + ) + wait_state.add_transition( + self._generate_result_response_transition(wait_state, client_to_goal_id, result_params) + ) scxml_root = ScxmlRoot(scxml_root_name) scxml_root.set_data_model(ScxmlDataModel(action_fields_as_data)) scxml_root.add_state(wait_state, initial=True) diff --git a/src/as2fm/jani_generator/ros_helpers/ros_communication_handler.py b/src/as2fm/jani_generator/ros_helpers/ros_communication_handler.py index 7c99f3c7..56bbe4e9 100644 --- a/src/as2fm/jani_generator/ros_helpers/ros_communication_handler.py +++ b/src/as2fm/jani_generator/ros_helpers/ros_communication_handler.py @@ -19,12 +19,10 @@ from typing import Dict, Iterator, List, Optional, Type -from as2fm.as2fm_common.common import (get_default_expression_for_type, - value_to_string) +from as2fm.as2fm_common.common import get_default_expression_for_type, value_to_string from as2fm.jani_generator.jani_entries import JaniModel from as2fm.scxml_converter.scxml_entries import ScxmlData, ScxmlRoot -from as2fm.scxml_converter.scxml_entries.utils import ( - ROS_FIELD_PREFIX, get_data_type_from_string) +from as2fm.scxml_converter.scxml_entries.utils import ROS_FIELD_PREFIX, get_data_type_from_string class RosCommunicationHandler: @@ -58,10 +56,12 @@ def _set_name_and_type(self, interface_name: str, interface_type: str) -> None: self._interface_name = interface_name self._interface_type = interface_type else: - assert self._interface_name == interface_name, \ - f"Error: Interface name {interface_name} does not match {self._interface_name}." - assert self._interface_type == interface_type, \ - f"Error: Interface type {interface_type} does not match {self._interface_type}." + assert ( + self._interface_name == interface_name + ), f"Error: Interface name {interface_name} does not match {self._interface_name}." + assert ( + self._interface_type == interface_type + ), f"Error: Interface type {interface_type} does not match {self._interface_type}." def _assert_validity(self): """ @@ -69,10 +69,12 @@ def _assert_validity(self): """ assert self._interface_name is not None, "Interface name not set." assert self._interface_type is not None, "Interface type not set." - assert self._server_automaton is not None, \ - f"ROS server not provided for {self._interface_name}." - assert len(self._clients_automata) > 0, \ - f"No ROS clients provided for {self._interface_name}." + assert ( + self._server_automaton is not None + ), f"ROS server not provided for {self._interface_name}." + assert ( + len(self._clients_automata) > 0 + ), f"No ROS clients provided for {self._interface_name}." def set_server(self, interface_name: str, interface_type: str, automaton_name: str) -> None: """ @@ -84,8 +86,9 @@ def set_server(self, interface_name: str, interface_type: str, automaton_name: s :automaton_name: The name of the JANI automaton that implements this server. """ self._set_name_and_type(interface_name, interface_type) - assert self._server_automaton is None, \ - f"Found more than one server for interface {interface_name}." + assert ( + self._server_automaton is None + ), f"Found more than one server for interface {interface_name}." self._server_automaton = automaton_name def add_client(self, interface_name: str, interface_type: str, automaton_name: str) -> None: @@ -98,8 +101,9 @@ def add_client(self, interface_name: str, interface_type: str, automaton_name: s :automaton_name: The name of the JANI automaton that implements this client. """ self._set_name_and_type(interface_name, interface_type) - assert automaton_name not in self._clients_automata, \ - f"Service client for {automaton_name} already declared for service {interface_name}." + assert ( + automaton_name not in self._clients_automata + ), f"Service client for {automaton_name} already declared for service {interface_name}." self._clients_automata.append(automaton_name) def to_scxml(self) -> ScxmlRoot: @@ -130,9 +134,12 @@ def _generate_datamodel_from_ros_fields(self, fields: Dict[str, str]) -> List[Sc def update_ros_communication_handlers( - automaton_name: str, handler_class: Type[RosCommunicationHandler], - handlers_dict: Dict[str, RosCommunicationHandler], - servers_dict: Dict[str, tuple], clients_dict: Dict[str, tuple]): + automaton_name: str, + handler_class: Type[RosCommunicationHandler], + handlers_dict: Dict[str, RosCommunicationHandler], + servers_dict: Dict[str, tuple], + clients_dict: Dict[str, tuple], +): """ Update the ROS communication handlers with the given clients and servers. @@ -141,8 +148,9 @@ def update_ros_communication_handlers( :param servers_dict: The dictionary of servers to add. :param clients_dict: The dictionary of clients to add. """ - assert issubclass(handler_class, RosCommunicationHandler), \ - f"The handler class {handler_class} must be a subclass of RosCommunicationHandler." + assert issubclass( + handler_class, RosCommunicationHandler + ), f"The handler class {handler_class} must be a subclass of RosCommunicationHandler." for service_name, service_type in servers_dict.values(): if service_name not in handlers_dict: handlers_dict[service_name] = handler_class() @@ -150,12 +158,12 @@ def update_ros_communication_handlers( for service_name, service_type in clients_dict.values(): if service_name not in handlers_dict: handlers_dict[service_name] = handler_class() - handlers_dict[service_name].add_client( - service_name, service_type, automaton_name) + handlers_dict[service_name].add_client(service_name, service_type, automaton_name) def generate_plain_scxml_from_handlers( - handlers_dict: Dict[str, RosCommunicationHandler]) -> Iterator[ScxmlRoot]: + handlers_dict: Dict[str, RosCommunicationHandler] +) -> Iterator[ScxmlRoot]: """ Generate the plain SCXML models from the ROS communication handlers. @@ -172,8 +180,9 @@ def remove_empty_self_loops_from_interface_handlers_in_jani(jani_model: JaniMode :param jani_model: The Jani model to modify. """ - handlers_prefixes = [handler.get_interface_prefix() - for handler in RosCommunicationHandler.__subclasses__()] + handlers_prefixes = [ + handler.get_interface_prefix() for handler in RosCommunicationHandler.__subclasses__() + ] for automaton in jani_model.get_automata(): # Modify the automaton in place for prefix in handlers_prefixes: diff --git a/src/as2fm/jani_generator/ros_helpers/ros_service_handler.py b/src/as2fm/jani_generator/ros_helpers/ros_service_handler.py index d37f3139..5c68225e 100644 --- a/src/as2fm/jani_generator/ros_helpers/ros_service_handler.py +++ b/src/as2fm/jani_generator/ros_helpers/ros_service_handler.py @@ -19,18 +19,25 @@ from typing import Dict, List -from as2fm.jani_generator.ros_helpers.ros_communication_handler import \ - RosCommunicationHandler -from as2fm.scxml_converter.scxml_entries import (ScxmlAssign, ScxmlDataModel, - ScxmlParam, ScxmlRoot, - ScxmlSend, ScxmlState, - ScxmlTransition) +from as2fm.jani_generator.ros_helpers.ros_communication_handler import RosCommunicationHandler +from as2fm.scxml_converter.scxml_entries import ( + ScxmlAssign, + ScxmlDataModel, + ScxmlParam, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.ros_utils import ( - generate_srv_request_event, generate_srv_response_event, - generate_srv_server_request_event, generate_srv_server_response_event, - get_srv_type_params, sanitize_ros_interface_name) -from as2fm.scxml_converter.scxml_entries.utils import ( - PLAIN_FIELD_EVENT_PREFIX, ROS_FIELD_PREFIX) + generate_srv_request_event, + generate_srv_response_event, + generate_srv_server_request_event, + generate_srv_server_response_event, + get_srv_type_params, + sanitize_ros_interface_name, +) +from as2fm.scxml_converter.scxml_entries.utils import PLAIN_FIELD_EVENT_PREFIX, ROS_FIELD_PREFIX class RosServiceHandler(RosCommunicationHandler): @@ -42,8 +49,9 @@ class RosServiceHandler(RosCommunicationHandler): def get_interface_prefix() -> str: return "srv_handler_" - def generate_transition_to_processing_state(self, client_id: str, - req_fields: Dict[str, str]) -> ScxmlTransition: + def generate_transition_to_processing_state( + self, client_id: str, req_fields: Dict[str, str] + ) -> ScxmlTransition: """ Generate a transition from the waiting state to the processing state for a given client. @@ -60,11 +68,13 @@ def generate_transition_to_processing_state(self, client_id: str, return ScxmlTransition( f"processing_client_{client_id}", [generate_srv_request_event(self._interface_name, client_id)], - body=assignments + [ScxmlSend( - generate_srv_server_request_event(self._interface_name), event_params)]) + body=assignments + + [ScxmlSend(generate_srv_server_request_event(self._interface_name), event_params)], + ) - def generate_transition_from_processing_state(self, client_id: str, - res_fields: Dict[str, str]) -> ScxmlTransition: + def generate_transition_from_processing_state( + self, client_id: str, res_fields: Dict[str, str] + ) -> ScxmlTransition: """ Generate a transition from the processing state to the waiting state for a given client. """ @@ -75,9 +85,15 @@ def generate_transition_from_processing_state(self, client_id: str, assignments.append(ScxmlAssign(field_w_pref, PLAIN_FIELD_EVENT_PREFIX + field_name)) event_params.append(ScxmlParam(field_w_pref, expr=field_w_pref)) return ScxmlTransition( - "waiting", [generate_srv_server_response_event(self._interface_name)], - body=assignments + [ScxmlSend( - generate_srv_response_event(self._interface_name, client_id), event_params)]) + "waiting", + [generate_srv_server_response_event(self._interface_name)], + body=assignments + + [ + ScxmlSend( + generate_srv_response_event(self._interface_name, client_id), event_params + ) + ], + ) def to_scxml(self) -> ScxmlRoot: """ @@ -93,17 +109,23 @@ def to_scxml(self) -> ScxmlRoot: # Hack: Using support variables in the data model to avoid having _event in send params req_fields_as_data = self._generate_datamodel_from_ros_fields(req_params | res_params) # Make sure the service name has no slashes and spaces - scxml_root_name = \ - self.get_interface_prefix() + sanitize_ros_interface_name(self._interface_name) - wait_state = ScxmlState("waiting", - body=[ - self.generate_transition_to_processing_state( - client_id, req_params) - for client_id in self._clients_automata]) + scxml_root_name = self.get_interface_prefix() + sanitize_ros_interface_name( + self._interface_name + ) + wait_state = ScxmlState( + "waiting", + body=[ + self.generate_transition_to_processing_state(client_id, req_params) + for client_id in self._clients_automata + ], + ) processing_states = [ - ScxmlState(f"processing_client_{client_id}", - body=[self.generate_transition_from_processing_state(client_id, res_params)]) - for client_id in self._clients_automata] + ScxmlState( + f"processing_client_{client_id}", + body=[self.generate_transition_from_processing_state(client_id, res_params)], + ) + for client_id in self._clients_automata + ] # Prepare the ScxmlRoot object and return it scxml_root = ScxmlRoot(scxml_root_name) scxml_root.set_data_model(ScxmlDataModel(req_fields_as_data)) diff --git a/src/as2fm/jani_generator/ros_helpers/ros_timer.py b/src/as2fm/jani_generator/ros_helpers/ros_timer.py index 72551304..120120b1 100644 --- a/src/as2fm/jani_generator/ros_helpers/ros_timer.py +++ b/src/as2fm/jani_generator/ros_helpers/ros_timer.py @@ -20,17 +20,33 @@ from math import floor, gcd from typing import List, Optional, Tuple -from as2fm.jani_generator.jani_entries import (JaniAssignment, JaniAutomaton, - JaniEdge, JaniExpression, - JaniGuard, JaniVariable) +from as2fm.jani_generator.jani_entries import ( + JaniAssignment, + JaniAutomaton, + JaniEdge, + JaniExpression, + JaniGuard, + JaniVariable, +) from as2fm.jani_generator.jani_entries.jani_expression_generator import ( - and_operator, equal_operator, lower_operator, modulo_operator, - not_operator, plus_operator) -from as2fm.scxml_converter.scxml_entries import (ScxmlAssign, ScxmlData, - ScxmlDataModel, - ScxmlExecutionBody, ScxmlIf, - ScxmlRoot, ScxmlSend, - ScxmlState, ScxmlTransition) + and_operator, + equal_operator, + lower_operator, + modulo_operator, + not_operator, + plus_operator, +) +from as2fm.scxml_converter.scxml_entries import ( + ScxmlAssign, + ScxmlData, + ScxmlDataModel, + ScxmlExecutionBody, + ScxmlIf, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) TIME_UNITS = { "s": 1, @@ -53,8 +69,7 @@ def convert_time_between_units(time: int, from_unit: str, to_unit: str) -> int: return time new_time = time * TIME_UNITS[from_unit] / TIME_UNITS[to_unit] # make sure we do not lose precision - assert int(new_time) == new_time, \ - f"Conversion from {from_unit} to {to_unit} is not exact." + assert int(new_time) == new_time, f"Conversion from {from_unit} to {to_unit} is not exact." return int(new_time) @@ -78,8 +93,7 @@ def __init__(self, name: str, freq: float) -> None: self.name = name self.freq = freq self.period = 1.0 / freq - self.period_int, self.unit, self.factor = _to_best_int_period( - self.period) + self.period_int, self.unit, self.factor = _to_best_int_period(self.period) def get_common_time_step(timers: List[RosTimer]) -> Tuple[int, str]: @@ -96,13 +110,15 @@ def get_common_time_step(timers: List[RosTimer]) -> Tuple[int, str]: if TIME_UNITS[timer.unit] < TIME_UNITS[common_unit]: common_unit = timer.unit timer_periods = [ - convert_time_between_units(timer.period_int, timer.unit, common_unit) for timer in timers] + convert_time_between_units(timer.period_int, timer.unit, common_unit) for timer in timers + ] common_period = gcd(*timer_periods) return common_period, common_unit -def make_global_timer_automaton(timers: List[RosTimer], - max_time_ns: int) -> Optional[JaniAutomaton]: +def make_global_timer_automaton( + timers: List[RosTimer], max_time_ns: int +) -> Optional[JaniAutomaton]: """ Create a global timer automaton from a list of ROS timers. @@ -113,17 +129,18 @@ def make_global_timer_automaton(timers: List[RosTimer], return None global_timer_period, global_timer_period_unit = get_common_time_step(timers) timers_map = { - timer.name: convert_time_between_units(timer.period_int, timer.unit, - global_timer_period_unit) + timer.name: convert_time_between_units( + timer.period_int, timer.unit, global_timer_period_unit + ) for timer in timers } try: - max_time = convert_time_between_units( - max_time_ns, "ns", global_timer_period_unit) + max_time = convert_time_between_units(max_time_ns, "ns", global_timer_period_unit) except AssertionError: raise ValueError( f"Max time {max_time_ns} cannot be converted to {global_timer_period_unit}. " - "The max_time must have a unit that is greater or equal to the smallest timer period.") + "The max_time must have a unit that is greater or equal to the smallest timer period." + ) # Automaton LOC_NAME = "loc" @@ -136,11 +153,9 @@ def make_global_timer_automaton(timers: List[RosTimer], # variables variable_names = [f"{timer.name}_needed" for timer in timers] - timer_automaton.add_variable( - JaniVariable("t", int, JaniExpression(0))) + timer_automaton.add_variable(JaniVariable("t", int, JaniExpression(0))) for variable_name in variable_names: - timer_automaton.add_variable( - JaniVariable(variable_name, bool, JaniExpression(True))) + timer_automaton.add_variable(JaniVariable(variable_name, bool, JaniExpression(True))) # it is initially true, because everything "x % 0 == 0" # edges @@ -148,11 +163,16 @@ def make_global_timer_automaton(timers: List[RosTimer], timer_assignments = [] for i, (timer, variable_name) in enumerate(zip(timers, variable_names)): period_in_global_unit = timers_map[timer.name] - timer_assignments.append(JaniAssignment({ - "ref": variable_name, - # t % {period_in_global_unit} == 0 - "value": equal_operator(modulo_operator("t", period_in_global_unit), 0), - "index": i+1})) # 1, because t is at index 0 + timer_assignments.append( + JaniAssignment( + { + "ref": variable_name, + # t % {period_in_global_unit} == 0 + "value": equal_operator(modulo_operator("t", period_in_global_unit), 0), + "index": i + 1, + } + ) + ) # 1, because t is at index 0 # guard for main edge # Max time not reached yet guard_exp = lower_operator("t", max_time) @@ -165,40 +185,38 @@ def make_global_timer_automaton(timers: List[RosTimer], # TODO: write test case for this (and switch to not(t1 or t2 or ... or tN) guard) assignments = [ # t = t + global_timer_period - JaniAssignment({ - "ref": "t", - "value": plus_operator("t", global_timer_period), - "index": 0}) + JaniAssignment({"ref": "t", "value": plus_operator("t", global_timer_period), "index": 0}) ] + timer_assignments - iterator_edge = JaniEdge({ - "location": LOC_NAME, - "guard": JaniGuard(guard_exp), - "destinations": [{ + iterator_edge = JaniEdge( + { "location": LOC_NAME, - "assignments": assignments - }], - "action": GLOBAL_TIMER_TICK_ACTION - } + "guard": JaniGuard(guard_exp), + "destinations": [{"location": LOC_NAME, "assignments": assignments}], + "action": GLOBAL_TIMER_TICK_ACTION, + } ) timer_automaton.add_edge(iterator_edge) # edges to sync with ROS timers for timer in timers: guard = JaniGuard(JaniExpression(f"{timer.name}_needed")) - timer_edge = JaniEdge({ - "location": LOC_NAME, - "action": f"{ROS_TIMER_RATE_EVENT_PREFIX}{timer.name}_on_receive", - "guard": guard, - "destinations": [{ + timer_edge = JaniEdge( + { "location": LOC_NAME, - "assignments": [ - JaniAssignment({ - "ref": f"{timer.name}_needed", - "value": JaniExpression(False) - }) - ] - }] - }) + "action": f"{ROS_TIMER_RATE_EVENT_PREFIX}{timer.name}_on_receive", + "guard": guard, + "destinations": [ + { + "location": LOC_NAME, + "assignments": [ + JaniAssignment( + {"ref": f"{timer.name}_needed", "value": JaniExpression(False)} + ) + ], + } + ], + } + ) timer_automaton.add_edge(timer_edge) return timer_automaton @@ -217,28 +235,39 @@ def make_global_timer_scxml(timers: List[RosTimer], max_time_ns: int) -> Optiona return None global_timer_period, global_timer_period_unit = get_common_time_step(timers) timers_map = { - timer.name: convert_time_between_units(timer.period_int, timer.unit, - global_timer_period_unit) + timer.name: convert_time_between_units( + timer.period_int, timer.unit, global_timer_period_unit + ) for timer in timers } try: - max_time = convert_time_between_units( - max_time_ns, "ns", global_timer_period_unit) + max_time = convert_time_between_units(max_time_ns, "ns", global_timer_period_unit) except AssertionError: raise ValueError( f"Max time {max_time_ns} cannot be converted to {global_timer_period_unit}. " - "The max_time must have a unit that is greater or equal to the smallest timer period.") + "The max_time must have a unit that is greater or equal to the smallest timer period." + ) scxml_root = ScxmlRoot("global_timer_automata") scxml_root.set_data_model(ScxmlDataModel([ScxmlData("current_time", "0", "int64")])) idle_state = ScxmlState("idle") global_timer_tick_body: ScxmlExecutionBody = [] - global_timer_tick_body.append(ScxmlAssign("current_time", - f"current_time + {global_timer_period}")) + global_timer_tick_body.append( + ScxmlAssign("current_time", f"current_time + {global_timer_period}") + ) for timer_name, timer_period in timers_map.items(): - global_timer_tick_body.append(ScxmlIf([(f"(current_time % {timer_period}) == 0", - [ScxmlSend(f"ros_time_rate.{timer_name}")])])) - timer_step_transition = ScxmlTransition("idle", [], f"current_time < {max_time}", - global_timer_tick_body) + global_timer_tick_body.append( + ScxmlIf( + [ + ( + f"(current_time % {timer_period}) == 0", + [ScxmlSend(f"ros_time_rate.{timer_name}")], + ) + ] + ) + ) + timer_step_transition = ScxmlTransition( + "idle", [], f"current_time < {max_time}", global_timer_tick_body + ) idle_state.add_transition(timer_step_transition) scxml_root.add_state(idle_state, initial=True) return scxml_root diff --git a/src/as2fm/jani_generator/scxml_helpers/scxml_event.py b/src/as2fm/jani_generator/scxml_helpers/scxml_event.py index e390efbd..114d99ac 100644 --- a/src/as2fm/jani_generator/scxml_helpers/scxml_event.py +++ b/src/as2fm/jani_generator/scxml_helpers/scxml_event.py @@ -20,8 +20,7 @@ import re from typing import Dict, List, Optional -from as2fm.jani_generator.ros_helpers.ros_timer import \ - ROS_TIMER_RATE_EVENT_PREFIX +from as2fm.jani_generator.ros_helpers.ros_timer import ROS_TIMER_RATE_EVENT_PREFIX class EventSender: @@ -44,9 +43,7 @@ def __init__(self, automaton_name: str, edge_action_name: str): class Event: - def __init__(self, - name: str, - data_struct: Optional[Dict[str, type]] = None): + def __init__(self, name: str, data_struct: Optional[Dict[str, type]] = None): self.name = name self.data_struct = data_struct # Map automaton -> event name @@ -93,25 +90,29 @@ def has_receivers(self) -> bool: def must_be_skipped_in_jani_conversion(self): """Indicate whether this must be considered in the conversion to jani.""" return ( - self.name.startswith(ROS_TIMER_RATE_EVENT_PREFIX) or + self.name.startswith(ROS_TIMER_RATE_EVENT_PREFIX) + or # If the event is a timer event, there is only a receiver # It is the edge that the user declared with the # `ros_rate_callback` tag. It will be handled in the # `scxml_event_processor` module differently. - self.is_bt_response_event() and len(self.senders) == 0 or - self.is_optional_action_event() and len(self.senders) == 0 + self.is_bt_response_event() + and len(self.senders) == 0 + or self.is_optional_action_event() + and len(self.senders) == 0 ) def is_bt_response_event(self): """Check if the event is a behavior tree response event (running, success, failure). They may have no sender if the plugin does not implement it.""" return self.name.startswith("bt_") and ( - self.name.endswith("_running") or - self.name.endswith("_success") or - self.name.endswith("_failure")) + self.name.endswith("_running") + or self.name.endswith("_success") + or self.name.endswith("_failure") + ) def is_optional_action_event(self): - return (self.is_action_feedback_event() or self.is_action_rejected_event()) + return self.is_action_feedback_event() or self.is_action_rejected_event() def is_action_feedback_event(self): """Check if the event is an action feedback event.""" diff --git a/src/as2fm/jani_generator/scxml_helpers/scxml_event_processor.py b/src/as2fm/jani_generator/scxml_helpers/scxml_event_processor.py index 7205e5e7..ffd58d9e 100644 --- a/src/as2fm/jani_generator/scxml_helpers/scxml_event_processor.py +++ b/src/as2fm/jani_generator/scxml_helpers/scxml_event_processor.py @@ -24,19 +24,19 @@ from as2fm.jani_generator.jani_entries.jani_automaton import JaniAutomaton from as2fm.jani_generator.jani_entries.jani_composition import JaniComposition from as2fm.jani_generator.jani_entries.jani_edge import JaniEdge -from as2fm.jani_generator.jani_entries.jani_expression_generator import \ - array_create_operator +from as2fm.jani_generator.jani_entries.jani_expression_generator import array_create_operator from as2fm.jani_generator.ros_helpers.ros_timer import ( - GLOBAL_TIMER_NAME, GLOBAL_TIMER_TICK_ACTION, ROS_TIMER_RATE_EVENT_PREFIX, - RosTimer) + GLOBAL_TIMER_NAME, + GLOBAL_TIMER_TICK_ACTION, + ROS_TIMER_RATE_EVENT_PREFIX, + RosTimer, +) from as2fm.jani_generator.scxml_helpers.scxml_event import EventsHolder def implement_scxml_events_as_jani_syncs( - events_holder: EventsHolder, - timers: List[RosTimer], - max_array_size: int, - jani_model: JaniModel) -> List[str]: + events_holder: EventsHolder, timers: List[RosTimer], max_array_size: int, jani_model: JaniModel +) -> List[str]: """ Implement the scxml events as jani syncs. @@ -71,70 +71,78 @@ def implement_scxml_events_as_jani_syncs( if add_timer_syncs: # Additional self-loop in the waiting state, allowing the global timer to tick # only if all events have been processed - event_automaton.add_edge(JaniEdge({ - "location": "waiting", - "destinations": [{ - "location": "waiting", - "probability": {"exp": 1.0}, - "assignments": [] - }], - "action": "global_timer_enable" - })) + event_automaton.add_edge( + JaniEdge( + { + "location": "waiting", + "destinations": [ + {"location": "waiting", "probability": {"exp": 1.0}, "assignments": []} + ], + "action": "global_timer_enable", + } + ) + ) timer_enable_syncs.update({event_name: "global_timer_enable"}) # Add the event handling automaton jc.add_element(event_name) if event.has_receivers(): # Add a "received" state and related transitions event_automaton.add_location("received") - event_automaton.add_edge(JaniEdge({ - "location": "waiting", - "destinations": [{ - "location": "received", - "probability": {"exp": 1.0}, - "assignments": [] - }], - "action": event_name_on_send - })) - event_automaton.add_edge(JaniEdge({ - "location": "received", - "destinations": [{ - "location": "waiting", - "probability": {"exp": 1.0}, - "assignments": [] - }], - "action": event_name_on_receive - })) + event_automaton.add_edge( + JaniEdge( + { + "location": "waiting", + "destinations": [ + {"location": "received", "probability": {"exp": 1.0}, "assignments": []} + ], + "action": event_name_on_send, + } + ) + ) + event_automaton.add_edge( + JaniEdge( + { + "location": "received", + "destinations": [ + {"location": "waiting", "probability": {"exp": 1.0}, "assignments": []} + ], + "action": event_name_on_receive, + } + ) + ) else: # Store the events without receivers events_without_receivers.append(event_name) # If there are no receivers, we add a self-loop - event_automaton.add_edge(JaniEdge({ - "location": "waiting", - "destinations": [{ - "location": "waiting", - "probability": {"exp": 1.0}, - "assignments": [] - }], - "action": event_name_on_send - })) + event_automaton.add_edge( + JaniEdge( + { + "location": "waiting", + "destinations": [ + {"location": "waiting", "probability": {"exp": 1.0}, "assignments": []} + ], + "action": event_name_on_send, + } + ) + ) jani_model.add_jani_automaton(event_automaton) # Verify and prepare the receiver syncs if event.has_receivers(): receivers_syncs = {event_name: event_name_on_receive} for receiver in event.get_receivers(): action_name = receiver.edge_action_name - assert action_name == event_name_on_receive, \ - f"Action name {action_name} must be {event_name_on_receive}." + assert ( + action_name == event_name_on_receive + ), f"Action name {action_name} must be {event_name_on_receive}." receivers_syncs.update({receiver.automaton_name: action_name}) jc.add_sync(event_name_on_receive, receivers_syncs) # Verify and prepare the sender syncs for sender in event.get_senders(): action_name = sender.edge_action_name - assert action_name == event_name_on_send, \ - f"Action name {action_name} must be {event_name_on_send}." - senders_syncs = { - event_name: action_name, - sender.automaton_name: action_name} + assert ( + action_name == event_name_on_send + ), f"Action name {action_name} must be {event_name_on_send}." + senders_syncs = {event_name: action_name, sender.automaton_name: action_name} jc.add_sync(action_name, senders_syncs) # Add the global data, if needed for p_name, p_type in event.get_data_structure().items(): @@ -147,26 +155,26 @@ def implement_scxml_events_as_jani_syncs( jani_model.add_variable( variable_name=f"{event_name}.{p_name}", variable_type=p_type, - variable_init_expression=init_value + variable_init_expression=init_value, ) # In case of arrays, add a variable representing the array size, too if is_array: jani_model.add_variable( variable_name=f"{event_name}.{p_name}.length", variable_type=int, - variable_init_expression=0 + variable_init_expression=0, ) # For each event, we add an extra boolean flag for data validity jani_model.add_variable( - variable_name=f"{event_name}.valid", - variable_type=bool, - variable_init_expression=False + variable_name=f"{event_name}.valid", variable_type=bool, variable_init_expression=False ) # Add syncs for global timer if add_timer_syncs: # Add sync action for global timer tick - jc.add_sync(GLOBAL_TIMER_TICK_ACTION, - timer_enable_syncs | {GLOBAL_TIMER_NAME: GLOBAL_TIMER_TICK_ACTION}) + jc.add_sync( + GLOBAL_TIMER_TICK_ACTION, + timer_enable_syncs | {GLOBAL_TIMER_NAME: GLOBAL_TIMER_TICK_ACTION}, + ) # Add syncs for rate timers for timer in timers: name = timer.name @@ -176,13 +184,14 @@ def implement_scxml_events_as_jani_syncs( except KeyError as e: raise RuntimeError( f"Was expecting an event for timer {name}, with name " - f"{ROS_TIMER_RATE_EVENT_PREFIX}{name}.") from e + f"{ROS_TIMER_RATE_EVENT_PREFIX}{name}." + ) from e action_name_receiver = f"{event_name}_on_receive" automaton_name = event.get_receivers()[0].automaton_name timer_trigger_syncs = { GLOBAL_TIMER_NAME: action_name_receiver, - automaton_name: action_name_receiver - } + automaton_name: action_name_receiver, + } # Make sure that all other events are processed before starting the timer callback timer_trigger_syncs.update(timer_enable_syncs) jc.add_sync(action_name_receiver, timer_trigger_syncs) diff --git a/src/as2fm/jani_generator/scxml_helpers/scxml_expression.py b/src/as2fm/jani_generator/scxml_helpers/scxml_expression.py index e6ca0e61..be790a0c 100644 --- a/src/as2fm/jani_generator/scxml_helpers/scxml_expression.py +++ b/src/as2fm/jani_generator/scxml_helpers/scxml_expression.py @@ -23,10 +23,15 @@ import esprima from as2fm.jani_generator.jani_entries.jani_convince_expression_expansion import ( - CALLABLE_OPERATORS_MAP, OPERATORS_TO_JANI_MAP) + CALLABLE_OPERATORS_MAP, + OPERATORS_TO_JANI_MAP, +) from as2fm.jani_generator.jani_entries.jani_expression import JaniExpression from as2fm.jani_generator.jani_entries.jani_expression_generator import ( - array_access_operator, array_create_operator, array_value_operator) + array_access_operator, + array_create_operator, + array_value_operator, +) from as2fm.jani_generator.jani_entries.jani_value import JaniValue JS_CALLABLE_PREFIX = "Math" @@ -39,7 +44,8 @@ class ArrayInfo: def parse_ecmascript_to_jani_expression( - ecmascript: str, array_info: Optional[ArrayInfo] = None) -> JaniExpression: + ecmascript: str, array_info: Optional[ArrayInfo] = None +) -> JaniExpression: """ Parse ecmascript to jani expression. @@ -60,7 +66,8 @@ def parse_ecmascript_to_jani_expression( def _parse_ecmascript_to_jani_expression( - ast: esprima.nodes.Script, array_info: Optional[ArrayInfo] = None) -> JaniExpression: + ast: esprima.nodes.Script, array_info: Optional[ArrayInfo] = None +) -> JaniExpression: """ Parse ecmascript to jani expression. @@ -72,21 +79,25 @@ def _parse_ecmascript_to_jani_expression( return JaniExpression(JaniValue(ast.value)) elif ast.type == "UnaryExpression": assert ast.prefix is True and ast.operator == "-", "Only unary minus is supported." - return JaniExpression({ - "op": OPERATORS_TO_JANI_MAP[ast.operator], - "left": JaniValue(0), - "right": _parse_ecmascript_to_jani_expression(ast.argument, array_info) - }) + return JaniExpression( + { + "op": OPERATORS_TO_JANI_MAP[ast.operator], + "left": JaniValue(0), + "right": _parse_ecmascript_to_jani_expression(ast.argument, array_info), + } + ) elif ast.type == "ArrayExpression": assert array_info is not None, "Array info must be provided for ArrayExpressions." entry_type: Type = array_info.array_type if len(ast.elements) == 0: - return array_create_operator("__array_iterator", array_info.array_max_size, - entry_type(0)) + return array_create_operator( + "__array_iterator", array_info.array_max_size, entry_type(0) + ) else: elements_to_add = array_info.array_max_size - len(ast.elements) - assert elements_to_add >= 0, \ - "Array size must be less than or equal to the recipient max size." + assert ( + elements_to_add >= 0 + ), "Array size must be less than or equal to the recipient max size." elements_list = [] for element in ast.elements: assert element.type == "Literal", "Array elements must be literals." @@ -96,9 +107,10 @@ def _parse_ecmascript_to_jani_expression( return array_value_operator(elements_list) elif ast.type == "Identifier": # If it is an identifier, we do not need to expand further - assert ast.name not in ("True", "False"), \ - f"Boolean {ast.name} mistaken for an identifier. "\ + assert ast.name not in ("True", "False"), ( + f"Boolean {ast.name} mistaken for an identifier. " "Did you mean to use 'true' or 'false' instead?" + ) return JaniExpression(ast.name) elif ast.type == "MemberExpression": object_expr = _parse_ecmascript_to_jani_expression(ast.object, array_info) @@ -110,37 +122,47 @@ def _parse_ecmascript_to_jani_expression( # Access to the member of an object through dot notation # Check the object_expr is an identifier object_expr_str = object_expr.as_identifier() - assert object_expr_str is not None, \ - "Only identifiers can be accessed through dot notation." - assert ast.property.type == "Identifier", \ - "Dot notation can be used only to access object's members." - field_complete_name = f'{object_expr_str}.{ast.property.name}' + assert ( + object_expr_str is not None + ), "Only identifiers can be accessed through dot notation." + assert ( + ast.property.type == "Identifier" + ), "Dot notation can be used only to access object's members." + field_complete_name = f"{object_expr_str}.{ast.property.name}" return JaniExpression(field_complete_name) elif ast.type == "ExpressionStatement": return _parse_ecmascript_to_jani_expression(ast.expression, array_info) elif ast.type == "BinaryExpression": # It is a more complex expression - assert ast.operator in OPERATORS_TO_JANI_MAP, \ - f"ecmascript to jani expression: unknown operator {ast.operator}" - return JaniExpression({ - "op": OPERATORS_TO_JANI_MAP[ast.operator], - "left": _parse_ecmascript_to_jani_expression(ast.left, array_info), - "right": _parse_ecmascript_to_jani_expression(ast.right, array_info) - }) + assert ( + ast.operator in OPERATORS_TO_JANI_MAP + ), f"ecmascript to jani expression: unknown operator {ast.operator}" + return JaniExpression( + { + "op": OPERATORS_TO_JANI_MAP[ast.operator], + "left": _parse_ecmascript_to_jani_expression(ast.left, array_info), + "right": _parse_ecmascript_to_jani_expression(ast.right, array_info), + } + ) elif ast.type == "CallExpression": # We expect function calls to be of the form Math.function_name(args) (JavaScript-like) # The "." operator is represented as a MemberExpression - assert ast.callee.type == "MemberExpression", \ - f"Functions callee is expected to be MemberExpressions, found {ast.callee}." - assert ast.callee.object.type == "Identifier", \ - f"Callee object is expected to be an Identifier, found {ast.callee.object}." - assert ast.callee.property.type == "Identifier", \ - f"Callee property is expected to be an Identifier, found {ast.callee.property}." - assert ast.callee.object.name == JS_CALLABLE_PREFIX, \ - f"Function calls prefix is expected to be 'Math', found {ast.callee.object.name}." + assert ( + ast.callee.type == "MemberExpression" + ), f"Functions callee is expected to be MemberExpressions, found {ast.callee}." + assert ( + ast.callee.object.type == "Identifier" + ), f"Callee object is expected to be an Identifier, found {ast.callee.object}." + assert ( + ast.callee.property.type == "Identifier" + ), f"Callee property is expected to be an Identifier, found {ast.callee.property}." + assert ( + ast.callee.object.name == JS_CALLABLE_PREFIX + ), f"Function calls prefix is expected to be 'Math', found {ast.callee.object.name}." function_name: str = ast.callee.property.name - assert function_name in CALLABLE_OPERATORS_MAP, \ - f"Unsupported function call {function_name}." + assert ( + function_name in CALLABLE_OPERATORS_MAP + ), f"Unsupported function call {function_name}." expression_args: List[JaniExpression] = [] for arg in ast.arguments: expression_args.append(_parse_ecmascript_to_jani_expression(arg, array_info)) diff --git a/src/as2fm/jani_generator/scxml_helpers/scxml_tags.py b/src/as2fm/jani_generator/scxml_helpers/scxml_tags.py index 793d35f3..8b7e336d 100644 --- a/src/as2fm/jani_generator/scxml_helpers/scxml_tags.py +++ b/src/as2fm/jani_generator/scxml_helpers/scxml_tags.py @@ -22,27 +22,47 @@ from hashlib import sha256 from typing import Dict, List, Optional, Set, Tuple, Union, get_args -from as2fm.as2fm_common.common import (check_value_type_compatible, - string_to_value, value_to_type) -from as2fm.as2fm_common.ecmascript_interpretation import \ - interpret_ecma_script_expr -from as2fm.jani_generator.jani_entries import (JaniAssignment, JaniAutomaton, - JaniEdge, JaniExpression, - JaniExpressionType, JaniGuard, - JaniValue, JaniVariable) +from as2fm.as2fm_common.common import check_value_type_compatible, string_to_value, value_to_type +from as2fm.as2fm_common.ecmascript_interpretation import interpret_ecma_script_expr +from as2fm.jani_generator.jani_entries import ( + JaniAssignment, + JaniAutomaton, + JaniEdge, + JaniExpression, + JaniExpressionType, + JaniGuard, + JaniValue, + JaniVariable, +) from as2fm.jani_generator.jani_entries.jani_expression_generator import ( - and_operator, max_operator, not_operator, plus_operator) + and_operator, + max_operator, + not_operator, + plus_operator, +) from as2fm.jani_generator.jani_entries.jani_utils import ( - get_all_variables_and_instantiations, get_array_type_and_size, - get_variable_type, is_variable_array) + get_all_variables_and_instantiations, + get_array_type_and_size, + get_variable_type, + is_variable_array, +) from as2fm.jani_generator.scxml_helpers.scxml_event import Event, EventsHolder from as2fm.jani_generator.scxml_helpers.scxml_expression import ( - ArrayInfo, parse_ecmascript_to_jani_expression) -from as2fm.scxml_converter.scxml_entries import (ScxmlAssign, ScxmlBase, - ScxmlData, ScxmlDataModel, - ScxmlExecutionBody, ScxmlIf, - ScxmlRoot, ScxmlSend, - ScxmlState, ScxmlTransition) + ArrayInfo, + parse_ecmascript_to_jani_expression, +) +from as2fm.scxml_converter.scxml_entries import ( + ScxmlAssign, + ScxmlBase, + ScxmlData, + ScxmlDataModel, + ScxmlExecutionBody, + ScxmlIf, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) # The resulting types from the SCXML conversion to Jani ModelTupleType = Tuple[JaniAutomaton, EventsHolder] @@ -55,19 +75,22 @@ def _hash_element(element: Union[ET.Element, ScxmlBase, List[str]]) -> str: :return: The hash of the element. """ if isinstance(element, ET.Element): - s = ET.tostring(element, encoding='unicode', method='xml') + s = ET.tostring(element, encoding="unicode", method="xml") elif isinstance(element, ScxmlBase): - s = ET.tostring(element.as_xml(), encoding='unicode', method='xml') + s = ET.tostring(element.as_xml(), encoding="unicode", method="xml") elif isinstance(element, list): - s = '/'.join(f"{element}") + s = "/".join(f"{element}") else: raise ValueError(f"Element type {type(element)} not supported.") return sha256(s.encode()).hexdigest()[:8] def _interpret_scxml_assign( - elem: ScxmlAssign, jani_automaton: JaniAutomaton, event_substitution: Optional[str] = None, - assign_index: int = 0) -> List[JaniAssignment]: + elem: ScxmlAssign, + jani_automaton: JaniAutomaton, + event_substitution: Optional[str] = None, + assign_index: int = 0, +) -> List[JaniAssignment]: """Interpret SCXML assign element. :param element: The SCXML element to interpret. @@ -75,22 +98,22 @@ def _interpret_scxml_assign( :param event_substitution: The event to substitute in the expression. :return: The action or expression to be executed. """ - assert isinstance(elem, ScxmlAssign), \ - f"Expected ScxmlAssign, got {type(elem)}" + assert isinstance(elem, ScxmlAssign), f"Expected ScxmlAssign, got {type(elem)}" assignment_target = parse_ecmascript_to_jani_expression(elem.get_location()) target_expr_type = assignment_target.get_expression_type() - is_target_array = target_expr_type == JaniExpressionType.IDENTIFIER and \ - is_variable_array(jani_automaton, assignment_target.as_identifier()) + is_target_array = target_expr_type == JaniExpressionType.IDENTIFIER and is_variable_array( + jani_automaton, assignment_target.as_identifier() + ) array_info = None if is_target_array: var_info = get_array_type_and_size(jani_automaton, assignment_target.as_identifier()) array_info = ArrayInfo(*var_info) # Check if the target is an array, in case copy the length too assignment_value = parse_ecmascript_to_jani_expression( - elem.get_expr(), array_info).replace_event(event_substitution) + elem.get_expr(), array_info + ).replace_event(event_substitution) assignments: List[JaniAssignment] = [ - JaniAssignment({"ref": assignment_target, - "value": assignment_value, "index": assign_index}) + JaniAssignment({"ref": assignment_target, "value": assignment_value, "index": assign_index}) ] # Handle array types if is_target_array: @@ -100,49 +123,60 @@ def _interpret_scxml_assign( if value_expr_type == JaniExpressionType.IDENTIFIER: # Copy one array into another: simply copy the length from the source to the target value_identifier = assignment_value.as_identifier() - assignments.append(JaniAssignment({ - "ref": f"{target_identifier}.length", - "value": JaniExpression(f"{value_identifier}.length"), - "index": assign_index - })) + assignments.append( + JaniAssignment( + { + "ref": f"{target_identifier}.length", + "value": JaniExpression(f"{value_identifier}.length"), + "index": assign_index, + } + ) + ) elif value_expr_type == JaniExpressionType.OPERATOR: # Explicit array assignment: set the new length of the variable, too # This makes sense only if the operator is of type "av" (array value) op_type, operands = assignment_value.as_operator() - assert op_type == "av", \ - f"Array assignment expects an array value (av) operator, found {op_type}." - array_length = len(string_to_value( - elem.get_expr(), get_variable_type(jani_automaton, target_identifier))) - assignments.append(JaniAssignment({ - "ref": f"{target_identifier}.length", - "value": JaniValue(array_length), - "index": assign_index - })) + assert ( + op_type == "av" + ), f"Array assignment expects an array value (av) operator, found {op_type}." + array_length = len( + string_to_value( + elem.get_expr(), get_variable_type(jani_automaton, target_identifier) + ) + ) + assignments.append( + JaniAssignment( + { + "ref": f"{target_identifier}.length", + "value": JaniValue(array_length), + "index": assign_index, + } + ) + ) else: raise ValueError( - f"Cannot assign expression {elem.get_expr()} to the array {target_identifier}.") + f"Cannot assign expression {elem.get_expr()} to the array {target_identifier}." + ) elif target_expr_type == JaniExpressionType.OPERATOR: op_type, operands = assignment_target.as_operator() if op_type == "aa": # We are dealing with an array assignment. Update the length too - array_name = operands['exp'].as_identifier() + array_name = operands["exp"].as_identifier() assert array_name is not None, "Array assignments expects an array identifier exp." array_length_id = f"{array_name}.length" - array_idx = operands['index'] + array_idx = operands["index"] # Note: we do not make sure the max length increase is 1 (that is our assumption) # One way to do it could be to set the array length to -1 in case of broken assumptions new_length = max_operator(plus_operator(array_idx, 1), array_length_id) - assignments.append(JaniAssignment({ - "ref": array_length_id, - "value": new_length, - "index": assign_index - })) + assignments.append( + JaniAssignment({"ref": array_length_id, "value": new_length, "index": assign_index}) + ) return assignments def _merge_conditions( - previous_conditions: List[JaniExpression], - new_condition: Optional[JaniExpression] = None) -> JaniExpression: + previous_conditions: List[JaniExpression], new_condition: Optional[JaniExpression] = None +) -> JaniExpression: """This merges negated conditions of previous if-clauses with the condition of the current if-clause. This is necessary to properly implement the if-else semantics of SCXML by parallel outgoing transitions in Jani. @@ -170,7 +204,7 @@ def _append_scxml_body_to_jani_automaton( hash_str: str, guard_exp: Optional[JaniExpression], trigger_event: Optional[str], - max_array_size: int + max_array_size: int, ) -> Tuple[List[JaniEdge], List[str]]: """ Converts the body of an SCXML element to a set of locations and edges. @@ -178,46 +212,46 @@ def _append_scxml_body_to_jani_automaton( They need to be added to a JaniAutomaton later on. """ edge_action_name = f"{source}-{target}-{hash_str}" - trigger_event_action = \ + trigger_event_action = ( edge_action_name if trigger_event is None else f"{trigger_event}_on_receive" + ) new_edges = [] new_locations = [] if guard_exp is not None: guard_exp.replace_event(trigger_event) # First edge. Has to evaluate guard and trigger event of original transition. - new_edges.append(JaniEdge({ - "location": source, - "action": trigger_event_action, - "guard": JaniGuard(guard_exp), - "destinations": [{ - "location": None, - "assignments": [] - }] - })) + new_edges.append( + JaniEdge( + { + "location": source, + "action": trigger_event_action, + "guard": JaniGuard(guard_exp), + "destinations": [{"location": None, "assignments": []}], + } + ) + ) for i, ec in enumerate(body): if isinstance(ec, ScxmlAssign): - assign_idx = len(new_edges[-1].destinations[0]['assignments']) + assign_idx = len(new_edges[-1].destinations[0]["assignments"]) jani_assigns = _interpret_scxml_assign(ec, jani_automaton, trigger_event, assign_idx) - new_edges[-1].destinations[0]['assignments'].extend(jani_assigns) + new_edges[-1].destinations[0]["assignments"].extend(jani_assigns) elif isinstance(ec, ScxmlSend): event_name = ec.get_event() event_send_action_name = event_name + "_on_send" - interm_loc = f'{source}-{i}-{hash_str}' - new_edges[-1].destinations[0]['location'] = interm_loc - new_edge = JaniEdge({ - "location": interm_loc, - "action": event_send_action_name, - "guard": None, - "destinations": [{ - "location": None, - "assignments": [] - }] - }) + interm_loc = f"{source}-{i}-{hash_str}" + new_edges[-1].destinations[0]["location"] = interm_loc + new_edge = JaniEdge( + { + "location": interm_loc, + "action": event_send_action_name, + "guard": None, + "destinations": [{"location": None, "assignments": []}], + } + ) data_structure_for_event: Dict[str, type] = {} for param in ec.get_params(): - param_assign_name = f'{ec.get_event()}.{param.get_name()}' - expr = param.get_expr() if param.get_expr() is not None else \ - param.get_location() + param_assign_name = f"{ec.get_event()}.{param.get_name()}" + expr = param.get_expr() if param.get_expr() is not None else param.get_location() assert expr is not None, "Expected expression or location in param." # Update the events holder # TODO: get the expected type from a jani expression, w/o setting dummy values @@ -230,45 +264,50 @@ def _append_scxml_body_to_jani_automaton( array_info = None if isinstance(res_eval_value, ArrayType): array_info = ArrayInfo(get_args(res_eval_type)[0], max_array_size) - jani_expr = parse_ecmascript_to_jani_expression( - expr, array_info).replace_event(trigger_event) - new_edge.destinations[0]['assignments'].append(JaniAssignment({ - "ref": param_assign_name, - "value": jani_expr - })) + jani_expr = parse_ecmascript_to_jani_expression(expr, array_info).replace_event( + trigger_event + ) + new_edge.destinations[0]["assignments"].append( + JaniAssignment({"ref": param_assign_name, "value": jani_expr}) + ) # TODO: Try to reuse as much as possible from _interpret_scxml_assign # If we are sending an array, set the length as well jani_expr_type = jani_expr.get_expression_type() if jani_expr_type == JaniExpressionType.IDENTIFIER: variable_name = jani_expr.as_identifier() if is_variable_array(jani_automaton, variable_name): - new_edge.destinations[0]['assignments'].append(JaniAssignment({ - "ref": f'{param_assign_name}.length', - "value": f"{variable_name}.length"})) + new_edge.destinations[0]["assignments"].append( + JaniAssignment( + { + "ref": f"{param_assign_name}.length", + "value": f"{variable_name}.length", + } + ) + ) elif jani_expr_type == JaniExpressionType.OPERATOR: op_type, operands = jani_expr.as_operator() if op_type == "av": - assert isinstance(res_eval_value, ArrayType), \ - f"Expected array value, got {res_eval_value}." - new_edge.destinations[0]['assignments'].append(JaniAssignment({ - "ref": f'{param_assign_name}.length', - "value": JaniValue(len(res_eval_value))})) - new_edge.destinations[0]['assignments'].append(JaniAssignment({ - "ref": f'{ec.get_event()}.valid', - "value": True - })) + assert isinstance( + res_eval_value, ArrayType + ), f"Expected array value, got {res_eval_value}." + new_edge.destinations[0]["assignments"].append( + JaniAssignment( + { + "ref": f"{param_assign_name}.length", + "value": JaniValue(len(res_eval_value)), + } + ) + ) + new_edge.destinations[0]["assignments"].append( + JaniAssignment({"ref": f"{ec.get_event()}.valid", "value": True}) + ) if not events_holder.has_event(event_name): - send_event = Event( - event_name, - data_structure_for_event - ) + send_event = Event(event_name, data_structure_for_event) events_holder.add_event(send_event) else: send_event = events_holder.get_event(event_name) - send_event.set_data_structure( - data_structure_for_event - ) + send_event.set_data_structure(data_structure_for_event) send_event.add_sender_edge(jani_automaton.get_name(), event_send_action_name) new_edges.append(new_edge) @@ -277,16 +316,24 @@ def _append_scxml_body_to_jani_automaton( if_prefix = f"{source}_{hash_str}_{i}" interm_loc_before = f"{if_prefix}_before_if" interm_loc_after = f"{if_prefix}_after_if" - new_edges[-1].destinations[0]['location'] = interm_loc_before + new_edges[-1].destinations[0]["location"] = interm_loc_before previous_conditions: List[JaniExpression] = [] for if_idx, (cond_str, conditional_body) in enumerate(ec.get_conditional_executions()): current_cond = parse_ecmascript_to_jani_expression(cond_str) - jani_cond = _merge_conditions( - previous_conditions, current_cond).replace_event(trigger_event) + jani_cond = _merge_conditions(previous_conditions, current_cond).replace_event( + trigger_event + ) sub_edges, sub_locs = _append_scxml_body_to_jani_automaton( - jani_automaton, events_holder, conditional_body, interm_loc_before, - interm_loc_after, '-'.join([hash_str, _hash_element(ec), str(if_idx)]), - jani_cond, None, max_array_size) + jani_automaton, + events_holder, + conditional_body, + interm_loc_before, + interm_loc_after, + "-".join([hash_str, _hash_element(ec), str(if_idx)]), + jani_cond, + None, + max_array_size, + ) new_edges.extend(sub_edges) new_locations.extend(sub_locs) previous_conditions.append(current_cond) @@ -294,38 +341,45 @@ def _append_scxml_body_to_jani_automaton( else_execution_body = ec.get_else_execution() else_execution_id = str(len(ec.get_conditional_executions())) else_execution_body = [] if else_execution_body is None else else_execution_body - jani_cond = _merge_conditions( - previous_conditions).replace_event(trigger_event) + jani_cond = _merge_conditions(previous_conditions).replace_event(trigger_event) sub_edges, sub_locs = _append_scxml_body_to_jani_automaton( - jani_automaton, events_holder, ec.get_else_execution(), interm_loc_before, - interm_loc_after, '-'.join([hash_str, _hash_element(ec), else_execution_id]), - jani_cond, None, max_array_size) + jani_automaton, + events_holder, + ec.get_else_execution(), + interm_loc_before, + interm_loc_after, + "-".join([hash_str, _hash_element(ec), else_execution_id]), + jani_cond, + None, + max_array_size, + ) new_edges.extend(sub_edges) new_locations.extend(sub_locs) # Prepare the edge from the end of the if-else block - new_edges.append(JaniEdge({ - "location": interm_loc_after, - "action": edge_action_name, - "guard": None, - "destinations": [{ - "location": None, - "assignments": [] - }] - })) + new_edges.append( + JaniEdge( + { + "location": interm_loc_after, + "action": edge_action_name, + "guard": None, + "destinations": [{"location": None, "assignments": []}], + } + ) + ) new_locations.append(interm_loc_before) new_locations.append(interm_loc_after) - new_edges[-1].destinations[0]['location'] = target + new_edges[-1].destinations[0]["location"] = target return new_edges, new_locations class BaseTag: """Base class for all SCXML tags.""" + # class function to initialize the correct tag object @staticmethod - def from_element(element: ScxmlBase, - call_trace: List[ScxmlBase], - model: ModelTupleType, - max_array_size: int) -> 'BaseTag': + def from_element( + element: ScxmlBase, call_trace: List[ScxmlBase], model: ModelTupleType, max_array_size: int + ) -> "BaseTag": """Return the correct tag object based on the xml element. :param element: The xml element representing the tag. @@ -338,10 +392,13 @@ def from_element(element: ScxmlBase, raise NotImplementedError(f"Support for SCXML type >{type(element)}< not implemented.") return CLASS_BY_TYPE[type(element)](element, call_trace, model, max_array_size) - def __init__(self, element: ScxmlBase, - call_trace: List[ScxmlBase], - model: ModelTupleType, - max_array_size: int) -> None: + def __init__( + self, + element: ScxmlBase, + call_trace: List[ScxmlBase], + model: ModelTupleType, + max_array_size: int, + ) -> None: """Initialize the ScxmlTag object from an xml element. :param element: The xml element representing the tag. @@ -356,17 +413,17 @@ def __init__(self, element: ScxmlBase, scxml_children = self.get_children() self.children = [ BaseTag.from_element(child, call_trace + [element], model, max_array_size) - for child in scxml_children] + for child in scxml_children + ] - def get_children(self) -> Union[ - List[ScxmlBase], List[ScxmlTransition], List[Union[ScxmlDataModel, ScxmlState]]]: - """Method extracting all children from a specific Scxml Tag. - """ + def get_children( + self, + ) -> Union[List[ScxmlBase], List[ScxmlTransition], List[Union[ScxmlDataModel, ScxmlState]]]: + """Method extracting all children from a specific Scxml Tag.""" raise NotImplementedError("Method get_children not implemented.") def get_tag_name(self) -> str: - """Return the tag name to match against. - """ + """Return the tag name to match against.""" return self.element.get_tag_name() def write_model(self): @@ -407,19 +464,23 @@ def write_model(self): init_value = parse_ecmascript_to_jani_expression(scxml_data.get_expr(), array_info) expr_type = type(interpret_ecma_script_expr(scxml_data.get_expr())) assert check_value_type_compatible( - interpret_ecma_script_expr(scxml_data.get_expr()), expected_type), \ - f"Invalid value for {scxml_data.get_name()}: " \ + interpret_ecma_script_expr(scxml_data.get_expr()), expected_type + ), ( + f"Invalid value for {scxml_data.get_name()}: " f"Expected type {expected_type}, got {expr_type}." + ) # TODO: Add support for lower and upper bounds self.automaton.add_variable( - JaniVariable(scxml_data.get_name(), scxml_data.get_type(), init_value)) + JaniVariable(scxml_data.get_name(), scxml_data.get_type(), init_value) + ) # In case of arrays, declare an additional 'length' variable # In this case, use dot notation, as in JS arrays if expected_type is ArrayType: init_expr = string_to_value(scxml_data.get_expr(), scxml_data.get_type()) # TODO: The length variable NEEDS to be bounded self.automaton.add_variable( - JaniVariable(f"{scxml_data.get_name()}.length", int, JaniValue(len(init_expr)))) + JaniVariable(f"{scxml_data.get_name()}.length", int, JaniValue(len(init_expr))) + ) class ScxmlTag(BaseTag): @@ -448,8 +509,16 @@ def handle_entry_state(self): onentry_body = initial_state.get_onentry() hash_str = _hash_element([source_state, target_state, "onentry"]) new_edges, new_locations = _append_scxml_body_to_jani_automaton( - self.automaton, self.events_holder, onentry_body, source_state, - target_state, hash_str, None, None, self.max_array_size) + self.automaton, + self.events_holder, + onentry_body, + source_state, + target_state, + hash_str, + None, + None, + self.max_array_size, + ) # Add the initial state and start sequence to the automaton self.automaton.add_location(source_state) self.automaton.make_initial(source_state) @@ -471,8 +540,7 @@ def add_unhandled_transitions(self): child.add_unhandled_transitions(transitions_set) def write_model(self): - assert isinstance(self.element, ScxmlRoot), \ - f"Expected ScxmlRoot, got {type(self.element)}." + assert isinstance(self.element, ScxmlRoot), f"Expected ScxmlRoot, got {type(self.element)}." self.automaton.set_name(self.element.get_name()) super().write_model() self.add_unhandled_transitions() @@ -492,8 +560,7 @@ def get_children(self) -> List[ScxmlTransition]: return [] if state_transitions is None else state_transitions def get_handled_events(self) -> Set[str]: - """Return the events that are handled by the state. - """ + """Return the events that are handled by the state.""" transition_events = set(self._events_no_condition) for event_name in self._event_to_conditions.keys(): transition_events.add(event_name) @@ -510,8 +577,9 @@ def get_guard_exp_for_prev_conditions(self, event_name: str) -> Optional[JaniExp """ previous_expressions = [ - parse_ecmascript_to_jani_expression(cond) for - cond in self._event_to_conditions.get(event_name, [])] + parse_ecmascript_to_jani_expression(cond) + for cond in self._event_to_conditions.get(event_name, []) + ] if len(previous_expressions) > 0: return _merge_conditions(previous_expressions) else: @@ -524,10 +592,19 @@ def add_unhandled_transitions(self, transitions_set: Set[str]): continue guard_exp = self.get_guard_exp_for_prev_conditions(event_name) edges, locations = _append_scxml_body_to_jani_automaton( - self.automaton, self.events_holder, [], self.element.get_id(), - self.element.get_id(), "", guard_exp, event_name, self.max_array_size) - assert len(locations) == 0 and len(edges) == 1, \ - f"Expected one edge for self-loops, got {len(edges)} edges." + self.automaton, + self.events_holder, + [], + self.element.get_id(), + self.element.get_id(), + "", + guard_exp, + event_name, + self.max_array_size, + ) + assert ( + len(locations) == 0 and len(edges) == 1 + ), f"Expected one edge for self-loops, got {len(edges)} edges." self.automaton.add_edge(edges[0]) self._events_no_condition.append(event_name) @@ -544,12 +621,14 @@ def write_model(self): transition_condition = child.element.get_condition() # Add previous conditions matching the same event trigger to the current child state child.set_previous_siblings_conditions( - self._event_to_conditions.get(transition_event, [])) + self._event_to_conditions.get(transition_event, []) + ) if transition_condition is None: # Make sure we do not have multiple transitions with no condition and same event - assert transition_event not in self._events_no_condition, \ - f"Event {transition_event} in state {self.element.get_id()} already has a" \ + assert transition_event not in self._events_no_condition, ( + f"Event {transition_event} in state {self.element.get_id()} already has a" "transition without condition." + ) self._events_no_condition.append(transition_event) else: # Update the list of conditions related to a transition trigger @@ -573,8 +652,9 @@ def set_previous_siblings_conditions(self, conditions_scripts: List[str]): self._previous_conditions = conditions_scripts def write_model(self): - assert hasattr(self, "_previous_conditions"), \ - "Make sure 'set_previous_siblings_conditions' was called before." + assert hasattr( + self, "_previous_conditions" + ), "Make sure 'set_previous_siblings_conditions' was called before." scxml_root: ScxmlRoot = self.call_trace[0] current_state: ScxmlState = self.call_trace[-1] current_state_id: str = current_state.get_id() @@ -583,8 +663,9 @@ def write_model(self): assert target_state is not None, f"Transition's target state {target_state_id} not found." event_name = self.element.get_events() # TODO: Need to extend this to support multiple events - assert len(event_name) == 0 or len(event_name) == 1, \ - "Transitions triggered by multiple events are not supported." + assert ( + len(event_name) == 0 or len(event_name) == 1 + ), "Transitions triggered by multiple events are not supported." transition_trigger_event = None if len(event_name) == 0 else event_name[0] if transition_trigger_event is not None: # TODO: Maybe get rid of one of the two event variables @@ -600,7 +681,8 @@ def write_model(self): existing_event.add_receiver(self.automaton.get_name(), action_name) # Prepare the previous expressions for the transition guard previous_expressions = [ - parse_ecmascript_to_jani_expression(cond) for cond in self._previous_conditions] + parse_ecmascript_to_jani_expression(cond) for cond in self._previous_conditions + ] if event_name is not None: for expr in previous_expressions: expr.replace_event(transition_trigger_event) @@ -628,11 +710,20 @@ def write_model(self): merged_transition_body.extend(target_state.get_onentry()) # We assume that each transition has a unique combination of the entries below # TODO: If so, we could come up with a more descriptive name, instead of hashing? - hash_str = _hash_element([ - current_state_id, target_state_id, event_name, transition_condition]) + hash_str = _hash_element( + [current_state_id, target_state_id, event_name, transition_condition] + ) new_edges, new_locations = _append_scxml_body_to_jani_automaton( - self.automaton, self.events_holder, merged_transition_body, current_state_id, - target_state_id, hash_str, guard, transition_trigger_event, self.max_array_size) + self.automaton, + self.events_holder, + merged_transition_body, + current_state_id, + target_state_id, + hash_str, + guard, + transition_trigger_event, + self.max_array_size, + ) for edge in new_edges: self.automaton.add_edge(edge) for loc in new_locations: diff --git a/src/as2fm/jani_generator/scxml_helpers/scxml_to_jani.py b/src/as2fm/jani_generator/scxml_helpers/scxml_to_jani.py index 5dfc0a14..3a9a5069 100644 --- a/src/as2fm/jani_generator/scxml_helpers/scxml_to_jani.py +++ b/src/as2fm/jani_generator/scxml_helpers/scxml_to_jani.py @@ -21,20 +21,23 @@ from as2fm.jani_generator.jani_entries.jani_automaton import JaniAutomaton from as2fm.jani_generator.jani_entries.jani_model import JaniModel -from as2fm.jani_generator.ros_helpers.ros_communication_handler import \ - remove_empty_self_loops_from_interface_handlers_in_jani -from as2fm.jani_generator.ros_helpers.ros_timer import ( - RosTimer, make_global_timer_automaton) +from as2fm.jani_generator.ros_helpers.ros_communication_handler import ( + remove_empty_self_loops_from_interface_handlers_in_jani, +) +from as2fm.jani_generator.ros_helpers.ros_timer import RosTimer, make_global_timer_automaton from as2fm.jani_generator.scxml_helpers.scxml_event import EventsHolder -from as2fm.jani_generator.scxml_helpers.scxml_event_processor import \ - implement_scxml_events_as_jani_syncs +from as2fm.jani_generator.scxml_helpers.scxml_event_processor import ( + implement_scxml_events_as_jani_syncs, +) from as2fm.jani_generator.scxml_helpers.scxml_tags import BaseTag from as2fm.scxml_converter.scxml_entries import ScxmlRoot def convert_scxml_root_to_jani_automaton( - scxml_root: ScxmlRoot, jani_automaton: JaniAutomaton, events_holder: EventsHolder, - max_array_size: int + scxml_root: ScxmlRoot, + jani_automaton: JaniAutomaton, + events_holder: EventsHolder, + max_array_size: int, ) -> None: """ Convert an SCXML element to a Jani automaton. @@ -44,15 +47,13 @@ def convert_scxml_root_to_jani_automaton( :param events_holder: The holder for the events to be implemented as Jani syncs. :param max_array_size: The max size of the arrays in the model. """ - BaseTag.from_element(scxml_root, [], (jani_automaton, - events_holder), max_array_size).write_model() + BaseTag.from_element( + scxml_root, [], (jani_automaton, events_holder), max_array_size + ).write_model() def convert_multiple_scxmls_to_jani( - scxmls: List[ScxmlRoot], - timers: List[RosTimer], - max_time_ns: int, - max_array_size: int + scxmls: List[ScxmlRoot], timers: List[RosTimer], max_time_ns: int, max_array_size: int ) -> JaniModel: """ Assemble automata from multiple SCXML files into a Jani model. @@ -70,8 +71,9 @@ def convert_multiple_scxmls_to_jani( for input_scxml in scxmls: assert isinstance(input_scxml, ScxmlRoot) scxml_root = input_scxml - assert scxml_root.is_plain_scxml(), \ - f"Input model {scxml_root.get_name()} does not contain a plain SCXML model." + assert ( + scxml_root.is_plain_scxml() + ), f"Input model {scxml_root.get_name()} does not contain a plain SCXML model." automaton = JaniAutomaton() convert_scxml_root_to_jani_automaton(scxml_root, automaton, events_holder, max_array_size) base_model.add_jani_automaton(automaton) diff --git a/src/as2fm/jani_generator/scxml_helpers/top_level_interpreter.py b/src/as2fm/jani_generator/scxml_helpers/top_level_interpreter.py index 96b0d5be..5052bbec 100644 --- a/src/as2fm/jani_generator/scxml_helpers/top_level_interpreter.py +++ b/src/as2fm/jani_generator/scxml_helpers/top_level_interpreter.py @@ -24,17 +24,15 @@ from xml.etree import ElementTree as ET from as2fm.as2fm_common.common import remove_namespace -from as2fm.jani_generator.ros_helpers.ros_action_handler import \ - RosActionHandler +from as2fm.jani_generator.ros_helpers.ros_action_handler import RosActionHandler from as2fm.jani_generator.ros_helpers.ros_communication_handler import ( - RosCommunicationHandler, generate_plain_scxml_from_handlers, - update_ros_communication_handlers) -from as2fm.jani_generator.ros_helpers.ros_service_handler import \ - RosServiceHandler -from as2fm.jani_generator.ros_helpers.ros_timer import ( - RosTimer, make_global_timer_scxml) -from as2fm.jani_generator.scxml_helpers.scxml_to_jani import \ - convert_multiple_scxmls_to_jani + RosCommunicationHandler, + generate_plain_scxml_from_handlers, + update_ros_communication_handlers, +) +from as2fm.jani_generator.ros_helpers.ros_service_handler import RosServiceHandler +from as2fm.jani_generator.ros_helpers.ros_timer import RosTimer, make_global_timer_scxml +from as2fm.jani_generator.scxml_helpers.scxml_to_jani import convert_multiple_scxmls_to_jani from as2fm.scxml_converter.bt_converter import bt_converter from as2fm.scxml_converter.scxml_entries import ScxmlRoot @@ -58,12 +56,7 @@ def _parse_time_element(time_element: ET.Element) -> int: :param time_element: The time element to interpret. :return: The interpreted time in nanoseconds. """ - TIME_MULTIPLIERS = { - "s": 1_000_000_000, - "ms": 1_000_000, - "us": 1_000, - "ns": 1 - } + TIME_MULTIPLIERS = {"s": 1_000_000_000, "ms": 1_000_000, "us": 1_000, "ns": 1} time_unit = time_element.attrib["unit"] assert time_unit in TIME_MULTIPLIERS, f"Invalid time unit: {time_unit}" return int(time_element.attrib["value"]) * TIME_MULTIPLIERS[time_unit] @@ -83,10 +76,11 @@ def parse_main_xml(xml_path: str) -> FullModel: """ # Used to generate absolute paths of scxml models folder_of_xml = os.path.dirname(xml_path) - with open(xml_path, 'r', encoding='utf-8') as f: + with open(xml_path, "r", encoding="utf-8") as f: xml = ET.parse(f) - assert remove_namespace(xml.getroot().tag) == "convince_mc_tc", \ - "The top-level XML element must be convince_mc_tc." + assert ( + remove_namespace(xml.getroot().tag) == "convince_mc_tc" + ), "The top-level XML element must be convince_mc_tc." model = FullModel() for first_level in xml.getroot(): if remove_namespace(first_level.tag) == "mc_parameters": @@ -100,8 +94,7 @@ def parse_main_xml(xml_path: str) -> FullModel: elif remove_namespace(mc_parameter.tag) == "bt_tick_rate": model.bt_tick_rate = float(mc_parameter.attrib["value"]) else: - raise ValueError( - f"Invalid mc_parameter tag: {mc_parameter.tag}") + raise ValueError(f"Invalid mc_parameter tag: {mc_parameter.tag}") elif remove_namespace(first_level.tag) == "behavior_tree": for child in first_level: if remove_namespace(child.tag) == "input": @@ -109,27 +102,23 @@ def parse_main_xml(xml_path: str) -> FullModel: assert model.bt is None, "Only one Behavior Tree is supported." model.bt = os.path.join(folder_of_xml, child.attrib["src"]) elif child.attrib["type"] == "bt-plugin-ros-scxml": - model.plugins.append( - os.path.join(folder_of_xml, child.attrib["src"])) + model.plugins.append(os.path.join(folder_of_xml, child.attrib["src"])) else: raise ValueError(f"Invalid input type: {child.attrib['type']}") else: - raise ValueError( - f"Invalid behavior_tree tag: {child.tag} != input") + raise ValueError(f"Invalid behavior_tree tag: {child.tag} != input") assert model.bt is not None, "A Behavior Tree must be defined." elif remove_namespace(first_level.tag) == "node_models": for node_model in first_level: - assert remove_namespace(node_model.tag) == "input", \ - "Only input tags are supported." - assert node_model.attrib["type"] == "ros-scxml", \ - "Only ROS-SCXML node models are supported." + assert remove_namespace(node_model.tag) == "input", "Only input tags are supported." + assert ( + node_model.attrib["type"] == "ros-scxml" + ), "Only ROS-SCXML node models are supported." model.skills.append(os.path.join(folder_of_xml, node_model.attrib["src"])) elif remove_namespace(first_level.tag) == "properties": for property in first_level: - assert remove_namespace(property.tag) == "input", \ - "Only input tags are supported." - assert property.attrib["type"] == "jani", \ - "Only Jani properties are supported." + assert remove_namespace(property.tag) == "input", "Only input tags are supported." + assert property.attrib["type"] == "jani", "Only Jani properties are supported." model.properties.append(os.path.join(folder_of_xml, property.attrib["src"])) else: raise ValueError(f"Invalid main point tag: {first_level.tag}") @@ -137,7 +126,8 @@ def parse_main_xml(xml_path: str) -> FullModel: def generate_plain_scxml_models_and_timers( - model: FullModel) -> Tuple[List[ScxmlRoot], List[RosTimer]]: + model: FullModel, +) -> Tuple[List[ScxmlRoot], List[RosTimer]]: """ Generate plain SCXML models and ROS timers from the full model dictionary. """ @@ -155,21 +145,27 @@ def generate_plain_scxml_models_and_timers( all_services: Dict[str, RosCommunicationHandler] = {} all_actions: Dict[str, RosCommunicationHandler] = {} for scxml_entry in ros_scxmls: - plain_scxmls, ros_declarations = \ - scxml_entry.to_plain_scxml_and_declarations() + plain_scxmls, ros_declarations = scxml_entry.to_plain_scxml_and_declarations() # Handle ROS timers for timer_name, timer_rate in ros_declarations._timers.items(): - assert timer_name not in all_timers, \ - f"Timer {timer_name} already exists." + assert timer_name not in all_timers, f"Timer {timer_name} already exists." all_timers.append(RosTimer(timer_name, timer_rate)) # Handle ROS Services update_ros_communication_handlers( - scxml_entry.get_name(), RosServiceHandler, all_services, - ros_declarations._service_servers, ros_declarations._service_clients) + scxml_entry.get_name(), + RosServiceHandler, + all_services, + ros_declarations._service_servers, + ros_declarations._service_clients, + ) # Handle ROS Actions update_ros_communication_handlers( - scxml_entry.get_name(), RosActionHandler, all_actions, - ros_declarations._action_servers, ros_declarations._action_clients) + scxml_entry.get_name(), + RosActionHandler, + all_actions, + ros_declarations._action_servers, + ros_declarations._action_clients, + ) plain_scxml_models.extend(plain_scxmls) # Generate sync SCXML models for services and actions for plain_scxml in generate_plain_scxml_from_handlers(all_services | all_actions): @@ -177,8 +173,9 @@ def generate_plain_scxml_models_and_timers( return plain_scxml_models, all_timers -def interpret_top_level_xml(xml_path: str, jani_file: str, - generated_scxmls_dir: Optional[str] = None): +def interpret_top_level_xml( + xml_path: str, jani_file: str, generated_scxmls_dir: Optional[str] = None +): """ Interpret the top-level XML file as a Jani model. And write it to a file. The generated Jani model is written to the same directory as the input XML file under the @@ -197,24 +194,31 @@ def interpret_top_level_xml(xml_path: str, jani_file: str, plain_scxml_dir = os.path.join(model_dir, generated_scxmls_dir) os.makedirs(plain_scxml_dir, exist_ok=True) for scxml_model in plain_scxml_models: - with open(os.path.join(plain_scxml_dir, f"{scxml_model.get_name()}.scxml"), "w", - encoding='utf-8') as f: + with open( + os.path.join(plain_scxml_dir, f"{scxml_model.get_name()}.scxml"), + "w", + encoding="utf-8", + ) as f: f.write(scxml_model.as_xml_string()) # Additionally, write the timers SCXML model global_timer_scxml = make_global_timer_scxml(all_timers, model.max_time) if global_timer_scxml is not None: - with open(os.path.join(plain_scxml_dir, global_timer_scxml.get_name() + ".scxml"), "w", - encoding='utf-8') as f: + with open( + os.path.join(plain_scxml_dir, global_timer_scxml.get_name() + ".scxml"), + "w", + encoding="utf-8", + ) as f: f.write(global_timer_scxml.as_xml_string()) jani_model = convert_multiple_scxmls_to_jani( - plain_scxml_models, all_timers, model.max_time, model.max_array_size) + plain_scxml_models, all_timers, model.max_time, model.max_array_size + ) jani_dict = jani_model.as_dict() assert len(model.properties) == 1, "Only one property is supported right now." - with open(model.properties[0], "r", encoding='utf-8') as f: + with open(model.properties[0], "r", encoding="utf-8") as f: jani_dict["properties"] = json.load(f)["properties"] output_path = os.path.join(model_dir, jani_file) - with open(output_path, "w", encoding='utf-8') as f: + with open(output_path, "w", encoding="utf-8") as f: json.dump(jani_dict, f, indent=2, ensure_ascii=False) diff --git a/src/as2fm/jani_visualizer/main.py b/src/as2fm/jani_visualizer/main.py index fb63fa7d..f62ae088 100644 --- a/src/as2fm/jani_visualizer/main.py +++ b/src/as2fm/jani_visualizer/main.py @@ -25,42 +25,45 @@ def main_jani_to_plantuml(): - parser = argparse.ArgumentParser( - description='Converts a `*.jani` file to a `*.plantuml` file.') - parser.add_argument('input_fname', type=str, help='The input jani file.') - parser.add_argument('output_plantuml_fname', type=str, help='The output plantuml file.') - parser.add_argument('output_svg_fname', type=str, help='The output svg file.') - parser.add_argument('--no-syncs', action='store_true', - help='Don\'t connects transitions that are synchronized.') - parser.add_argument('--no-assignments', action='store_true', - help='Don\'t show assignments on the edges.') - parser.add_argument('--no-guard', action='store_true', - help='Don\'t show guards on the edges.') + parser = argparse.ArgumentParser(description="Converts a `*.jani` file to a `*.plantuml` file.") + parser.add_argument("input_fname", type=str, help="The input jani file.") + parser.add_argument("output_plantuml_fname", type=str, help="The output plantuml file.") + parser.add_argument("output_svg_fname", type=str, help="The output svg file.") + parser.add_argument( + "--no-syncs", action="store_true", help="Don't connects transitions that are synchronized." + ) + parser.add_argument( + "--no-assignments", action="store_true", help="Don't show assignments on the edges." + ) + parser.add_argument("--no-guard", action="store_true", help="Don't show guards on the edges.") args = parser.parse_args() assert os.path.isfile(args.input_fname), f"File {args.input_fname} must exist." try: - with open(args.input_fname, 'r') as f: + with open(args.input_fname, "r") as f: jani_dict = json.load(f) except json.JSONDecodeError as e: raise ValueError(f"Error while reading the input file {args.input_fname}") from e - assert not os.path.isfile(args.output_plantuml_fname), \ - f"File {args.output_plantuml_fname} must not exist." + assert not os.path.isfile( + args.output_plantuml_fname + ), f"File {args.output_plantuml_fname} must not exist." - assert not os.path.isfile(args.output_svg_fname), \ - f"File {args.output_svg_fname} must not exist." + assert not os.path.isfile( + args.output_svg_fname + ), f"File {args.output_svg_fname} must not exist." pua = PlantUMLAutomata(jani_dict) puml_str = pua.to_plantuml( - with_assignments=not args.no_assignments, - with_guards=not args.no_guard, - with_syncs=not args.no_syncs - ) - with open(args.output_plantuml_fname, 'w') as f: + with_assignments=not args.no_assignments, + with_guards=not args.no_guard, + with_syncs=not args.no_syncs, + ) + with open(args.output_plantuml_fname, "w") as f: f.write(puml_str) - plantuml.PlantUML('http://www.plantuml.com/plantuml/img/').processes_file( - args.output_plantuml_fname, outfile=args.output_svg_fname) - url = plantuml.PlantUML('http://www.plantuml.com/plantuml/img/').get_url(puml_str) + plantuml.PlantUML("http://www.plantuml.com/plantuml/img/").processes_file( + args.output_plantuml_fname, outfile=args.output_svg_fname + ) + url = plantuml.PlantUML("http://www.plantuml.com/plantuml/img/").get_url(puml_str) print(f"{url=}") diff --git a/src/as2fm/jani_visualizer/visualizer.py b/src/as2fm/jani_visualizer/visualizer.py index f9fbedbc..ac276703 100644 --- a/src/as2fm/jani_visualizer/visualizer.py +++ b/src/as2fm/jani_visualizer/visualizer.py @@ -24,19 +24,18 @@ def _compact_assignments(assignments: Union[dict, list, str, int]) -> str: out: str = "" if isinstance(assignments, dict): - if 'ref' in assignments: - assert 'value' in assignments, \ - "The value must be present if ref is present." + if "ref" in assignments: + assert "value" in assignments, "The value must be present if ref is present." out += f"{assignments['ref']}=({_compact_assignments(assignments['value'])})\n" - elif 'op' in assignments: - if 'left' in assignments and 'right' in assignments: + elif "op" in assignments: + if "left" in assignments and "right" in assignments: out += f"{_compact_assignments(assignments['left'])} {assignments['op']} " out += f"{_compact_assignments(assignments['right'])}" - elif 'exp' in assignments: + elif "exp" in assignments: out += f"{assignments['op']}({_compact_assignments(assignments['exp'])})" else: raise ValueError(f"Unknown assignment: {assignments}") - elif assignments.keys() == {'exp'}: + elif assignments.keys() == {"exp"}: out += f"({_compact_assignments(assignments['exp'])})" else: raise ValueError(f"Unknown assignment: {assignments}") @@ -64,20 +63,16 @@ class PlantUMLAutomata: def __init__(self, jani_dict: dict): self.jani_dict = jani_dict - self.jani_automata = jani_dict['automata'] - assert isinstance(self.jani_automata, list), \ - "The automata must be a list." - assert len(self.jani_automata) >= 1, \ - "At least one automaton must be present." + self.jani_automata = jani_dict["automata"] + assert isinstance(self.jani_automata, list), "The automata must be a list." + assert len(self.jani_automata) >= 1, "At least one automaton must be present." def _preprocess_syncs(self): """Preprocess the synchronizations.""" - assert 'system' in self.jani_dict, \ - "The system must be present." - assert 'syncs' in self.jani_dict['system'], \ - "The system must have syncs." + assert "system" in self.jani_dict, "The system must be present." + assert "syncs" in self.jani_dict["system"], "The system must have syncs." n_syncs = len(self.jani_dict["system"]["syncs"]) - automata = [a['name'] for a in self.jani_automata] + automata = [a["name"] for a in self.jani_automata] # define colors for the syncs colors = [] @@ -91,8 +86,9 @@ def _preprocess_syncs(self): colors_per_action = {} for i, sync in enumerate(self.jani_dict["system"]["syncs"]): synchronise = sync["synchronise"] - assert len(synchronise) == len(automata), \ - "The synchronisation must have the same number of elements as the automata." + assert len(synchronise) == len( + automata + ), "The synchronisation must have the same number of elements as the automata." for action, automaton in zip(synchronise, automata): if action is None: continue @@ -101,11 +97,12 @@ def _preprocess_syncs(self): colors_per_action[automaton][action] = colors[i] return colors_per_action - def to_plantuml(self, - with_assignments: bool = False, - with_guards: bool = False, - with_syncs: bool = False, - ) -> str: + def to_plantuml( + self, + with_assignments: bool = False, + with_guards: bool = False, + with_syncs: bool = False, + ) -> str: colors_per_action = self._preprocess_syncs() puml: str = "@startuml\n" @@ -113,49 +110,44 @@ def to_plantuml(self, for automaton in self.jani_automata: # add a box for the automaton - automaton_name = automaton['name'] + automaton_name = automaton["name"] puml += f"package {automaton_name} {{\n" - for i_l, location in enumerate(automaton['locations']): - loc_name = _unique_name(automaton_name, location['name']) + for i_l, location in enumerate(automaton["locations"]): + loc_name = _unique_name(automaton_name, location["name"]) puml += f" usecase \"({i_l}) {location['name']}\" as {loc_name}\n" - for edge in automaton['edges']: - source = _unique_name(automaton_name, edge['location']) - assert len(edge['destinations']) == 1, \ - "Only one destination is supported." - destination = edge['destinations'][0] - target = _unique_name(automaton_name, destination['location']) + for edge in automaton["edges"]: + source = _unique_name(automaton_name, edge["location"]) + assert len(edge["destinations"]) == 1, "Only one destination is supported." + destination = edge["destinations"][0] + target = _unique_name(automaton_name, destination["location"]) edge_label = "" color = "#000" # black by default # Assignments if ( - with_assignments and - 'assignments' in destination and - len(destination['assignments']) > 0 + with_assignments + and "assignments" in destination + and len(destination["assignments"]) > 0 ): - assignments_str = _compact_assignments(destination['assignments']).strip() + assignments_str = _compact_assignments(destination["assignments"]).strip() edge_label += f"⏬{assignments_str}\n" # Guards - if ( - with_guards and - 'guard' in edge - ): - guard_str = _compact_assignments(edge['guard']).strip() + if with_guards and "guard" in edge: + guard_str = _compact_assignments(edge["guard"]).strip() edge_label += f"💂{guard_str}\n" # Syncs - if ( - with_syncs and - 'action' in edge - ): - action = edge['action'] - if (automaton['name'] in colors_per_action and - action in colors_per_action[automaton['name']]): - color = colors_per_action[automaton['name']][action] + if with_syncs and "action" in edge: + action = edge["action"] + if ( + automaton["name"] in colors_per_action + and action in colors_per_action[automaton["name"]] + ): + color = colors_per_action[automaton["name"]][action] edge_label += f"🔗{action}\n" - edge_label = ' \\n\\\n'.join(edge_label.split('\n')) + edge_label = " \\n\\\n".join(edge_label.split("\n")) if len(edge_label.strip()) > 0: puml += f" {source} -[{color}]-> {target} : {edge_label}\n" else: diff --git a/src/as2fm/scxml_converter/bt_converter.py b/src/as2fm/scxml_converter/bt_converter.py index 752bab7e..3dedf344 100644 --- a/src/as2fm/scxml_converter/bt_converter.py +++ b/src/as2fm/scxml_converter/bt_converter.py @@ -27,36 +27,40 @@ from btlib.bts import xml_to_networkx from btlib.common import NODE_CAT -from as2fm.scxml_converter.scxml_entries import (RESERVED_BT_PORT_NAMES, - RosRateCallback, RosTimeRate, - ScxmlRoot, ScxmlSend, - ScxmlState, ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ( + RESERVED_BT_PORT_NAMES, + RosRateCallback, + RosTimeRate, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) class BT_EVENT_TYPE(Enum): """Event types for Behavior Tree.""" + TICK = auto() SUCCESS = auto() FAILURE = auto() RUNNING = auto() @staticmethod - def from_str(event_name: str) -> 'BT_EVENT_TYPE': - event_name = event_name.replace('event=', '') - event_name = event_name.replace('"', '') - event_name = event_name.replace('bt_', '') + def from_str(event_name: str) -> "BT_EVENT_TYPE": + event_name = event_name.replace("event=", "") + event_name = event_name.replace('"', "") + event_name = event_name.replace("bt_", "") return BT_EVENT_TYPE[event_name.upper()] def bt_event_name(node_id: str, event_type: BT_EVENT_TYPE) -> str: """Return the event name for the given node and event type.""" - return f'bt_{node_id}_{event_type.name.lower()}' + return f"bt_{node_id}_{event_type.name.lower()}" def bt_converter( - bt_xml_path: str, - bt_plugins_scxml_paths: List[str], - bt_tick_rate: float + bt_xml_path: str, bt_plugins_scxml_paths: List[str], bt_tick_rate: float ) -> List[ScxmlRoot]: """ Convert a Behavior Tree (BT) in XML format to SCXML. @@ -73,35 +77,41 @@ def bt_converter( bt_plugins_scxmls = {} for path in bt_plugins_scxml_paths: - assert os.path.exists(path), f'SCXML must exist. {path} not found.' + assert os.path.exists(path), f"SCXML must exist. {path} not found." bt_plugin_scxml = ScxmlRoot.from_scxml_file(path) bt_plugin_name = bt_plugin_scxml.get_name() - assert bt_plugin_name not in bt_plugins_scxmls, \ - f'Plugin name must be unique. {bt_plugin_name} already exists.' + assert ( + bt_plugin_name not in bt_plugins_scxmls + ), f"Plugin name must be unique. {bt_plugin_name} already exists." bt_plugins_scxmls[bt_plugin_name] = bt_plugin_scxml leaf_node_ids = [] generated_scxmls: List[ScxmlRoot] = [] # Generate the instances of the plugins used in the BT for node in bt_graph.nodes: - assert 'category' in bt_graph.nodes[node], 'Node must have a category.' - if bt_graph.nodes[node]['category'] == NODE_CAT.LEAF: + assert "category" in bt_graph.nodes[node], "Node must have a category." + if bt_graph.nodes[node]["category"] == NODE_CAT.LEAF: leaf_node_ids.append(node) - assert 'ID' in bt_graph.nodes[node], 'Leaf node must have a type.' - node_type = bt_graph.nodes[node]['ID'] + assert "ID" in bt_graph.nodes[node], "Leaf node must have a type." + node_type = bt_graph.nodes[node]["ID"] node_id = node - assert node_type in bt_plugins_scxmls, \ - f'Leaf node must have a plugin. {node_type} not found.' - instance_name = f'{node_id}_{node_type}' + assert ( + node_type in bt_plugins_scxmls + ), f"Leaf node must have a plugin. {node_type} not found." + instance_name = f"{node_id}_{node_type}" scxml_plugin_instance: ScxmlRoot = deepcopy(bt_plugins_scxmls[node_type]) scxml_plugin_instance.set_name(instance_name) scxml_plugin_instance.instantiate_bt_events(node_id) - bt_ports = [(p_name, p_value) for p_name, p_value in bt_graph.nodes[node].items() - if p_name not in RESERVED_BT_PORT_NAMES] + bt_ports = [ + (p_name, p_value) + for p_name, p_value in bt_graph.nodes[node].items() + if p_name not in RESERVED_BT_PORT_NAMES + ] scxml_plugin_instance.set_bt_ports_values(bt_ports) scxml_plugin_instance.update_bt_ports_values() - assert scxml_plugin_instance.check_validity(), \ - f"Error: SCXML plugin instance {instance_name} is not valid." + assert ( + scxml_plugin_instance.check_validity() + ), f"Error: SCXML plugin instance {instance_name} is not valid." generated_scxmls.append(scxml_plugin_instance) # Generate the BT SCXML fsm_graph = Bt2FSM(bt_graph).convert() @@ -111,28 +121,28 @@ def bt_converter( state = ScxmlState(node) node_id = None if name_with_id_pattern.match(node): - node_id = int(node.split('_')[0]) + node_id = int(node.split("_")[0]) if node_id in leaf_node_ids: state.append_on_entry(ScxmlSend(bt_event_name(node_id, BT_EVENT_TYPE.TICK))) for edge in fsm_graph.edges(node): target = edge[1] transition = ScxmlTransition(target) if node_id is not None and node_id in leaf_node_ids: - if 'label' not in fsm_graph.edges[edge]: + if "label" not in fsm_graph.edges[edge]: continue - label = fsm_graph.edges[edge]['label'] - if label == 'on_success': + label = fsm_graph.edges[edge]["label"] + if label == "on_success": event_type = BT_EVENT_TYPE.SUCCESS - elif label == 'on_failure': + elif label == "on_failure": event_type = BT_EVENT_TYPE.FAILURE - elif label == 'on_running': + elif label == "on_running": event_type = BT_EVENT_TYPE.RUNNING else: - raise ValueError(f'Invalid label: {label}') + raise ValueError(f"Invalid label: {label}") event_name = bt_event_name(node_id, event_type) transition.add_event(event_name) state.add_transition(transition) - if node in ['success', 'failure', 'running']: + if node in ["success", "failure", "running"]: state.add_transition(ScxmlTransition("wait_for_tick")) bt_scxml_root.add_state(state) # TODO: Make BT rate configurable, e.g. from main.xml @@ -140,8 +150,7 @@ def bt_converter( bt_scxml_root.add_ros_declaration(rtr) wait_for_tick = ScxmlState("wait_for_tick") - wait_for_tick.add_transition( - RosRateCallback(rtr, "tick")) + wait_for_tick.add_transition(RosRateCallback(rtr, "tick")) bt_scxml_root.add_state(wait_for_tick, initial=True) assert bt_scxml_root.check_validity(), "Error: SCXML root tag is not valid." generated_scxmls.append(bt_scxml_root) diff --git a/src/as2fm/scxml_converter/scxml_entries/__init__.py b/src/as2fm/scxml_converter/scxml_entries/__init__.py index 4981bbf2..76c393ed 100644 --- a/src/as2fm/scxml_converter/scxml_entries/__init__.py +++ b/src/as2fm/scxml_converter/scxml_entries/__init__.py @@ -1,37 +1,65 @@ # isort: skip_file # Skipping file to avoid circular import problem -from .scxml_base import ScxmlBase # noqa: F401 -from .utils import CallbackType # noqa: F401 -from .bt_utils import RESERVED_BT_PORT_NAMES # noqa: F401 -from .scxml_bt import ( # noqa: F401 - BtInputPortDeclaration, BtOutputPortDeclaration, BtGetValueInputPort) # noqa: F401 -from .scxml_param import ScxmlParam # noqa: F401 -from .scxml_ros_field import RosField # noqa: F401 -from .scxml_data import ScxmlData # noqa: F401 -from .scxml_data_model import ScxmlDataModel # noqa: F401 -from .ros_utils import ScxmlRosDeclarationsContainer # noqa: F401 -from .scxml_executable_entries import ScxmlAssign, ScxmlIf, ScxmlSend # noqa: F401 +from .scxml_base import ScxmlBase # noqa: F401 +from .utils import CallbackType # noqa: F401 +from .bt_utils import RESERVED_BT_PORT_NAMES # noqa: F401 +from .scxml_bt import ( # noqa: F401 + BtInputPortDeclaration, + BtOutputPortDeclaration, + BtGetValueInputPort, +) # noqa: F401 +from .scxml_param import ScxmlParam # noqa: F401 +from .scxml_ros_field import RosField # noqa: F401 +from .scxml_data import ScxmlData # noqa: F401 +from .scxml_data_model import ScxmlDataModel # noqa: F401 +from .ros_utils import ScxmlRosDeclarationsContainer # noqa: F401 +from .scxml_executable_entries import ScxmlAssign, ScxmlIf, ScxmlSend # noqa: F401 from .scxml_executable_entries import ScxmlExecutableEntry, ScxmlExecutionBody # noqa: F401 -from .scxml_executable_entries import ( # noqa: F401 - execution_body_from_xml, as_plain_execution_body, # noqa: F401 - execution_entry_from_xml, valid_execution_body, # noqa: F401 - valid_execution_body_entry_types, instantiate_exec_body_bt_events) # noqa: F401 -from .scxml_transition import ScxmlTransition # noqa: F401 -from .scxml_state import ScxmlState # noqa: F401 -from .scxml_ros_timer import (RosTimeRate, RosRateCallback) # noqa: F401 -from .scxml_ros_topic import ( # noqa: F401 - RosTopicPublisher, RosTopicSubscriber, RosTopicCallback, RosTopicPublish) # noqa: F401 -from .scxml_ros_service import ( # noqa: F401 - RosServiceServer, RosServiceClient, RosServiceHandleRequest, # noqa: F401 - RosServiceHandleResponse, RosServiceSendRequest, RosServiceSendResponse) # noqa: F401 -from .scxml_ros_action_client import ( # noqa: F401 - RosActionClient, RosActionSendGoal, RosActionHandleGoalResponse, # noqa: F401 - RosActionHandleFeedback, RosActionHandleSuccessResult, # noqa: F401 - RosActionHandleCanceledResult, RosActionHandleAbortedResult) # noqa: F401 -from .scxml_ros_action_server import ( # noqa: F401 - RosActionServer, RosActionHandleGoalRequest, RosActionAcceptGoal, # noqa: F401 - RosActionRejectGoal, RosActionStartThread, RosActionSendFeedback, # noqa: F401 - RosActionSendSuccessResult) # noqa: F401 -from .scxml_ros_action_server_thread import ( # noqa: F401 - RosActionThread, RosActionHandleThreadStart) # noqa: F401 -from .scxml_root import ScxmlRoot # noqa: F401 +from .scxml_executable_entries import ( # noqa: F401 + execution_body_from_xml, + as_plain_execution_body, # noqa: F401 + execution_entry_from_xml, + valid_execution_body, # noqa: F401 + valid_execution_body_entry_types, + instantiate_exec_body_bt_events, +) # noqa: F401 +from .scxml_transition import ScxmlTransition # noqa: F401 +from .scxml_state import ScxmlState # noqa: F401 +from .scxml_ros_timer import RosTimeRate, RosRateCallback # noqa: F401 +from .scxml_ros_topic import ( # noqa: F401 + RosTopicPublisher, + RosTopicSubscriber, + RosTopicCallback, + RosTopicPublish, +) # noqa: F401 +from .scxml_ros_service import ( # noqa: F401 + RosServiceServer, + RosServiceClient, + RosServiceHandleRequest, # noqa: F401 + RosServiceHandleResponse, + RosServiceSendRequest, + RosServiceSendResponse, +) # noqa: F401 +from .scxml_ros_action_client import ( # noqa: F401 + RosActionClient, + RosActionSendGoal, + RosActionHandleGoalResponse, # noqa: F401 + RosActionHandleFeedback, + RosActionHandleSuccessResult, # noqa: F401 + RosActionHandleCanceledResult, + RosActionHandleAbortedResult, +) # noqa: F401 +from .scxml_ros_action_server import ( # noqa: F401 + RosActionServer, + RosActionHandleGoalRequest, + RosActionAcceptGoal, # noqa: F401 + RosActionRejectGoal, + RosActionStartThread, + RosActionSendFeedback, # noqa: F401 + RosActionSendSuccessResult, +) # noqa: F401 +from .scxml_ros_action_server_thread import ( # noqa: F401 + RosActionThread, + RosActionHandleThreadStart, +) # noqa: F401 +from .scxml_root import ScxmlRoot # noqa: F401 diff --git a/src/as2fm/scxml_converter/scxml_entries/bt_utils.py b/src/as2fm/scxml_converter/scxml_entries/bt_utils.py index 6dbbabea..cdc40250 100644 --- a/src/as2fm/scxml_converter/scxml_entries/bt_utils.py +++ b/src/as2fm/scxml_converter/scxml_entries/bt_utils.py @@ -24,7 +24,7 @@ VALID_BT_OUTPUT_PORT_TYPES: Dict[str, Type] = SCXML_DATA_STR_TO_TYPE """List of keys that are not going to be read as BT ports from the BT XML definition.""" -RESERVED_BT_PORT_NAMES = ['NAME', 'ID', 'category'] +RESERVED_BT_PORT_NAMES = ["NAME", "ID", "category"] def is_bt_event(event_name: str) -> bool: @@ -50,11 +50,13 @@ def is_blackboard_reference(port_value: str) -> bool: class BtPortsHandler: """Collector for declared BT ports and their assigned value.""" + @staticmethod def check_port_name_allowed(port_name: str) -> None: """Check if the port name is allowed.""" - assert port_name not in RESERVED_BT_PORT_NAMES, \ - f"Error: Port name {port_name} is reserved in BT" + assert ( + port_name not in RESERVED_BT_PORT_NAMES + ), f"Error: Port name {port_name} is reserved in BT" def __init__(self): # For each port name, store the port type string and value. @@ -72,23 +74,29 @@ def out_port_exists(self, port_name: str) -> bool: def declare_in_port(self, port_name: str, port_type: str) -> None: """Add an input port to the handler.""" BtPortsHandler.check_port_name_allowed(port_name) - assert not self.in_port_exists(port_name), \ - f"Error: Input port {port_name} already declared as input port." - assert not self.out_port_exists(port_name), \ - f"Error: Input port {port_name} already declared as output port." - assert port_type in VALID_BT_INPUT_PORT_TYPES, \ - f"Error: Unsupported input port type {port_type}." + assert not self.in_port_exists( + port_name + ), f"Error: Input port {port_name} already declared as input port." + assert not self.out_port_exists( + port_name + ), f"Error: Input port {port_name} already declared as output port." + assert ( + port_type in VALID_BT_INPUT_PORT_TYPES + ), f"Error: Unsupported input port type {port_type}." self._in_ports[port_name] = (port_type, None) def declare_out_port(self, port_name: str, port_type: str) -> None: """Add an output port to the handler.""" BtPortsHandler.check_port_name_allowed(port_name) - assert not self.out_port_exists(port_name), \ - f"Error: Output port {port_name} already declared as output port." - assert not self.in_port_exists(port_name), \ - f"Error: Output port {port_name} already declared as input port." - assert port_type in VALID_BT_OUTPUT_PORT_TYPES, \ - f"Error: Unsupported output port type {port_type}." + assert not self.out_port_exists( + port_name + ), f"Error: Output port {port_name} already declared as output port." + assert not self.in_port_exists( + port_name + ), f"Error: Output port {port_name} already declared as input port." + assert ( + port_type in VALID_BT_OUTPUT_PORT_TYPES + ), f"Error: Unsupported output port type {port_type}." self._out_ports[port_name] = (port_type, None) def get_port_value(self, port_name: str) -> str: @@ -102,8 +110,9 @@ def get_port_value(self, port_name: str) -> str: def get_in_port_value(self, port_name: str) -> str: """Get the value of an input port.""" - assert self.in_port_exists(port_name), \ - f"Error: Port {port_name} is not declared as input port." + assert self.in_port_exists( + port_name + ), f"Error: Port {port_name} is not declared as input port." port_value = self._in_ports[port_name][1] assert port_value is not None, f"Error: Port {port_name} has no assigned value." return port_value @@ -125,16 +134,19 @@ def set_port_value(self, port_name: str, port_value: str) -> None: def _set_in_port_value(self, port_name: str, port_value: str): """Set the value of an input port.""" - assert self.in_port_exists(port_name), \ - f"Error: Port {port_name} is not declared as input port." - assert self._in_ports[port_name][1] is None, \ - f"Error: Port {port_name} already has a value assigned." + assert self.in_port_exists( + port_name + ), f"Error: Port {port_name} is not declared as input port." + assert ( + self._in_ports[port_name][1] is None + ), f"Error: Port {port_name} already has a value assigned." port_type = self._in_ports[port_name][0] # Ensure this is not a Blackboard variable reference: currently not supported if is_blackboard_reference(port_value): raise NotImplementedError( f"Error: {port_value} assigns a Blackboard variable to {port_name}. " - "This is not yet supported.") + "This is not yet supported." + ) self._in_ports[port_name] = (port_type, port_value) def _set_out_port_value(self, port_name: str, port_value: str): diff --git a/src/as2fm/scxml_converter/scxml_entries/ros_utils.py b/src/as2fm/scxml_converter/scxml_entries/ros_utils.py index f63af4b1..c881d459 100644 --- a/src/as2fm/scxml_converter/scxml_entries/ros_utils.py +++ b/src/as2fm/scxml_converter/scxml_entries/ros_utils.py @@ -20,15 +20,18 @@ from as2fm.scxml_converter.scxml_entries import RosField, ScxmlBase from as2fm.scxml_converter.scxml_entries.utils import all_non_empty_strings -MSG_TYPE_SUBSTITUTIONS = { - 'boolean': 'bool', - 'sequence': 'int32[]' -} +MSG_TYPE_SUBSTITUTIONS = {"boolean": "bool", "sequence": "int32[]"} -BASIC_FIELD_TYPES = ['boolean', - 'int8', 'int16', 'int32', 'int64', - 'float', 'double', - 'sequence'] +BASIC_FIELD_TYPES = [ + "boolean", + "int8", + "int16", + "int32", + "int64", + "float", + "double", + "sequence", +] """Container for the ROS interface name (e.g. topic or service name) and the related type""" @@ -46,10 +49,13 @@ def is_ros_type_known(type_definition: str, ros_interface: str) -> bool: interface_ns, interface_type = type_definition.split("/") if len(interface_ns) == 0 or len(interface_type) == 0: return False - assert ros_interface in ["msg", "srv", "action"], \ - "Error: SCXML ROS declarations: unknown ROS interface." + assert ros_interface in [ + "msg", + "srv", + "action", + ], "Error: SCXML ROS declarations: unknown ROS interface." try: - interface_importer = __import__(interface_ns + f'.{ros_interface}', fromlist=['']) + interface_importer = __import__(interface_ns + f".{ros_interface}", fromlist=[""]) _ = getattr(interface_importer, interface_type) except (ImportError, AttributeError): print(f"Error: SCXML ROS declarations: topic type {type_definition} not found.") @@ -78,9 +84,10 @@ def extract_params_from_ros_type(ros_interface_type: Type[Any]) -> Dict[str, str """ fields = ros_interface_type.get_fields_and_field_types() for key in fields.keys(): - assert fields[key] in BASIC_FIELD_TYPES, \ - f"Error: SCXML ROS declarations: {ros_interface_type} {key} field is " \ + assert fields[key] in BASIC_FIELD_TYPES, ( + f"Error: SCXML ROS declarations: {ros_interface_type} {key} field is " f"of type {fields[key]}, that is not supported." + ) fields[key] = MSG_TYPE_SUBSTITUTIONS.get(fields[key], fields[key]) return fields @@ -106,10 +113,11 @@ def get_srv_type_params(service_definition: str) -> Tuple[Dict[str, str], Dict[s """ Get the fields of a service request and response as pairs of name and type objects. """ - assert is_srv_type_known(service_definition), \ - f"Error: SCXML ROS declarations: service type {service_definition} not found." + assert is_srv_type_known( + service_definition + ), f"Error: SCXML ROS declarations: service type {service_definition} not found." interface_ns, interface_type = service_definition.split("/") - srv_module = __import__(interface_ns + '.srv', fromlist=['']) + srv_module = __import__(interface_ns + ".srv", fromlist=[""]) srv_class = getattr(srv_module, interface_type) # TODO: Fields can be nested. Look AS2FM/scxml_converter/src/scxml_converter/scxml_converter.py @@ -119,15 +127,17 @@ def get_srv_type_params(service_definition: str) -> Tuple[Dict[str, str], Dict[s return req_fields, res_fields -def get_action_type_params(action_definition: str - ) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]: +def get_action_type_params( + action_definition: str, +) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]: """ Get the fields of an action goal, feedback and result as pairs of name and type objects. """ - assert is_action_type_known(action_definition), \ - f"Error: SCXML ROS declarations: action type {action_definition} not found." + assert is_action_type_known( + action_definition + ), f"Error: SCXML ROS declarations: action type {action_definition} not found." interface_ns, interface_type = action_definition.split("/") - action_module = __import__(interface_ns + '.action', fromlist=['']) + action_module = __import__(interface_ns + ".action", fromlist=[""]) action_class = getattr(action_module, interface_type) action_goal_fields = extract_params_from_ros_type(action_class.Goal) action_feedback_fields = extract_params_from_ros_type(action_class.Feedback) @@ -142,14 +152,17 @@ def get_action_goal_id_definition() -> Tuple[str, str]: def sanitize_ros_interface_name(interface_name: str) -> str: """Replace slashes in a ROS interface name.""" - assert isinstance(interface_name, str), \ - "Error: ROS interface sanitizer: interface name must be a string." + assert isinstance( + interface_name, str + ), "Error: ROS interface sanitizer: interface name must be a string." # Remove potential prepended slash interface_name = interface_name.removeprefix("/") - assert len(interface_name) > 0, \ - "Error: ROS interface sanitizer: interface name must not be empty." - assert interface_name.count(" ") == 0, \ - "Error: ROS interface sanitizer: interface name must not contain spaces." + assert ( + len(interface_name) > 0 + ), "Error: ROS interface sanitizer: interface name must not be empty." + assert ( + interface_name.count(" ") == 0 + ), "Error: ROS interface sanitizer: interface name must not contain spaces." return interface_name.replace("/", "__") @@ -231,8 +244,10 @@ def generate_action_feedback_event(action_name: str) -> str: def generate_action_feedback_handle_event(action_name: str, automaton_name: str) -> str: """Generate the name of the event that handles a feedback in an action client.""" - return f"action_{sanitize_ros_interface_name(action_name)}_" \ - f"feedback_handle_client_{automaton_name}" + return ( + f"action_{sanitize_ros_interface_name(action_name)}_" + f"feedback_handle_client_{automaton_name}" + ) def generate_action_result_event(action_name: str) -> str: @@ -242,8 +257,10 @@ def generate_action_result_event(action_name: str) -> str: def generate_action_result_handle_event(action_name: str, automaton_name: str) -> str: """Generate the name of the event that handles a result in an action client.""" - return f"action_{sanitize_ros_interface_name(action_name)}_" \ - f"result_handle_client_{automaton_name}" + return ( + f"action_{sanitize_ros_interface_name(action_name)}_" + f"result_handle_client_{automaton_name}" + ) class ScxmlRosDeclarationsContainer: @@ -278,50 +295,64 @@ def append_ros_declaration(self, scxml_ros_declaration: ScxmlBase) -> None: :param scxml_ros_declaration: The ROS declaration to add (inheriting from RosDeclaration). """ - from as2fm.scxml_converter.scxml_entries.scxml_ros_action_client import \ - RosActionClient - from as2fm.scxml_converter.scxml_entries.scxml_ros_action_server import \ - RosActionServer - from as2fm.scxml_converter.scxml_entries.scxml_ros_base import \ - RosDeclaration + from as2fm.scxml_converter.scxml_entries.scxml_ros_action_client import RosActionClient + from as2fm.scxml_converter.scxml_entries.scxml_ros_action_server import RosActionServer + from as2fm.scxml_converter.scxml_entries.scxml_ros_base import RosDeclaration from as2fm.scxml_converter.scxml_entries.scxml_ros_service import ( - RosServiceClient, RosServiceServer) - from as2fm.scxml_converter.scxml_entries.scxml_ros_timer import \ - RosTimeRate + RosServiceClient, + RosServiceServer, + ) + from as2fm.scxml_converter.scxml_entries.scxml_ros_timer import RosTimeRate from as2fm.scxml_converter.scxml_entries.scxml_ros_topic import ( - RosTopicPublisher, RosTopicSubscriber) - assert isinstance(scxml_ros_declaration, RosDeclaration), \ - f"Error: SCXML ROS declarations: {type(scxml_ros_declaration)} isn't a ROS declaration." + RosTopicPublisher, + RosTopicSubscriber, + ) + + assert isinstance( + scxml_ros_declaration, RosDeclaration + ), f"Error: SCXML ROS declarations: {type(scxml_ros_declaration)} isn't a ROS declaration." if isinstance(scxml_ros_declaration, RosTimeRate): - self._append_timer(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_rate()) + self._append_timer(scxml_ros_declaration.get_name(), scxml_ros_declaration.get_rate()) elif isinstance(scxml_ros_declaration, RosTopicPublisher): - self._append_publisher(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_publisher( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) elif isinstance(scxml_ros_declaration, RosTopicSubscriber): - self._append_subscriber(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_subscriber( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) elif isinstance(scxml_ros_declaration, RosServiceServer): - self._append_service_server(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_service_server( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) elif isinstance(scxml_ros_declaration, RosServiceClient): - self._append_service_client(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_service_client( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) elif isinstance(scxml_ros_declaration, RosActionServer): - self._append_action_server(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_action_server( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) elif isinstance(scxml_ros_declaration, RosActionClient): - self._append_action_client(scxml_ros_declaration.get_name(), - scxml_ros_declaration.get_interface_name(), - scxml_ros_declaration.get_interface_type()) + self._append_action_client( + scxml_ros_declaration.get_name(), + scxml_ros_declaration.get_interface_name(), + scxml_ros_declaration.get_interface_type(), + ) else: - raise NotImplementedError(f"Error: SCXML ROS declaration type " - f"{type(scxml_ros_declaration)}.") + raise NotImplementedError( + f"Error: SCXML ROS declaration type " f"{type(scxml_ros_declaration)}." + ) def _append_publisher(self, pub_name: str, topic_name: str, topic_type: str) -> None: """ @@ -331,10 +362,12 @@ def _append_publisher(self, pub_name: str, topic_name: str, topic_type: str) -> :param topic_name: Name of the topic to publish to. :param topic_type: Type of the message to publish. """ - assert all_non_empty_strings(pub_name, topic_name, topic_type), \ - "Error: ROS declarations: publisher name, topic name and type must be strings." - assert pub_name not in self._publishers, \ - f"Error: ROS declarations: topic publisher {pub_name} already declared." + assert all_non_empty_strings( + pub_name, topic_name, topic_type + ), "Error: ROS declarations: publisher name, topic name and type must be strings." + assert ( + pub_name not in self._publishers + ), f"Error: ROS declarations: topic publisher {pub_name} already declared." self._publishers[pub_name] = (topic_name, topic_type) def _append_subscriber(self, sub_name: str, topic_name: str, topic_type: str) -> None: @@ -345,14 +378,17 @@ def _append_subscriber(self, sub_name: str, topic_name: str, topic_type: str) -> :param topic_name: Name of the topic to subscribe to. :param topic_type: Type of the message to subscribe to. """ - assert all_non_empty_strings(sub_name, topic_name, topic_type), \ - "Error: ROS declarations: subscriber name, topic name and type must be strings." - assert sub_name not in self._subscribers, \ - f"Error: ROS declarations: topic subscriber {sub_name} already declared." + assert all_non_empty_strings( + sub_name, topic_name, topic_type + ), "Error: ROS declarations: subscriber name, topic name and type must be strings." + assert ( + sub_name not in self._subscribers + ), f"Error: ROS declarations: topic subscriber {sub_name} already declared." self._subscribers[sub_name] = (topic_name, topic_type) def _append_service_client( - self, client_name: str, service_name: str, service_type: str) -> None: + self, client_name: str, service_name: str, service_type: str + ) -> None: """ Add a service client to the container. @@ -360,14 +396,17 @@ def _append_service_client( :param service_name: Name of the service to call. :param service_type: Type of data used in the service communication. """ - assert all_non_empty_strings(client_name, service_name, service_type), \ - "Error: ROS declarations: client name, service name and type must be strings." - assert client_name not in self._service_clients, \ - f"Error: ROS declarations: service client {client_name} already declared." + assert all_non_empty_strings( + client_name, service_name, service_type + ), "Error: ROS declarations: client name, service name and type must be strings." + assert ( + client_name not in self._service_clients + ), f"Error: ROS declarations: service client {client_name} already declared." self._service_clients[client_name] = (service_name, service_type) def _append_service_server( - self, server_name: str, service_name: str, service_type: str) -> None: + self, server_name: str, service_name: str, service_type: str + ) -> None: """ Add a service server to the container. @@ -375,32 +414,40 @@ def _append_service_server( :param service_name: Name of the provided service (what the client needs to call). :param service_type: Type of data used in the service communication. """ - assert all_non_empty_strings(server_name, service_name, service_type), \ - "Error: ROS declarations: server name, service name and type must be strings." - assert server_name not in self._service_servers, \ - f"Error: ROS declarations: service server {server_name} already declared." + assert all_non_empty_strings( + server_name, service_name, service_type + ), "Error: ROS declarations: server name, service name and type must be strings." + assert ( + server_name not in self._service_servers + ), f"Error: ROS declarations: service server {server_name} already declared." self._service_servers[server_name] = (service_name, service_type) def _append_action_client(self, client_name: str, action_name: str, action_type: str) -> None: - assert all_non_empty_strings(client_name, action_name, action_type), \ - "Error: ROS declarations: client name, action name and type must be strings." - assert client_name not in self._action_clients, \ - f"Error: ROS declarations: action client {client_name} already declared." + assert all_non_empty_strings( + client_name, action_name, action_type + ), "Error: ROS declarations: client name, action name and type must be strings." + assert ( + client_name not in self._action_clients + ), f"Error: ROS declarations: action client {client_name} already declared." self._action_clients[client_name] = (action_name, action_type) def _append_action_server(self, server_name: str, action_name: str, action_type: str) -> None: - assert all_non_empty_strings(server_name, action_name, action_type), \ - "Error: ROS declarations: server name, action name and type must be strings." - assert server_name not in self._action_servers, \ - f"Error: ROS declarations: action server {server_name} already declared." + assert all_non_empty_strings( + server_name, action_name, action_type + ), "Error: ROS declarations: server name, action name and type must be strings." + assert ( + server_name not in self._action_servers + ), f"Error: ROS declarations: action server {server_name} already declared." self._action_servers[server_name] = (action_name, action_type) def _append_timer(self, timer_name: str, timer_rate: float) -> None: assert isinstance(timer_name, str), "Error: ROS declarations: timer name must be a string." - assert isinstance(timer_rate, float) and timer_rate > 0, \ - "Error: ROS declarations: timer rate must be a positive number." - assert timer_name not in self._timers, \ - f"Error: ROS declarations: timer {timer_name} already declared." + assert ( + isinstance(timer_rate, float) and timer_rate > 0 + ), "Error: ROS declarations: timer rate must be a positive number." + assert ( + timer_name not in self._timers + ), f"Error: ROS declarations: timer {timer_name} already declared." self._timers[timer_name] = timer_rate def is_publisher_defined(self, pub_name: str) -> bool: @@ -433,36 +480,41 @@ def get_publisher_info(self, pub_name: str) -> Tuple[str, str]: def get_subscriber_info(self, sub_name: str) -> Tuple[str, str]: """Provide a subscriber topic name and type""" sub_info = self._subscribers.get(sub_name) - assert sub_info is not None, \ - f"Error: SCXML ROS declarations: unknown subscriber {sub_name}." + assert ( + sub_info is not None + ), f"Error: SCXML ROS declarations: unknown subscriber {sub_name}." return sub_info def get_service_server_info(self, server_name: str) -> Tuple[str, str]: """Provide a server's service name and type""" server_info = self._service_servers.get(server_name) - assert server_info is not None, \ - f"Error: SCXML ROS declarations: unknown service server {server_name}." + assert ( + server_info is not None + ), f"Error: SCXML ROS declarations: unknown service server {server_name}." return server_info def get_service_client_info(self, client_name: str) -> Tuple[str, str]: """Provide a client's service name and type""" client_info = self._service_clients.get(client_name) - assert client_info is not None, \ - f"Error: SCXML ROS declarations: unknown service client {client_name}." + assert ( + client_info is not None + ), f"Error: SCXML ROS declarations: unknown service client {client_name}." return client_info def get_action_server_info(self, server_name: str) -> Tuple[str, str]: """Given an action server name, provide the related action name and type.""" server_info = self._action_servers.get(server_name) - assert server_info is not None, \ - f"Error: SCXML ROS declarations: unknown action server {server_name}." + assert ( + server_info is not None + ), f"Error: SCXML ROS declarations: unknown action server {server_name}." return server_info def get_action_client_info(self, client_name: str) -> Tuple[str, str]: """Given an action client name, provide the related action name and type.""" client_info = self._action_clients.get(client_name) - assert client_info is not None, \ - f"Error: SCXML ROS declarations: unknown action client {client_name}." + assert ( + client_info is not None + ), f"Error: SCXML ROS declarations: unknown action client {client_name}." return client_info def get_timers(self) -> Dict[str, float]: @@ -486,8 +538,7 @@ def check_valid_srv_res_fields(self, server_name: str, ros_fields: List[RosField return False return True - def check_valid_action_goal_fields( - self, alias_name: str, ros_fields: List[RosField]) -> bool: + def check_valid_action_goal_fields(self, alias_name: str, ros_fields: List[RosField]) -> bool: """ Check if the provided fields match with the action type's goal entries. @@ -497,8 +548,9 @@ def check_valid_action_goal_fields( if self.is_action_client_defined(alias_name): action_type = self.get_action_client_info(alias_name)[1] else: - assert self.is_action_server_defined(alias_name), \ - f"Error: SCXML ROS declarations: unknown action {alias_name}." + assert self.is_action_server_defined( + alias_name + ), f"Error: SCXML ROS declarations: unknown action {alias_name}." action_type = self.get_action_server_info(alias_name)[1] goal_fields = get_action_type_params(action_type)[0] if not check_all_fields_known(ros_fields, goal_fields): @@ -507,7 +559,8 @@ def check_valid_action_goal_fields( return True def check_valid_action_feedback_fields( - self, server_name: str, ros_fields: List[RosField]) -> bool: + self, server_name: str, ros_fields: List[RosField] + ) -> bool: """ Check if the provided fields match with the action type's feedback entries. @@ -517,13 +570,16 @@ def check_valid_action_feedback_fields( _, action_type = self.get_action_server_info(server_name) _, feedback_fields, _ = get_action_type_params(action_type) if not check_all_fields_known(ros_fields, feedback_fields): - print(f"Error: SCXML ROS declarations: Action feedback {server_name} " - "has invalid fields.") + print( + f"Error: SCXML ROS declarations: Action feedback {server_name} " + "has invalid fields." + ) return False return True def check_valid_action_result_fields( - self, server_name: str, ros_fields: List[RosField]) -> bool: + self, server_name: str, ros_fields: List[RosField] + ) -> bool: """ Check if the provided fields match with the action type's result entries. diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_bt.py b/src/as2fm/scxml_converter/scxml_entries/scxml_bt.py index b5ce7616..da52e4bb 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_bt.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_bt.py @@ -22,8 +22,7 @@ from as2fm.scxml_converter.scxml_entries import ScxmlBase from as2fm.scxml_converter.scxml_entries.utils import is_non_empty_string -from as2fm.scxml_converter.scxml_entries.xml_utils import (assert_xml_tag_ok, - get_xml_argument) +from as2fm.scxml_converter.scxml_entries.xml_utils import assert_xml_tag_ok, get_xml_argument class BtInputPortDeclaration(ScxmlBase): @@ -47,8 +46,9 @@ def __init__(self, key_str: str, type_str: str): self._type = type_str def check_validity(self) -> bool: - return is_non_empty_string(BtInputPortDeclaration, "key", self._key) and \ - is_non_empty_string(BtInputPortDeclaration, "type", self._type) + return is_non_empty_string( + BtInputPortDeclaration, "key", self._key + ) and is_non_empty_string(BtInputPortDeclaration, "type", self._type) def get_key_name(self) -> str: return self._key @@ -63,7 +63,8 @@ def as_plain_scxml(self, _) -> ScxmlBase: def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML BT Input Port: invalid parameters." xml_bt_in_port = ET.Element( - BtInputPortDeclaration.get_tag_name(), {"key": self._key, "type": self._type}) + BtInputPortDeclaration.get_tag_name(), {"key": self._key, "type": self._type} + ) return xml_bt_in_port @@ -88,8 +89,9 @@ def __init__(self, key_str: str, type_str: str): self._type = type_str def check_validity(self) -> bool: - return is_non_empty_string(BtOutputPortDeclaration, "key", self._key) and \ - is_non_empty_string(BtOutputPortDeclaration, "type", self._type) + return is_non_empty_string( + BtOutputPortDeclaration, "key", self._key + ) and is_non_empty_string(BtOutputPortDeclaration, "type", self._type) def get_key_name(self) -> str: return self._key @@ -104,7 +106,8 @@ def as_plain_scxml(self, _) -> ScxmlBase: def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML BT Input Port: invalid parameters." xml_bt_in_port = ET.Element( - BtOutputPortDeclaration.get_tag_name(), {"key": self._key, "type": self._type}) + BtOutputPortDeclaration.get_tag_name(), {"key": self._key, "type": self._type} + ) return xml_bt_in_port @@ -138,8 +141,7 @@ def as_plain_scxml(self, _) -> ScxmlBase: def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML BT Input Port: invalid parameters." - xml_bt_in_port = ET.Element( - BtGetValueInputPort.get_tag_name(), {"key": self._key}) + xml_bt_in_port = ET.Element(BtGetValueInputPort.get_tag_name(), {"key": self._key}) return xml_bt_in_port diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_data.py b/src/as2fm/scxml_converter/scxml_entries/scxml_data.py index 25531ebf..689f3194 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_data.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_data.py @@ -25,10 +25,16 @@ from as2fm.scxml_converter.scxml_entries import BtGetValueInputPort, ScxmlBase from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler from as2fm.scxml_converter.scxml_entries.utils import ( - convert_string_to_type, get_array_max_size, get_data_type_from_string, - is_non_empty_string) + convert_string_to_type, + get_array_max_size, + get_data_type_from_string, + is_non_empty_string, +) from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_xml_argument, read_value_from_xml_arg_or_child) + assert_xml_tag_ok, + get_xml_argument, + read_value_from_xml_arg_or_child, +) ValidExpr = Union[BtGetValueInputPort, str, int, float] @@ -51,7 +57,8 @@ def get_tag_name() -> str: @staticmethod def _interpret_type_from_comment_above( - comment_above: Optional[str]) -> Optional[Tuple[str, str]]: + comment_above: Optional[str], + ) -> Optional[Tuple[str, str]]: """Interpret the type of the data from the comment above the data tag. :param comment_above: The comment above the data tag (optional) @@ -60,7 +67,7 @@ def _interpret_type_from_comment_above( if comment_above is None: return None # match string inside xml comment brackets - type_match = re.search(r'TYPE\ (.+):(.+)', comment_above.strip()) + type_match = re.search(r"TYPE\ (.+):(.+)", comment_above.strip()) if type_match is None: return None return type_match.group(1), type_match.group(2) @@ -74,21 +81,30 @@ def from_xml_tree(xml_tree: ET.Element, comment_above: Optional[str] = None) -> if data_type is None: comment_tuple = ScxmlData._interpret_type_from_comment_above(comment_above) assert comment_tuple is not None, f"Error: SCXML data: type of {data_id} not found." - assert comment_tuple[0] == data_id, \ - "Error: SCXML data: unexpected ID in type in comment " \ + assert comment_tuple[0] == data_id, ( + "Error: SCXML data: unexpected ID in type in comment " f"({comment_tuple[0]}!={data_id})." + ) data_type = comment_tuple[1] data_expr = read_value_from_xml_arg_or_child( - ScxmlData, xml_tree, "expr", (BtGetValueInputPort, str)) + ScxmlData, xml_tree, "expr", (BtGetValueInputPort, str) + ) lower_bound = read_value_from_xml_arg_or_child( - ScxmlData, xml_tree, "lower_bound_incl", (BtGetValueInputPort, str), none_allowed=True) + ScxmlData, xml_tree, "lower_bound_incl", (BtGetValueInputPort, str), none_allowed=True + ) upper_bound = read_value_from_xml_arg_or_child( - ScxmlData, xml_tree, "upper_bound_incl", (BtGetValueInputPort, str), none_allowed=True) + ScxmlData, xml_tree, "upper_bound_incl", (BtGetValueInputPort, str), none_allowed=True + ) return ScxmlData(data_id, data_expr, data_type, lower_bound, upper_bound) def __init__( - self, id_: str, expr: ValidExpr, data_type: str, - lower_bound: Optional[ValidExpr] = None, upper_bound: Optional[ValidExpr] = None): + self, + id_: str, + expr: ValidExpr, + data_type: str, + lower_bound: Optional[ValidExpr] = None, + upper_bound: Optional[ValidExpr] = None, + ): self._id: str = id_ self._expr: str = expr self._data_type: str = data_type @@ -100,13 +116,15 @@ def get_name(self) -> str: def get_type(self) -> type: python_type = get_data_type_from_string(self._data_type) - assert python_type is not None, \ - f"Error: SCXML data: '{self._id}' has unknown type '{self._data_type}'." + assert ( + python_type is not None + ), f"Error: SCXML data: '{self._id}' has unknown type '{self._data_type}'." return python_type def get_array_max_size(self) -> Optional[int]: - assert is_array_type(self.get_type()), \ - f"Error: SCXML data: '{self._id}' type is not an array." + assert is_array_type( + self.get_type() + ), f"Error: SCXML data: '{self._id}' type is not an array." return get_array_max_size(self._data_type) def get_expr(self) -> str: @@ -117,8 +135,10 @@ def check_valid_bounds(self) -> bool: # Nothing to check return True if self.get_type() not in (float, int): - print(f"Error: SCXML data: '{self._id}' has bounds but has type {self._data_type}, " - "not a number.") + print( + f"Error: SCXML data: '{self._id}' has bounds but has type {self._data_type}, " + "not a number." + ) return False lower_bound = None upper_bound = None @@ -128,8 +148,10 @@ def check_valid_bounds(self) -> bool: upper_bound = convert_string_to_type(self._upper_bound, self._data_type) if all(bound is not None for bound in [lower_bound, upper_bound]): if lower_bound > upper_bound: - print(f"Error: SCXML data: 'lower_bound_incl' {lower_bound} is not smaller " - f"than 'upper_bound_incl' {upper_bound}.") + print( + f"Error: SCXML data: 'lower_bound_incl' {lower_bound} is not smaller " + f"than 'upper_bound_incl' {upper_bound}." + ) return False return True @@ -144,8 +166,9 @@ def check_validity(self) -> bool: def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid data object." - xml_data = ET.Element(ScxmlData.get_tag_name(), - {"id": self._id, "expr": self._expr, "type": self._data_type}) + xml_data = ET.Element( + ScxmlData.get_tag_name(), {"id": self._id, "expr": self._expr, "type": self._data_type} + ) if self._lower_bound is not None: xml_data.set("lower_bound_incl", str(self._lower_bound)) if self._upper_bound is not None: diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_executable_entries.py b/src/as2fm/scxml_converter/scxml_entries/scxml_executable_entries.py index 4a9115e8..972fd37f 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_executable_entries.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_executable_entries.py @@ -20,26 +20,37 @@ from typing import Dict, List, Optional, Tuple, Union, get_args from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (BtGetValueInputPort, - ScxmlBase, ScxmlParam, - ScxmlRosDeclarationsContainer) -from as2fm.scxml_converter.scxml_entries.bt_utils import (BtPortsHandler, - is_bt_event, - replace_bt_event) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - get_plain_expression, - is_non_empty_string) +from as2fm.scxml_converter.scxml_entries import ( + BtGetValueInputPort, + ScxmlBase, + ScxmlParam, + ScxmlRosDeclarationsContainer, +) +from as2fm.scxml_converter.scxml_entries.bt_utils import ( + BtPortsHandler, + is_bt_event, + replace_bt_event, +) +from as2fm.scxml_converter.scxml_entries.utils import ( + CallbackType, + get_plain_expression, + is_non_empty_string, +) from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_xml_argument, read_value_from_xml_child) + assert_xml_tag_ok, + get_xml_argument, + read_value_from_xml_child, +) # Use delayed type evaluation: https://peps.python.org/pep-0484/#forward-references -ScxmlExecutableEntry = Union['ScxmlAssign', 'ScxmlIf', 'ScxmlSend'] +ScxmlExecutableEntry = Union["ScxmlAssign", "ScxmlIf", "ScxmlSend"] ScxmlExecutionBody = List[ScxmlExecutableEntry] ConditionalExecutionBody = Tuple[str, ScxmlExecutionBody] def instantiate_exec_body_bt_events( - exec_body: Optional[ScxmlExecutionBody], instance_id: str) -> None: + exec_body: Optional[ScxmlExecutionBody], instance_id: str +) -> None: """ Instantiate the behavior tree events in the execution body. @@ -52,7 +63,8 @@ def instantiate_exec_body_bt_events( def update_exec_body_bt_ports_values( - exec_body: Optional[ScxmlExecutionBody], bt_ports_handler: BtPortsHandler) -> None: + exec_body: Optional[ScxmlExecutionBody], bt_ports_handler: BtPortsHandler +) -> None: """ Update the BT ports values in the execution body. """ @@ -102,14 +114,17 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlIf": else_body = current_body else: exec_bodies.append(current_body) - assert len(conditions) == len(exec_bodies), \ - "Error: SCXML if: number of conditions and bodies do not match " \ + assert len(conditions) == len(exec_bodies), ( + "Error: SCXML if: number of conditions and bodies do not match " f"({len(conditions)} != {len(exec_bodies)}). Conditions: {conditions}." + ) return ScxmlIf(list(zip(conditions, exec_bodies)), else_body) - def __init__(self, - conditional_executions: List[ConditionalExecutionBody], - else_execution: Optional[ScxmlExecutionBody] = None): + def __init__( + self, + conditional_executions: List[ConditionalExecutionBody], + else_execution: Optional[ScxmlExecutionBody] = None, + ): """ Class representing a conditional execution in SCXML. @@ -148,9 +163,10 @@ def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler): update_exec_body_bt_ports_values(self._else_execution, bt_ports_handler) def check_validity(self) -> bool: - valid_conditional_executions = len(self._conditional_executions) > 0 and \ - all(isinstance(condition, str) and len(body) > 0 and valid_execution_body(body) - for condition, body in self._conditional_executions) + valid_conditional_executions = len(self._conditional_executions) > 0 and all( + isinstance(condition, str) and len(body) > 0 and valid_execution_body(body) + for condition, body in self._conditional_executions + ) if not valid_conditional_executions: print("Error: SCXML if: Found invalid entries in conditional executions.") valid_else_execution = valid_execution_body(self._else_execution) @@ -158,12 +174,14 @@ def check_validity(self) -> bool: print("Error: SCXML if: invalid else execution body found.") return valid_conditional_executions and valid_else_execution - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: """Check if the ros instantiations have been declared.""" # Check the executable content - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - "Error: SCXML if: invalid ROS declarations type provided." + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), "Error: SCXML if: invalid ROS declarations type provided." for _, exec_body in self._conditional_executions: for exec_entry in exec_body: if not exec_entry.check_valid_ros_instantiations(ros_declarations): @@ -191,7 +209,8 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "Sc execution_body = as_plain_execution_body(execution, ros_declarations) assert execution_body is not None, "Error: SCXML if: invalid execution body." conditional_executions.append( - (get_plain_expression(condition, self._cb_type), execution_body)) + (get_plain_expression(condition, self._cb_type), execution_body) + ) set_execution_body_callback_type(self._else_execution, self._cb_type) else_execution = as_plain_execution_body(self._else_execution, ros_declarations) return ScxmlIf(conditional_executions, else_execution) @@ -203,10 +222,10 @@ def as_xml(self) -> ET.Element: xml_if = ET.Element(ScxmlIf.get_tag_name(), {"cond": first_conditional_execution[0]}) append_execution_body_to_xml(xml_if, first_conditional_execution[1]) for condition, execution in self._conditional_executions[1:]: - xml_if.append(ET.Element('elseif', {"cond": condition})) + xml_if.append(ET.Element("elseif", {"cond": condition})) append_execution_body_to_xml(xml_if, execution) if len(self._else_execution) > 0: - xml_if.append(ET.Element('else')) + xml_if.append(ET.Element("else")) append_execution_body_to_xml(xml_if, self._else_execution) return xml_if @@ -226,8 +245,9 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlSend": :param xml_tree: The XML tree to create the object from. :param cb_type: The kind of callback executing this SCXML entry. """ - assert xml_tree.tag == ScxmlSend.get_tag_name(), \ - f"Error: SCXML send: XML tag name is not {ScxmlSend.get_tag_name()}." + assert ( + xml_tree.tag == ScxmlSend.get_tag_name() + ), f"Error: SCXML send: XML tag name is not {ScxmlSend.get_tag_name()}." event = xml_tree.attrib["event"] params: List[ScxmlParam] = [] assert params is not None, "Error: SCXML send: params is not valid." @@ -293,8 +313,9 @@ def check_valid_ros_instantiations(self, _) -> bool: return True def append_param(self, param: ScxmlParam) -> None: - assert self.__class__ is ScxmlSend, \ - f"Error: SCXML send: cannot append param to derived class {self.__class__.__name__}." + assert ( + self.__class__ is ScxmlSend + ), f"Error: SCXML send: cannot append param to derived class {self.__class__.__name__}." assert isinstance(param, ScxmlParam), "Error: SCXML send: invalid param." self._params.append(param) @@ -383,14 +404,15 @@ def as_plain_scxml(self, _) -> "ScxmlAssign": def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid assign object." - return ET.Element(ScxmlAssign.get_tag_name(), { - "location": self._location, "expr": self._expr}) + return ET.Element( + ScxmlAssign.get_tag_name(), {"location": self._location, "expr": self._expr} + ) # Get the resolved types from the forward references in ScxmlExecutableEntry -_ResolvedScxmlExecutableEntry = \ - tuple(entry._evaluate(globals(), locals(), frozenset()) - for entry in get_args(ScxmlExecutableEntry)) +_ResolvedScxmlExecutableEntry = tuple( + entry._evaluate(globals(), locals(), frozenset()) for entry in get_args(ScxmlExecutableEntry) +) def valid_execution_body_entry_types(exec_body: ScxmlExecutionBody) -> bool: @@ -405,8 +427,10 @@ def valid_execution_body_entry_types(exec_body: ScxmlExecutionBody) -> bool: return False for entry in exec_body: if not isinstance(entry, _ResolvedScxmlExecutableEntry): - print(f"Error: SCXML execution body: entry type {type(entry)} not in valid set." - f" {_ResolvedScxmlExecutableEntry}.") + print( + f"Error: SCXML execution body: entry type {type(entry)} not in valid set." + f" {_ResolvedScxmlExecutableEntry}." + ) return False return True @@ -438,11 +462,13 @@ def execution_entry_from_xml(xml_tree: ET.Element) -> ScxmlExecutableEntry: # TODO: This should be generated only once, since it stays as it is tag_to_cls: Dict[str, ScxmlExecutableEntry] = { - cls.get_tag_name(): cls for cls in _ResolvedScxmlExecutableEntry} + cls.get_tag_name(): cls for cls in _ResolvedScxmlExecutableEntry + } tag_to_cls.update({cls.get_tag_name(): cls for cls in RosTrigger.__subclasses__()}) exec_tag = xml_tree.tag - assert exec_tag in tag_to_cls, \ - f"Error: SCXML conversion: tag {exec_tag} isn't an executable entry." + assert ( + exec_tag in tag_to_cls + ), f"Error: SCXML conversion: tag {exec_tag} isn't an executable entry." return tag_to_cls[exec_tag].from_xml_tree(xml_tree) @@ -483,8 +509,8 @@ def set_execution_body_callback_type(exec_body: ScxmlExecutionBody, cb_type: Cal def as_plain_execution_body( - exec_body: Optional[ScxmlExecutionBody], - ros_declarations: ScxmlRosDeclarationsContainer) -> Optional[ScxmlExecutionBody]: + exec_body: Optional[ScxmlExecutionBody], ros_declarations: ScxmlRosDeclarationsContainer +) -> Optional[ScxmlExecutionBody]: """ Convert the execution body to plain SCXML. diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_param.py b/src/as2fm/scxml_converter/scxml_entries/scxml_param.py index aff4752a..8e631f28 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_param.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_param.py @@ -22,10 +22,12 @@ from as2fm.scxml_converter.scxml_entries import BtGetValueInputPort, ScxmlBase from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - is_non_empty_string) +from as2fm.scxml_converter.scxml_entries.utils import CallbackType, is_non_empty_string from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_xml_argument, read_value_from_xml_arg_or_child) + assert_xml_tag_ok, + get_xml_argument, + read_value_from_xml_arg_or_child, +) class ScxmlParam(ScxmlBase): @@ -40,14 +42,19 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlParam": """Create a ScxmlParam object from an XML tree.""" assert_xml_tag_ok(ScxmlParam, xml_tree) name = get_xml_argument(ScxmlParam, xml_tree, "name") - expr = read_value_from_xml_arg_or_child(ScxmlParam, xml_tree, "expr", - (BtGetValueInputPort, str), True) + expr = read_value_from_xml_arg_or_child( + ScxmlParam, xml_tree, "expr", (BtGetValueInputPort, str), True + ) location = get_xml_argument(ScxmlParam, xml_tree, "location", none_allowed=True) return ScxmlParam(name, expr=expr, location=location) - def __init__(self, name: str, *, - expr: Optional[Union[BtGetValueInputPort, str]] = None, - location: Optional[str] = None): + def __init__( + self, + name: str, + *, + expr: Optional[Union[BtGetValueInputPort, str]] = None, + location: Optional[str] = None, + ): """ Initialize the SCXML Parameter object. diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_root.py b/src/as2fm/scxml_converter/scxml_entries/scxml_root.py index 0c69f9cd..371d672c 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_root.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_root.py @@ -22,18 +22,24 @@ from typing import List, Optional, Tuple, get_args from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (BtInputPortDeclaration, - BtOutputPortDeclaration, - RosActionThread, ScxmlBase, - ScxmlDataModel, - ScxmlRosDeclarationsContainer, - ScxmlState) +from as2fm.scxml_converter.scxml_entries import ( + BtInputPortDeclaration, + BtOutputPortDeclaration, + RosActionThread, + ScxmlBase, + ScxmlDataModel, + ScxmlRosDeclarationsContainer, + ScxmlState, +) from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler from as2fm.scxml_converter.scxml_entries.scxml_bt import BtPortDeclarations from as2fm.scxml_converter.scxml_entries.scxml_ros_base import RosDeclaration from as2fm.scxml_converter.scxml_entries.utils import is_non_empty_string from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_children_as_scxml, get_xml_argument) + assert_xml_tag_ok, + get_children_as_scxml, + get_xml_argument, +) class ScxmlRoot(ScxmlBase): @@ -50,19 +56,23 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlRoot": assert_xml_tag_ok(ScxmlRoot, xml_tree) scxml_name = get_xml_argument(ScxmlRoot, xml_tree, "name") scxml_version = get_xml_argument(ScxmlRoot, xml_tree, "version") - assert scxml_version == "1.0", \ - f"Error: SCXML root: expected version 1.0, found {scxml_version}." + assert ( + scxml_version == "1.0" + ), f"Error: SCXML root: expected version 1.0, found {scxml_version}." scxml_init_state = get_xml_argument(ScxmlRoot, xml_tree, "initial") # Data Model datamodel_elements = get_children_as_scxml(xml_tree, (ScxmlDataModel,)) - assert len(datamodel_elements) <= 1, \ - f"Error: SCXML root: {len(datamodel_elements)} datamodels found, max 1 allowed." + assert ( + len(datamodel_elements) <= 1 + ), f"Error: SCXML root: {len(datamodel_elements)} datamodels found, max 1 allowed." # ROS Declarations ros_declarations: List[RosDeclaration] = get_children_as_scxml( - xml_tree, RosDeclaration.__subclasses__()) + xml_tree, RosDeclaration.__subclasses__() + ) # BT Declarations bt_port_declarations: List[BtPortDeclarations] = get_children_as_scxml( - xml_tree, get_args(BtPortDeclarations)) + xml_tree, get_args(BtPortDeclarations) + ) # Additional threads additional_threads = get_children_as_scxml(xml_tree, (RosActionThread,)) # States @@ -152,8 +162,7 @@ def add_state(self, state: ScxmlState, *, initial: bool = False): If initial is True, set it as the initial state.""" self._states.append(state) if initial: - assert self._initial_state is None, \ - "Error: SCXML root: Initial state already set" + assert self._initial_state is None, "Error: SCXML root: Initial state already set" self._initial_state = state.get_id() def set_data_model(self, data_model: ScxmlDataModel): @@ -161,8 +170,9 @@ def set_data_model(self, data_model: ScxmlDataModel): self._data_model = data_model def add_ros_declaration(self, ros_declaration: RosDeclaration): - assert isinstance(ros_declaration, RosDeclaration), \ - "Error: SCXML root: invalid ROS declaration type." + assert isinstance( + ros_declaration, RosDeclaration + ), "Error: SCXML root: invalid ROS declaration type." assert ros_declaration.check_validity(), "Error: SCXML root: invalid ROS declaration." self._ros_declarations.append(ros_declaration) @@ -170,17 +180,21 @@ def add_bt_port_declaration(self, bt_port_decl: BtPortDeclarations): """Add a BT port declaration to the handler.""" if isinstance(bt_port_decl, BtInputPortDeclaration): self._bt_ports_handler.declare_in_port( - bt_port_decl.get_key_name(), bt_port_decl.get_key_type()) + bt_port_decl.get_key_name(), bt_port_decl.get_key_type() + ) elif isinstance(bt_port_decl, BtOutputPortDeclaration): self._bt_ports_handler.declare_out_port( - bt_port_decl.get_key_name(), bt_port_decl.get_key_type()) + bt_port_decl.get_key_name(), bt_port_decl.get_key_type() + ) else: raise ValueError( - f"Error: SCXML root: invalid BT port declaration type {type(bt_port_decl)}.") + f"Error: SCXML root: invalid BT port declaration type {type(bt_port_decl)}." + ) def add_action_thread(self, action_thread: RosActionThread): - assert isinstance(action_thread, RosActionThread), \ - f"Error: SCXML root: invalid action thread type {type(action_thread)}." + assert isinstance( + action_thread, RosActionThread + ), f"Error: SCXML root: invalid action thread type {type(action_thread)}." self._additional_threads.append(action_thread) def set_bt_port_value(self, port_name: str, port_value: str): @@ -207,8 +221,9 @@ def _generate_ros_declarations_helper(self) -> Optional[ScxmlRosDeclarationsCont """Generate a HelperRosDeclarations object from the existing ROS declarations.""" ros_decl_container = ScxmlRosDeclarationsContainer(self._name) for ros_declaration in self._ros_declarations: - if not (ros_declaration.check_validity() and - ros_declaration.check_valid_instantiation()): + if not ( + ros_declaration.check_validity() and ros_declaration.check_valid_instantiation() + ): return None ros_decl_container.append_ros_declaration(ros_declaration) return ros_decl_container @@ -217,11 +232,13 @@ def check_validity(self) -> bool: valid_name = is_non_empty_string(ScxmlRoot, "name", self._name) valid_initial_state = is_non_empty_string(ScxmlRoot, "initial state", self._initial_state) valid_data_model = self._data_model is None or self._data_model.check_validity() - valid_states = all(isinstance(state, ScxmlState) and state.check_validity() - for state in self._states) - valid_threads = all(isinstance(scxml_thread, RosActionThread) and - scxml_thread.check_validity() for scxml_thread in - self._additional_threads) + valid_states = all( + isinstance(state, ScxmlState) and state.check_validity() for state in self._states + ) + valid_threads = all( + isinstance(scxml_thread, RosActionThread) and scxml_thread.check_validity() + for scxml_thread in self._additional_threads + ) if not valid_data_model: print("Error: SCXML root: datamodel is not valid.") if not valid_states: @@ -231,8 +248,9 @@ def check_validity(self) -> bool: valid_ros = self._check_valid_ros_declarations() if not valid_ros: print("Error: SCXML root: ROS declarations are not valid.") - return valid_name and valid_initial_state and valid_states and valid_data_model and \ - valid_ros + return ( + valid_name and valid_initial_state and valid_states and valid_data_model and valid_ros + ) def _check_valid_ros_declarations(self) -> bool: """Check if the ros declarations and instantiations are valid.""" @@ -241,11 +259,14 @@ def _check_valid_ros_declarations(self) -> bool: if ros_decl_container is None: return False # Check the ROS instantiations - if not all(state.check_valid_ros_instantiations(ros_decl_container) - for state in self._states): + if not all( + state.check_valid_ros_instantiations(ros_decl_container) for state in self._states + ): return False - if not all(scxml_thread.check_valid_ros_instantiations(ros_decl_container) - for scxml_thread in self._additional_threads): + if not all( + scxml_thread.check_valid_ros_instantiations(ros_decl_container) + for scxml_thread in self._additional_threads + ): return False return True @@ -255,8 +276,9 @@ def is_plain_scxml(self) -> bool: # If this is a valid scxml object, just check the absence of ROS and thread declarations return len(self._ros_declarations) == 0 and len(self._additional_threads) == 0 - def to_plain_scxml_and_declarations(self) -> Tuple[List["ScxmlRoot"], - ScxmlRosDeclarationsContainer]: + def to_plain_scxml_and_declarations( + self, + ) -> Tuple[List["ScxmlRoot"], ScxmlRosDeclarationsContainer]: """ Convert all internal ROS specific entries to plain SCXML. @@ -278,27 +300,33 @@ def to_plain_scxml_and_declarations(self) -> Tuple[List["ScxmlRoot"], for scxml_thread in self._additional_threads: converted_scxmls.extend(scxml_thread.as_plain_scxml(ros_declarations)) for plain_scxml in converted_scxmls: - assert isinstance(plain_scxml, ScxmlRoot), \ - "Error: SCXML root: conversion to plain SCXML resulted in invalid object " \ + assert isinstance(plain_scxml, ScxmlRoot), ( + "Error: SCXML root: conversion to plain SCXML resulted in invalid object " f"(expected ScxmlRoot, obtained {type(plain_scxml)}." - assert plain_scxml.check_validity(), \ - f"The SCXML root object {plain_scxml.get_name()} is not valid: " \ + ) + assert plain_scxml.check_validity(), ( + f"The SCXML root object {plain_scxml.get_name()} is not valid: " "conversion to plain SCXML failed." - assert plain_scxml.is_plain_scxml(), \ - f"The SCXML root object {plain_scxml.get_name()} is not plain SCXML: " \ + ) + assert plain_scxml.is_plain_scxml(), ( + f"The SCXML root object {plain_scxml.get_name()} is not plain SCXML: " "conversion to plain SCXML failed." + ) return (converted_scxmls, ros_declarations) def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid root object." assert self._initial_state is not None, "Error: SCXML root: no initial state set." - xml_root = ET.Element("scxml", { - "name": self._name, - "version": self._version, - "model_src": "", - "initial": self._initial_state, - "xmlns": "http://www.w3.org/2005/07/scxml" - }) + xml_root = ET.Element( + "scxml", + { + "name": self._name, + "version": self._version, + "model_src": "", + "initial": self._initial_state, + "xmlns": "http://www.w3.org/2005/07/scxml", + }, + ) if self._data_model is not None: data_model_xml = self._data_model.as_xml() assert data_model_xml is not None, "Error: SCXML root: invalid data model." @@ -313,4 +341,4 @@ def as_xml(self) -> ET.Element: return xml_root def as_xml_string(self) -> str: - return ET.tostring(self.as_xml(), encoding='unicode', xml_declaration=True) + return ET.tostring(self.as_xml(), encoding="unicode", xml_declaration=True) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_client.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_client.py index 93229f20..d75b3f57 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_client.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_client.py @@ -24,20 +24,22 @@ from action_msgs.msg import GoalStatus -from as2fm.scxml_converter.scxml_entries import (ScxmlRosDeclarationsContainer, - ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ScxmlRosDeclarationsContainer, ScxmlTransition from as2fm.scxml_converter.scxml_entries.ros_utils import ( generate_action_feedback_handle_event, generate_action_goal_handle_accepted_event, - generate_action_goal_handle_rejected_event, generate_action_goal_req_event, - generate_action_result_handle_event, is_action_type_known) -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosDeclaration, - RosTrigger) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - is_non_empty_string) -from as2fm.scxml_converter.scxml_entries.xml_utils import (assert_xml_tag_ok, - get_xml_argument) + generate_action_goal_handle_rejected_event, + generate_action_goal_req_event, + generate_action_result_handle_event, + is_action_type_known, +) +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import ( + RosCallback, + RosDeclaration, + RosTrigger, +) +from as2fm.scxml_converter.scxml_entries.utils import CallbackType, is_non_empty_string +from as2fm.scxml_converter.scxml_entries.xml_utils import assert_xml_tag_ok, get_xml_argument class RosActionClient(RosDeclaration): @@ -78,7 +80,8 @@ def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_goal_req_event( ros_declarations.get_action_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) class RosActionHandleGoalResponse(ScxmlTransition): @@ -103,8 +106,9 @@ def from_xml_tree(xml_tree: ET.Element) -> "RosActionHandleGoalResponse": reject_target = get_xml_argument(RosActionHandleGoalResponse, xml_tree, "reject") return RosActionHandleGoalResponse(action_name, accept_target, reject_target) - def __init__(self, action_client: Union[str, RosActionClient], - accept_target: str, reject_target: str) -> None: + def __init__( + self, action_client: Union[str, RosActionClient], accept_target: str, reject_target: str + ) -> None: """ Initialize a new RosActionHandleGoalResponse object. @@ -123,10 +127,12 @@ def __init__(self, action_client: Union[str, RosActionClient], def check_validity(self) -> bool: valid_name = is_non_empty_string(RosActionHandleGoalResponse, "name", self._client_name) - valid_accept = is_non_empty_string(RosActionHandleGoalResponse, "accept", - self._accept_target) - valid_reject = is_non_empty_string(RosActionHandleGoalResponse, "reject", - self._reject_target) + valid_accept = is_non_empty_string( + RosActionHandleGoalResponse, "accept", self._accept_target + ) + valid_reject = is_non_empty_string( + RosActionHandleGoalResponse, "reject", self._reject_target + ) return valid_name and valid_accept and valid_reject def instantiate_bt_events(self, _: str): @@ -138,22 +144,29 @@ def update_bt_ports_values(self, _) -> None: # We do not expect a body with BT ports to be substituted pass - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - "Error: SCXML Service Handle Response: invalid ROS declarations container." - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - "Error: SCXML action goal request: invalid ROS declarations container." + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), "Error: SCXML Service Handle Response: invalid ROS declarations container." + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), "Error: SCXML action goal request: invalid ROS declarations container." if not ros_declarations.is_action_client_defined(self._client_name): - print("Error: SCXML action goal request: " - f"action client {self._client_name} not declared.") + print( + "Error: SCXML action goal request: " + f"action client {self._client_name} not declared." + ) return False return True - def as_plain_scxml(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> List[ScxmlTransition]: - assert self.check_valid_ros_instantiations(ros_declarations), \ - "Error: SCXML service response handler: invalid ROS instantiations." + def as_plain_scxml( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> List[ScxmlTransition]: + assert self.check_valid_ros_instantiations( + ros_declarations + ), "Error: SCXML service response handler: invalid ROS instantiations." automaton_name = ros_declarations.get_automaton_name() interface_name, _ = ros_declarations.get_action_client_info(self._client_name) accept_event = generate_action_goal_handle_accepted_event(interface_name, automaton_name) @@ -164,9 +177,14 @@ def as_plain_scxml(self, def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML Service Handle Response: invalid parameters." - return ET.Element(RosActionHandleGoalResponse.get_tag_name(), - {"name": self._client_name, - "accept": self._accept_target, "reject": self._reject_target}) + return ET.Element( + RosActionHandleGoalResponse.get_tag_name(), + { + "name": self._client_name, + "accept": self._accept_target, + "reject": self._reject_target, + }, + ) class RosActionHandleFeedback(RosCallback): @@ -190,7 +208,8 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_feedback_handle_event( ros_declarations.get_action_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) class RosActionHandleSuccessResult(RosCallback): @@ -214,11 +233,13 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_handle_event( ros_declarations.get_action_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self._condition is None, \ - "Error: SCXML RosActionHandleSuccessResult: condition not supported." + assert ( + self._condition is None + ), "Error: SCXML RosActionHandleSuccessResult: condition not supported." self._condition = f"_wrapped_result.code == {GoalStatus.STATUS_SUCCEEDED}" return super().as_plain_scxml(ros_declarations) @@ -244,11 +265,13 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_handle_event( ros_declarations.get_action_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self._condition is None, \ - "Error: SCXML RosActionHandleSuccessResult: condition not supported." + assert ( + self._condition is None + ), "Error: SCXML RosActionHandleSuccessResult: condition not supported." self._condition = f"_wrapped_result.code == {GoalStatus.STATUS_CANCELED}" return super().as_plain_scxml(ros_declarations) @@ -274,10 +297,12 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_handle_event( ros_declarations.get_action_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self._condition is None, \ - "Error: SCXML RosActionHandleSuccessResult: condition not supported." + assert ( + self._condition is None + ), "Error: SCXML RosActionHandleSuccessResult: condition not supported." self._condition = f"_wrapped_result.code == {GoalStatus.STATUS_ABORTED}" return super().as_plain_scxml(ros_declarations) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server.py index cd12958f..89402c81 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server.py @@ -24,17 +24,22 @@ from action_msgs.msg import GoalStatus -from as2fm.scxml_converter.scxml_entries import (ScxmlParam, - ScxmlRosDeclarationsContainer, - ScxmlSend) +from as2fm.scxml_converter.scxml_entries import ScxmlParam, ScxmlRosDeclarationsContainer, ScxmlSend from as2fm.scxml_converter.scxml_entries.ros_utils import ( - generate_action_feedback_event, generate_action_goal_accepted_event, - generate_action_goal_handle_event, generate_action_goal_rejected_event, - generate_action_result_event, generate_action_thread_execution_start_event, - generate_action_thread_free_event, is_action_type_known) -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosDeclaration, - RosTrigger) + generate_action_feedback_event, + generate_action_goal_accepted_event, + generate_action_goal_handle_event, + generate_action_goal_rejected_event, + generate_action_result_event, + generate_action_thread_execution_start_event, + generate_action_thread_free_event, + is_action_type_known, +) +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import ( + RosCallback, + RosDeclaration, + RosTrigger, +) from as2fm.scxml_converter.scxml_entries.utils import CallbackType @@ -81,7 +86,8 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_goal_handle_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) class RosActionAcceptGoal(RosTrigger): @@ -109,7 +115,8 @@ def check_fields_validity(self, _) -> bool: def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_goal_accepted_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_xml(self) -> ET.Element: assert self.check_fields_validity(None), "Error: SCXML RosActionAcceptGoal: invalid fields." @@ -117,7 +124,6 @@ def as_xml(self) -> ET.Element: class RosActionRejectGoal(RosTrigger): - """ Object representing the SCXML ROS Event sent from the server when an action Goal is rejected. """ @@ -143,7 +149,8 @@ def check_fields_validity(self, _) -> bool: def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_goal_rejected_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_xml(self) -> ET.Element: assert self.check_fields_validity(None), "Error: SCXML RosActionRejectGoal: invalid fields." @@ -177,7 +184,8 @@ def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_thread_execution_start_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) class RosActionSendFeedback(RosTrigger): @@ -200,12 +208,14 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) -> bool: """Check if the goal_id and the request fields have been defined.""" - return ros_declarations.check_valid_action_feedback_fields(self._interface_name, - self._fields) + return ros_declarations.check_valid_action_feedback_fields( + self._interface_name, self._fields + ) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_feedback_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) class RosActionSendSuccessResult(RosTrigger): @@ -232,7 +242,8 @@ def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlSend: plain_send = super().as_plain_scxml(ros_declarations) @@ -264,7 +275,8 @@ def check_fields_validity(self, _) -> bool: def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlSend: plain_send = super().as_plain_scxml(ros_declarations) @@ -296,7 +308,8 @@ def check_fields_validity(self, _) -> bool: def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_result_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlSend: plain_send = super().as_plain_scxml(ros_declarations) @@ -327,4 +340,5 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_thread_free_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server_thread.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server_thread.py index 8f532bfe..47991d5a 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server_thread.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_action_server_thread.py @@ -22,24 +22,30 @@ from typing import List, Optional, Type, Union from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (RosField, ScxmlBase, - ScxmlDataModel, - ScxmlExecutionBody, - ScxmlParam, - ScxmlRosDeclarationsContainer, - ScxmlState, ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ( + RosField, + ScxmlBase, + ScxmlDataModel, + ScxmlExecutionBody, + ScxmlParam, + ScxmlRosDeclarationsContainer, + ScxmlState, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler from as2fm.scxml_converter.scxml_entries.ros_utils import ( generate_action_thread_execution_start_event, - generate_action_thread_free_event, sanitize_ros_interface_name) -from as2fm.scxml_converter.scxml_entries.scxml_ros_action_server import \ - RosActionServer -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosTrigger) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - is_non_empty_string) + generate_action_thread_free_event, + sanitize_ros_interface_name, +) +from as2fm.scxml_converter.scxml_entries.scxml_ros_action_server import RosActionServer +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import RosCallback, RosTrigger +from as2fm.scxml_converter.scxml_entries.utils import CallbackType, is_non_empty_string from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_children_as_scxml, get_xml_argument) + assert_xml_tag_ok, + get_children_as_scxml, + get_xml_argument, +) class RosActionThread(ScxmlBase): @@ -102,12 +108,11 @@ def add_state(self, state: ScxmlState, *, initial: bool = False): If initial is True, set it as the initial state.""" self._states.append(state) if initial: - assert self._initial_state is None, \ - 'Error: RosActionThread: Initial state already set' + assert self._initial_state is None, "Error: RosActionThread: Initial state already set" self._initial_state = state.get_id() def set_data_model(self, data_model: ScxmlDataModel): - assert self._data_model is None, 'Data model already set' + assert self._data_model is None, "Data model already set" self._data_model = data_model def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler) -> None: @@ -121,13 +126,16 @@ def check_validity(self) -> bool: valid_n_threads = isinstance(self._n_threads, int) and self._n_threads > 0 valid_initial_state = self._initial_state is not None valid_data_model = self._data_model is None or self._data_model.check_validity() - valid_states = all(isinstance(state, ScxmlState) and state.check_validity() - for state in self._states) + valid_states = all( + isinstance(state, ScxmlState) and state.check_validity() for state in self._states + ) if not valid_name: return False if not valid_n_threads: - print("Error: SCXML RosActionThread: " - f"{self._name} has invalid n_threads ({self._n_threads}).") + print( + "Error: SCXML RosActionThread: " + f"{self._name} has invalid n_threads ({self._n_threads})." + ) if not valid_initial_state: print(f"Error: SCXML RosActionThread: {self._name} has no initial state.") if not valid_data_model: @@ -137,14 +145,17 @@ def check_validity(self) -> bool: return valid_n_threads and valid_initial_state and valid_data_model and valid_states def check_valid_ros_instantiations(self, ros_decls: ScxmlRosDeclarationsContainer) -> bool: - assert isinstance(ros_decls, ScxmlRosDeclarationsContainer), \ - "Error: SCXML RosActionThread: Invalid ROS declarations container." + assert isinstance( + ros_decls, ScxmlRosDeclarationsContainer + ), "Error: SCXML RosActionThread: Invalid ROS declarations container." if not ros_decls.is_action_server_defined(self._name): print(f"Error: SCXML RosActionThread: undeclared thread action server '{self._name}'.") return False if not all(state.check_valid_ros_instantiations(ros_decls) for state in self._states): - print("Error: SCXML RosActionThread: " - f"invalid ROS instantiation for states in thread '{self._name}'.") + print( + "Error: SCXML RosActionThread: " + f"invalid ROS instantiation for states in thread '{self._name}'." + ) return False return True @@ -155,9 +166,11 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Lis This returns a list of ScxmlRoot objects, using ScxmlBase to avoid circular dependencies. """ from as2fm.scxml_converter.scxml_entries import ScxmlRoot + thread_instances: List[ScxmlRoot] = [] action_name = sanitize_ros_interface_name( - ros_declarations.get_action_server_info(self._name)[0]) + ros_declarations.get_action_server_info(self._name)[0] + ) for thread_idx in range(self._n_threads): thread_name = f"{action_name}_thread_{thread_idx}" plain_thread_instance = ScxmlRoot(thread_name) @@ -165,11 +178,13 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Lis for state in self._states: initial_state = state.get_id() == self._initial_state state.set_thread_id(thread_idx) - plain_thread_instance.add_state(state.as_plain_scxml(ros_declarations), - initial=initial_state) - assert plain_thread_instance.is_plain_scxml(), \ - "Error: SCXML RosActionThread: " \ + plain_thread_instance.add_state( + state.as_plain_scxml(ros_declarations), initial=initial_state + ) + assert plain_thread_instance.is_plain_scxml(), ( + "Error: SCXML RosActionThread: " f"failed to generate a plain-SCXML instance from thread '{self._name}'" + ) thread_instances.append(plain_thread_instance) return thread_instances @@ -200,9 +215,13 @@ def get_callback_type() -> CallbackType: # The thread is started upon a goal request, so use the action goal type return CallbackType.ROS_ACTION_GOAL - def __init__(self, server_alias: Union[str, RosActionServer], target_state: str, - condition: Optional[str] = None, exec_body: Optional[ScxmlExecutionBody] = None - ) -> None: + def __init__( + self, + server_alias: Union[str, RosActionServer], + target_state: str, + condition: Optional[str] = None, + exec_body: Optional[ScxmlExecutionBody] = None, + ) -> None: """ Initialize a new RosActionHandleResult object. @@ -229,18 +248,21 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def set_thread_id(self, thread_id: int) -> None: """Set the thread ID for this handler.""" # The thread ID is expected to be overwritten every time a new thread is generated. - assert isinstance(thread_id, int) and thread_id >= 0, \ - f"Error: SCXML {self.__class__.__name__}: invalid thread ID ({thread_id})." + assert ( + isinstance(thread_id, int) and thread_id >= 0 + ), f"Error: SCXML {self.__class__.__name__}: invalid thread ID ({thread_id})." self._thread_id = thread_id super().set_thread_id(thread_id) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_thread_execution_start_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self._thread_id is not None, \ - f"Error: SCXML {self.__class__.__name__}: thread ID not set." + assert ( + self._thread_id is not None + ), f"Error: SCXML {self.__class__.__name__}: thread ID not set." # Append a condition checking the thread ID matches the request self._condition = "_event.thread_id == " + str(self._thread_id) return super().as_plain_scxml(ros_declarations) @@ -262,8 +284,12 @@ def get_tag_name() -> str: def get_declaration_type() -> Type[RosActionServer]: return RosActionServer - def __init__(self, action_name: Union[str, RosActionServer], - fields: Optional[List[RosField]] = None, _=None) -> None: + def __init__( + self, + action_name: Union[str, RosActionServer], + fields: Optional[List[RosField]] = None, + _=None, + ) -> None: super().__init__(action_name, fields) self._thread_id: Optional[int] = None @@ -282,17 +308,20 @@ def check_fields_validity(self, _) -> bool: def set_thread_id(self, thread_id: int) -> None: """Set the thread ID for this handler.""" # The thread ID is expected to be overwritten every time a new thread is generated. - assert isinstance(thread_id, int) and thread_id >= 0, \ - f"Error: SCXML {self.__class__.__name__}: invalid thread ID ({thread_id})." + assert ( + isinstance(thread_id, int) and thread_id >= 0 + ), f"Error: SCXML {self.__class__.__name__}: invalid thread ID ({thread_id})." self._thread_id = thread_id def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_action_thread_free_event( - ros_declarations.get_action_server_info(self._interface_name)[0]) + ros_declarations.get_action_server_info(self._interface_name)[0] + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self._thread_id is not None, \ - f"Error: SCXML {self.__class__.__name__}: thread ID not set." + assert ( + self._thread_id is not None + ), f"Error: SCXML {self.__class__.__name__}: thread ID not set." plain_trigger = super().as_plain_scxml(ros_declarations) # Add the thread id to the (empty) param list plain_trigger.append_param(ScxmlParam("thread_id", expr=str(self._thread_id))) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_base.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_base.py index 8167401b..9c79e417 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_base.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_base.py @@ -18,20 +18,33 @@ from typing import Dict, List, Optional, Type, Union from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (BtGetValueInputPort, RosField, - ScxmlBase, ScxmlExecutionBody, - ScxmlParam, - ScxmlRosDeclarationsContainer, - ScxmlSend, ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ( + BtGetValueInputPort, + RosField, + ScxmlBase, + ScxmlExecutionBody, + ScxmlParam, + ScxmlRosDeclarationsContainer, + ScxmlSend, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler from as2fm.scxml_converter.scxml_entries.scxml_executable_entries import ( - as_plain_execution_body, execution_body_from_xml, - set_execution_body_callback_type, valid_execution_body) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - get_plain_expression, - is_non_empty_string) + as_plain_execution_body, + execution_body_from_xml, + set_execution_body_callback_type, + valid_execution_body, +) +from as2fm.scxml_converter.scxml_entries.utils import ( + CallbackType, + get_plain_expression, + is_non_empty_string, +) from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_xml_argument, read_value_from_xml_arg_or_child) + assert_xml_tag_ok, + get_xml_argument, + read_value_from_xml_arg_or_child, +) class RosDeclaration(ScxmlBase): @@ -56,17 +69,22 @@ def get_xml_arg_interface_name(cls) -> str: return f"{cls.get_communication_interface()}_name" @classmethod - def from_xml_tree(cls: Type['RosDeclaration'], xml_tree: ET.Element) -> 'RosDeclaration': + def from_xml_tree(cls: Type["RosDeclaration"], xml_tree: ET.Element) -> "RosDeclaration": """Create an instance of the class from an XML tree.""" assert_xml_tag_ok(cls, xml_tree) interface_name = read_value_from_xml_arg_or_child( - cls, xml_tree, cls.get_xml_arg_interface_name(), (BtGetValueInputPort, str)) + cls, xml_tree, cls.get_xml_arg_interface_name(), (BtGetValueInputPort, str) + ) interface_type = get_xml_argument(cls, xml_tree, "type") interface_alias = get_xml_argument(cls, xml_tree, "name", none_allowed=True) return cls(interface_name, interface_type, interface_alias) - def __init__(self, interface_name: Union[str, BtGetValueInputPort], interface_type: str, - interface_alias: Optional[str] = None): + def __init__( + self, + interface_name: Union[str, BtGetValueInputPort], + interface_type: str, + interface_alias: Optional[str] = None, + ): """ Constructor of ROS declaration. @@ -77,13 +95,15 @@ def __init__(self, interface_name: Union[str, BtGetValueInputPort], interface_ty self._interface_name = interface_name self._interface_type = interface_type self._interface_alias = interface_alias - assert isinstance(interface_name, (str, BtGetValueInputPort)), \ - f"Error: SCXML {self.get_tag_name()}: " \ + assert isinstance(interface_name, (str, BtGetValueInputPort)), ( + f"Error: SCXML {self.get_tag_name()}: " f"invalid type of interface_name {type(interface_name)}." + ) if self._interface_alias is None: - assert is_non_empty_string(self.__class__, "interface_name", self._interface_name), \ - f"Error: SCXML {self.__class__.__name__}: " \ + assert is_non_empty_string(self.__class__, "interface_name", self._interface_name), ( + f"Error: SCXML {self.__class__.__name__}: " "an alias name is required for dynamic ROS interfaces." + ) self._interface_alias = interface_name def get_interface_name(self) -> str: @@ -100,12 +120,14 @@ def get_name(self) -> str: def check_valid_interface_type(self) -> bool: return NotImplementedError( - f"{self.__class__.__name__} doesn't implement check_valid_interface_type.") + f"{self.__class__.__name__} doesn't implement check_valid_interface_type." + ) def check_validity(self) -> bool: valid_alias = is_non_empty_string(self.__class__, "name", self._interface_alias) - valid_action_name = isinstance(self._interface_name, BtGetValueInputPort) or \ - is_non_empty_string(self.__class__, "interface_name", self._interface_name) + valid_action_name = isinstance( + self._interface_name, BtGetValueInputPort + ) or is_non_empty_string(self.__class__, "interface_name", self._interface_name) valid_action_type = self.check_valid_interface_type() return valid_alias and valid_action_name and valid_action_type @@ -116,20 +138,26 @@ def check_valid_instantiation(self) -> bool: def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler) -> None: """Update the values of potential entries making use of BT ports.""" if isinstance(self._interface_name, BtGetValueInputPort): - self._interface_name = \ - bt_ports_handler.get_in_port_value(self._interface_name.get_key_name()) + self._interface_name = bt_ports_handler.get_in_port_value( + self._interface_name.get_key_name() + ) def as_plain_scxml(self, _) -> ScxmlBase: # This is discarded in the to_plain_scxml_and_declarations method from ScxmlRoot raise RuntimeError( - f"Error: SCXML {self.__class__.__name__} cannot be converted to plain SCXML.") + f"Error: SCXML {self.__class__.__name__} cannot be converted to plain SCXML." + ) def as_xml(self) -> ET.Element: assert self.check_validity(), f"Error: SCXML {self.__class__.__name__}: invalid parameters." - xml_declaration = ET.Element(self.get_tag_name(), - {"name": self._interface_alias, - self.get_xml_arg_interface_name(): self._interface_name, - "type": self._interface_type}) + xml_declaration = ET.Element( + self.get_tag_name(), + { + "name": self._interface_alias, + self.get_xml_arg_interface_name(): self._interface_name, + "type": self._interface_type, + }, + ) return xml_declaration @@ -156,7 +184,7 @@ def get_callback_type(cls) -> CallbackType: raise NotImplementedError(f"{cls.__name__} doesn't implement get_callback_type.") @classmethod - def from_xml_tree(cls: Type['RosCallback'], xml_tree: ET.Element) -> 'RosCallback': + def from_xml_tree(cls: Type["RosCallback"], xml_tree: ET.Element) -> "RosCallback": """Create an instance of the class from an XML tree.""" assert_xml_tag_ok(cls, xml_tree) interface_name = get_xml_argument(cls, xml_tree, "name") @@ -165,9 +193,13 @@ def from_xml_tree(cls: Type['RosCallback'], xml_tree: ET.Element) -> 'RosCallbac exec_body = execution_body_from_xml(xml_tree) return cls(interface_name, target_state, condition, exec_body) - def __init__(self, interface_decl: Union[str, RosDeclaration], target_state: str, - condition: Optional[str] = None, exec_body: Optional[ScxmlExecutionBody] = None - ) -> None: + def __init__( + self, + interface_decl: Union[str, RosDeclaration], + target_state: str, + condition: Optional[str] = None, + exec_body: Optional[ScxmlExecutionBody] = None, + ) -> None: """ Constructor of ROS callback. @@ -187,14 +219,14 @@ def __init__(self, interface_decl: Union[str, RosDeclaration], target_state: str self._target: str = target_state self._condition: Optional[str] = condition self._body: ScxmlExecutionBody = exec_body - assert self.check_validity(), \ - f"Error: SCXML {self.__class__.__name__}: invalid parameters." + assert self.check_validity(), f"Error: SCXML {self.__class__.__name__}: invalid parameters." def check_validity(self) -> bool: valid_name = is_non_empty_string(self.__class__, "name", self._interface_name) valid_target = is_non_empty_string(self.__class__, "target", self._target) - valid_condition = (self._condition is None or - is_non_empty_string(self.__class__, "cond", self._condition)) + valid_condition = self._condition is None or is_non_empty_string( + self.__class__, "cond", self._condition + ) valid_body = self._body is None or valid_execution_body(self._body) if not valid_body: print(f"Error: SCXML {self.__class__.__name__}: invalid entries in executable body.") @@ -203,31 +235,40 @@ def check_validity(self) -> bool: def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContainer) -> bool: """Check if the ROS interface used in the callback exists.""" raise NotImplementedError( - f"{self.__class__.__name__} doesn't implement check_interface_defined.") + f"{self.__class__.__name__} doesn't implement check_interface_defined." + ) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: """Translate the ROS interface name to a plain scxml event.""" raise NotImplementedError( - f"{self.__class__.__name__} doesn't implement get_plain_scxml_event.") + f"{self.__class__.__name__} doesn't implement get_plain_scxml_event." + ) - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: """Check if the ROS entries in the callback are correctly defined.""" - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - f"Error: SCXML {self.__class__.__name__}: invalid type of ROS declarations container." + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), f"Error: SCXML {self.__class__.__name__}: invalid type of ROS declarations container." if not self.check_interface_defined(ros_declarations): - print(f"Error: SCXML {self.__class__.__name__}: " - f"undefined ROS interface {self._interface_name}.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"undefined ROS interface {self._interface_name}." + ) return False valid_body = super().check_valid_ros_instantiations(ros_declarations) if not valid_body: - print(f"Error: SCXML {self.__class__.__name__}: " - f"body of {self._interface_name} has invalid ROS instantiations.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"body of {self._interface_name} has invalid ROS instantiations." + ) return valid_body def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlTransition: - assert self.check_valid_ros_instantiations(ros_declarations), \ - f"Error: SCXML {self.__class__.__name__}: invalid ROS instantiations." + assert self.check_valid_ros_instantiations( + ros_declarations + ), f"Error: SCXML {self.__class__.__name__}: invalid ROS instantiations." set_execution_body_callback_type(self._body, self.get_callback_type()) event_name = self.get_plain_scxml_event(ros_declarations) target = self._target @@ -240,8 +281,9 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Scx def as_xml(self) -> ET.Element: """Convert the ROS callback to an XML element.""" assert self.check_validity(), f"Error: SCXML {self.__class__.__name__}: invalid parameters." - xml_callback = ET.Element(self.get_tag_name(), - {"name": self._interface_name, "target": self._target}) + xml_callback = ET.Element( + self.get_tag_name(), {"name": self._interface_name, "target": self._target} + ) if self._condition is not None: xml_callback.set("cond", self._condition) for body_elem in self._body: @@ -272,7 +314,7 @@ def get_additional_arguments() -> List[str]: return [] @classmethod - def from_xml_tree(cls: Type['RosTrigger'], xml_tree: ET.Element) -> 'RosTrigger': + def from_xml_tree(cls: Type["RosTrigger"], xml_tree: ET.Element) -> "RosTrigger": """ Create an instance of the class from an XML tree. @@ -284,13 +326,17 @@ def from_xml_tree(cls: Type['RosTrigger'], xml_tree: ET.Element) -> 'RosTrigger' additional_arg_values: Dict[str, str] = {} for arg_name in cls.get_additional_arguments(): additional_arg_values[arg_name] = get_xml_argument(cls, xml_tree, arg_name) - fields = [RosField.from_xml_tree(field) for field in xml_tree - if field.tag is not ET.Comment] + fields = [ + RosField.from_xml_tree(field) for field in xml_tree if field.tag is not ET.Comment + ] return cls(interface_name, fields, additional_arg_values) - def __init__(self, interface_decl: Union[str, RosDeclaration], - fields: List[RosField], - additional_args: Optional[Dict[str, str]] = None) -> None: + def __init__( + self, + interface_decl: Union[str, RosDeclaration], + fields: List[RosField], + additional_args: Optional[Dict[str, str]] = None, + ) -> None: """ Constructor of a generic ROS trigger. @@ -330,51 +376,68 @@ def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler): def check_validity(self) -> bool: valid_name = is_non_empty_string(self.__class__, "name", self._interface_name) valid_fields = all(isinstance(field, RosField) for field in self._fields) - valid_additional_args = all(is_non_empty_string(self.__class__, arg_name, arg_value) - for arg_name, arg_value in self._additional_args.items()) + valid_additional_args = all( + is_non_empty_string(self.__class__, arg_name, arg_value) + for arg_name, arg_value in self._additional_args.items() + ) if not valid_fields: - print(f"Error: SCXML {self.__class__.__name__}: " - f"invalid entries in fields of {self._interface_name}.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"invalid entries in fields of {self._interface_name}." + ) if not valid_additional_args: - print(f"Error: SCXML {self.__class__.__name__}: " - f"invalid entries in additional arguments of {self._interface_name}.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"invalid entries in additional arguments of {self._interface_name}." + ) return valid_name and valid_fields def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContainer) -> bool: """Check if the ROS interface used in the trigger exists.""" raise NotImplementedError( - f"{self.__class__.__name__} doesn't implement check_interface_defined.") + f"{self.__class__.__name__} doesn't implement check_interface_defined." + ) def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) -> bool: """Check if all fields are assigned, given the ROS interface definition.""" raise NotImplementedError( - f"{self.__class__.__name__} doesn't implement check_fields_validity.") + f"{self.__class__.__name__} doesn't implement check_fields_validity." + ) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: """Translate the ROS interface name to a plain scxml event.""" raise NotImplementedError( - f"{self.__class__.__name__} doesn't implement get_plain_scxml_event.") + f"{self.__class__.__name__} doesn't implement get_plain_scxml_event." + ) - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: """Check if the ROS entries in the trigger are correctly defined.""" - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - f"Error: SCXML {self.__class__.__name__}: invalid type of ROS declarations container." + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), f"Error: SCXML {self.__class__.__name__}: invalid type of ROS declarations container." if not self.check_interface_defined(ros_declarations): - print(f"Error: SCXML {self.__class__.__name__}: " - f"undefined ROS interface {self._interface_name}.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"undefined ROS interface {self._interface_name}." + ) return False if not self.check_fields_validity(ros_declarations): - print(f"Error: SCXML {self.__class__.__name__}: " - f"invalid fields for {self._interface_name}.") + print( + f"Error: SCXML {self.__class__.__name__}: " + f"invalid fields for {self._interface_name}." + ) return False return True def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> ScxmlSend: - assert self.check_valid_ros_instantiations(ros_declarations), \ - f"Error: SCXML {self.__class__.__name__}: invalid ROS instantiations." - assert self._cb_type is not None, \ - f"Error: SCXML {self.__class__.__name__}: {self._interface_name} has no callback type." + assert self.check_valid_ros_instantiations( + ros_declarations + ), f"Error: SCXML {self.__class__.__name__}: invalid ROS instantiations." + assert ( + self._cb_type is not None + ), f"Error: SCXML {self.__class__.__name__}: {self._interface_name} has no callback type." event_name = self.get_plain_scxml_event(ros_declarations) params = [field.as_plain_scxml(ros_declarations) for field in self._fields] for param_name, param_value in self._additional_args.items(): diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_field.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_field.py index be55c70d..be91518b 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_field.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_field.py @@ -20,12 +20,17 @@ from as2fm.scxml_converter.scxml_entries import BtGetValueInputPort, ScxmlParam from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler -from as2fm.scxml_converter.scxml_entries.utils import (ROS_FIELD_PREFIX, - CallbackType, - get_plain_expression, - is_non_empty_string) +from as2fm.scxml_converter.scxml_entries.utils import ( + ROS_FIELD_PREFIX, + CallbackType, + get_plain_expression, + is_non_empty_string, +) from as2fm.scxml_converter.scxml_entries.xml_utils import ( - assert_xml_tag_ok, get_xml_argument, read_value_from_xml_arg_or_child) + assert_xml_tag_ok, + get_xml_argument, + read_value_from_xml_arg_or_child, +) class RosField(ScxmlParam): @@ -40,8 +45,9 @@ def from_xml_tree(xml_tree: ET.Element) -> "RosField": """Create a RosField object from an XML tree.""" assert_xml_tag_ok(RosField, xml_tree) name = get_xml_argument(RosField, xml_tree, "name") - expr = read_value_from_xml_arg_or_child(RosField, xml_tree, "expr", - (BtGetValueInputPort, str)) + expr = read_value_from_xml_arg_or_child( + RosField, xml_tree, "expr", (BtGetValueInputPort, str) + ) return RosField(name, expr) def __init__(self, name: str, expr: Union[BtGetValueInputPort, str]): @@ -52,8 +58,9 @@ def __init__(self, name: str, expr: Union[BtGetValueInputPort, str]): def check_validity(self) -> bool: valid_name = is_non_empty_string(RosField, "name", self._name) - valid_expr = (isinstance(self._expr, BtGetValueInputPort) or - is_non_empty_string(RosField, "expr", self._expr)) + valid_expr = isinstance(self._expr, BtGetValueInputPort) or is_non_empty_string( + RosField, "expr", self._expr + ) return valid_name and valid_expr def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler): @@ -63,11 +70,11 @@ def update_bt_ports_values(self, bt_ports_handler: BtPortsHandler): def as_plain_scxml(self, _) -> ScxmlParam: # In order to distinguish the message body from additional entries, add a prefix to the name - assert self._cb_type is not None, \ - f"Error: SCXML ROS field: {self._name} has not callback type set." + assert ( + self._cb_type is not None + ), f"Error: SCXML ROS field: {self._name} has not callback type set." plain_field_name = ROS_FIELD_PREFIX + self._name - return ScxmlParam(plain_field_name, - expr=get_plain_expression(self._expr, self._cb_type)) + return ScxmlParam(plain_field_name, expr=get_plain_expression(self._expr, self._cb_type)) def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML topic publish field: invalid parameters." diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_service.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_service.py index b2545136..e49c4dce 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_service.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_service.py @@ -24,12 +24,17 @@ from as2fm.scxml_converter.scxml_entries import ScxmlRosDeclarationsContainer from as2fm.scxml_converter.scxml_entries.ros_utils import ( - generate_srv_request_event, generate_srv_response_event, - generate_srv_server_request_event, generate_srv_server_response_event, - is_srv_type_known) -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosDeclaration, - RosTrigger) + generate_srv_request_event, + generate_srv_response_event, + generate_srv_server_request_event, + generate_srv_server_response_event, + is_srv_type_known, +) +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import ( + RosCallback, + RosDeclaration, + RosTrigger, +) from as2fm.scxml_converter.scxml_entries.utils import CallbackType @@ -89,7 +94,8 @@ def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_srv_request_event( ros_declarations.get_service_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) class RosServiceHandleRequest(RosCallback): @@ -112,7 +118,8 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_srv_server_request_event( - ros_declarations.get_service_server_info(self._interface_name)[0]) + ros_declarations.get_service_server_info(self._interface_name)[0] + ) class RosServiceSendResponse(RosTrigger): @@ -134,7 +141,8 @@ def check_fields_validity(self, ros_declarations: ScxmlRosDeclarationsContainer) def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_srv_server_response_event( - ros_declarations.get_service_server_info(self._interface_name)[0]) + ros_declarations.get_service_server_info(self._interface_name)[0] + ) class RosServiceHandleResponse(RosCallback): @@ -158,4 +166,5 @@ def check_interface_defined(self, ros_declarations: ScxmlRosDeclarationsContaine def get_plain_scxml_event(self, ros_declarations: ScxmlRosDeclarationsContainer) -> str: return generate_srv_response_event( ros_declarations.get_service_client_info(self._interface_name)[0], - ros_declarations.get_automaton_name()) + ros_declarations.get_automaton_name(), + ) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_timer.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_timer.py index ddfb981b..2b39d44b 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_timer.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_timer.py @@ -20,14 +20,10 @@ from as2fm.scxml_converter.scxml_entries import ScxmlRosDeclarationsContainer from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler -from as2fm.scxml_converter.scxml_entries.ros_utils import \ - generate_rate_timer_event -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosDeclaration) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - is_non_empty_string) -from as2fm.scxml_converter.scxml_entries.xml_utils import (assert_xml_tag_ok, - get_xml_argument) +from as2fm.scxml_converter.scxml_entries.ros_utils import generate_rate_timer_event +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import RosCallback, RosDeclaration +from as2fm.scxml_converter.scxml_entries.utils import CallbackType, is_non_empty_string +from as2fm.scxml_converter.scxml_entries.xml_utils import assert_xml_tag_ok, get_xml_argument class RosTimeRate(RosDeclaration): @@ -83,7 +79,8 @@ def check_valid_instantiation(self) -> bool: def as_xml(self) -> ET.Element: assert self.check_validity(), "Error: SCXML rate timer: invalid parameters." xml_time_rate = ET.Element( - RosTimeRate.get_tag_name(), {"rate_hz": str(self._rate_hz), "name": self._name}) + RosTimeRate.get_tag_name(), {"rate_hz": str(self._rate_hz), "name": self._name} + ) return xml_time_rate diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_topic.py b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_topic.py index 682c80e7..7350404d 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_ros_topic.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_ros_topic.py @@ -23,11 +23,12 @@ from typing import Type from as2fm.scxml_converter.scxml_entries import ScxmlRosDeclarationsContainer -from as2fm.scxml_converter.scxml_entries.ros_utils import ( - generate_topic_event, is_msg_type_known) -from as2fm.scxml_converter.scxml_entries.scxml_ros_base import (RosCallback, - RosDeclaration, - RosTrigger) +from as2fm.scxml_converter.scxml_entries.ros_utils import generate_topic_event, is_msg_type_known +from as2fm.scxml_converter.scxml_entries.scxml_ros_base import ( + RosCallback, + RosDeclaration, + RosTrigger, +) from as2fm.scxml_converter.scxml_entries.utils import CallbackType diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_state.py b/src/as2fm/scxml_converter/scxml_entries/scxml_state.py index 98ab1fd7..f5850337 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_state.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_state.py @@ -20,16 +20,21 @@ from typing import List, Sequence, Union from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (ScxmlBase, - ScxmlExecutableEntry, - ScxmlExecutionBody, - ScxmlRosDeclarationsContainer, - ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ( + ScxmlBase, + ScxmlExecutableEntry, + ScxmlExecutionBody, + ScxmlRosDeclarationsContainer, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.bt_utils import BtPortsHandler from as2fm.scxml_converter.scxml_entries.scxml_executable_entries import ( - as_plain_execution_body, execution_body_from_xml, - instantiate_exec_body_bt_events, set_execution_body_callback_type, - valid_execution_body) + as_plain_execution_body, + execution_body_from_xml, + instantiate_exec_body_bt_events, + set_execution_body_callback_type, + valid_execution_body, +) from as2fm.scxml_converter.scxml_entries.utils import CallbackType @@ -42,12 +47,14 @@ def get_tag_name() -> str: @staticmethod def _transitions_from_xml(state_id: str, xml_tree: ET.Element) -> List[ScxmlTransition]: - from as2fm.scxml_converter.scxml_entries.scxml_ros_base import \ - RosCallback + from as2fm.scxml_converter.scxml_entries.scxml_ros_base import RosCallback + transitions: List[ScxmlTransition] = [] - tag_to_cls = {cls.get_tag_name(): cls - for cls in ScxmlTransition.__subclasses__() - if cls != RosCallback} + tag_to_cls = { + cls.get_tag_name(): cls + for cls in ScxmlTransition.__subclasses__() + if cls != RosCallback + } tag_to_cls.update({cls.get_tag_name(): cls for cls in RosCallback.__subclasses__()}) tag_to_cls.update({ScxmlTransition.get_tag_name(): ScxmlTransition}) for child in xml_tree: @@ -56,25 +63,30 @@ def _transitions_from_xml(state_id: str, xml_tree: ET.Element) -> List[ScxmlTran elif child.tag in tag_to_cls: transitions.append(tag_to_cls[child.tag].from_xml_tree(child)) else: - assert child.tag in ("onentry", "onexit"), \ - f"Error: SCXML state {state_id}: unexpected tag {child.tag}." + assert child.tag in ( + "onentry", + "onexit", + ), f"Error: SCXML state {state_id}: unexpected tag {child.tag}." return transitions @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlState": """Create a ScxmlState object from an XML tree.""" - assert xml_tree.tag == ScxmlState.get_tag_name(), \ - f"Error: SCXML state: XML tag name is not {ScxmlState.get_tag_name()}." + assert ( + xml_tree.tag == ScxmlState.get_tag_name() + ), f"Error: SCXML state: XML tag name is not {ScxmlState.get_tag_name()}." id_ = xml_tree.attrib.get("id") assert id_ is not None and len(id_) > 0, "Error: SCXML state: id is not valid." scxml_state = ScxmlState(id_) # Get the onentry and onexit execution bodies on_entry = xml_tree.findall("onentry") - assert len(on_entry) <= 1, \ - f"Error: SCXML state: {len(on_entry)} onentry tags found, expected 0 or 1." + assert ( + len(on_entry) <= 1 + ), f"Error: SCXML state: {len(on_entry)} onentry tags found, expected 0 or 1." on_exit = xml_tree.findall("onexit") - assert len(on_exit) <= 1, \ - f"Error: SCXML state: {len(on_exit)} onexit tags found, expected 0 or 1." + assert ( + len(on_exit) <= 1 + ), f"Error: SCXML state: {len(on_exit)} onexit tags found, expected 0 or 1." if len(on_entry) > 0: for exec_entry in execution_body_from_xml(on_entry[0]): scxml_state.append_on_entry(exec_entry) @@ -86,10 +98,14 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlState": scxml_state.add_transition(body_entry) return scxml_state - def __init__(self, state_id: str, *, - on_entry: ScxmlExecutionBody = None, - on_exit: ScxmlExecutionBody = None, - body: List[ScxmlTransition] = None): + def __init__( + self, + state_id: str, + *, + on_entry: ScxmlExecutionBody = None, + on_exit: ScxmlExecutionBody = None, + body: List[ScxmlTransition] = None, + ): """ Initialize a new ScxmlState object. @@ -120,7 +136,7 @@ def set_thread_id(self, thread_idx: int): """Assign the thread ID to the thread-specific transitions in the body.""" for entry in self._on_entry + self._on_exit + self._body: # Assign the thread only to the entries supporting it - if hasattr(entry, 'set_thread_id'): + if hasattr(entry, "set_thread_id"): entry.set_thread_id(thread_idx) def instantiate_bt_events(self, instance_id: str) -> None: @@ -157,8 +173,9 @@ def check_validity(self) -> bool: valid_body = isinstance(self._body, list) if valid_body: for transition in self._body: - valid_transition = isinstance( - transition, ScxmlTransition) and transition.check_validity() + valid_transition = ( + isinstance(transition, ScxmlTransition) and transition.check_validity() + ) if not valid_transition: valid_body = False break @@ -172,8 +189,9 @@ def check_validity(self) -> bool: print(f"Error: SCXML state {self._id}: executable body is not valid.") return valid_on_entry and valid_on_exit and valid_body - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: """Check if the ros instantiations have been declared.""" valid_entry = ScxmlState._check_valid_ros_instantiations(self._on_entry, ros_declarations) valid_exit = ScxmlState._check_valid_ros_instantiations(self._on_exit, ros_declarations) @@ -188,11 +206,13 @@ def check_valid_ros_instantiations(self, @staticmethod def _check_valid_ros_instantiations( - body: Sequence[Union[ScxmlExecutableEntry, ScxmlTransition]], - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + body: Sequence[Union[ScxmlExecutableEntry, ScxmlTransition]], + ros_declarations: ScxmlRosDeclarationsContainer, + ) -> bool: """Check if the ros instantiations have been declared in the body.""" - return (len(body) == 0 or - all(entry.check_valid_ros_instantiations(ros_declarations) for entry in body)) + return len(body) == 0 or all( + entry.check_valid_ros_instantiations(ros_declarations) for entry in body + ) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "ScxmlState": """Convert the ROS-specific entries to be plain SCXML""" @@ -205,25 +225,28 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "Sc plain_entries = entry.as_plain_scxml(ros_declarations) if isinstance(plain_entries, ScxmlTransition): plain_body.append(plain_entries) - elif isinstance(plain_entries, list) and \ - all(isinstance(e, ScxmlTransition) for e in plain_entries): + elif isinstance(plain_entries, list) and all( + isinstance(e, ScxmlTransition) for e in plain_entries + ): # Some special entries return multiple transitions plain_body.extend(plain_entries) else: - raise ValueError(f"Error: SCXML state {self._id}: found invalid transition in " - "state body after conversion to plain SCXML.") + raise ValueError( + f"Error: SCXML state {self._id}: found invalid transition in " + "state body after conversion to plain SCXML." + ) return ScxmlState(self._id, on_entry=plain_entry, on_exit=plain_exit, body=plain_body) def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid state object." xml_state = ET.Element(ScxmlState.get_tag_name(), {"id": self._id}) if len(self._on_entry) > 0: - xml_on_entry = ET.Element('onentry') + xml_on_entry = ET.Element("onentry") for executable_entry in self._on_entry: xml_on_entry.append(executable_entry.as_xml()) xml_state.append(xml_on_entry) if len(self._on_exit) > 0: - xml_on_exit = ET.Element('onexit') + xml_on_exit = ET.Element("onexit") for executable_entry in self._on_exit: xml_on_exit.append(executable_entry.as_xml()) xml_state.append(xml_on_exit) diff --git a/src/as2fm/scxml_converter/scxml_entries/scxml_transition.py b/src/as2fm/scxml_converter/scxml_entries/scxml_transition.py index 9c5aa654..be673935 100644 --- a/src/as2fm/scxml_converter/scxml_entries/scxml_transition.py +++ b/src/as2fm/scxml_converter/scxml_entries/scxml_transition.py @@ -20,23 +20,30 @@ from typing import List, Optional from xml.etree import ElementTree as ET -from as2fm.scxml_converter.scxml_entries import (ScxmlBase, - ScxmlExecutableEntry, - ScxmlExecutionBody, - ScxmlRosDeclarationsContainer) -from as2fm.scxml_converter.scxml_entries.bt_utils import (BtPortsHandler, - is_bt_event, - replace_bt_event) +from as2fm.scxml_converter.scxml_entries import ( + ScxmlBase, + ScxmlExecutableEntry, + ScxmlExecutionBody, + ScxmlRosDeclarationsContainer, +) +from as2fm.scxml_converter.scxml_entries.bt_utils import ( + BtPortsHandler, + is_bt_event, + replace_bt_event, +) from as2fm.scxml_converter.scxml_entries.scxml_executable_entries import ( - execution_body_from_xml, instantiate_exec_body_bt_events, - set_execution_body_callback_type, valid_execution_body, - valid_execution_body_entry_types) -from as2fm.scxml_converter.scxml_entries.utils import (CallbackType, - get_plain_expression) + execution_body_from_xml, + instantiate_exec_body_bt_events, + set_execution_body_callback_type, + valid_execution_body, + valid_execution_body_entry_types, +) +from as2fm.scxml_converter.scxml_entries.utils import CallbackType, get_plain_expression class ScxmlTransition(ScxmlBase): """This class represents a single scxml state.""" + @staticmethod def get_tag_name() -> str: return "transition" @@ -44,8 +51,9 @@ def get_tag_name() -> str: @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlTransition": """Create a ScxmlTransition object from an XML tree.""" - assert xml_tree.tag == ScxmlTransition.get_tag_name(), \ - f"Error: SCXML transition: XML root tag name is not {ScxmlTransition.get_tag_name()}." + assert ( + xml_tree.tag == ScxmlTransition.get_tag_name() + ), f"Error: SCXML transition: XML root tag name is not {ScxmlTransition.get_tag_name()}." target = xml_tree.get("target") assert target is not None, "Error: SCXML transition: target attribute not found." events_str = xml_tree.get("event") @@ -55,9 +63,13 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlTransition": exec_body = exec_body if exec_body is not None else None return ScxmlTransition(target, events, condition, exec_body) - def __init__(self, - target: str, events: Optional[List[str]] = None, condition: Optional[str] = None, - body: Optional[ScxmlExecutionBody] = None): + def __init__( + self, + target: str, + events: Optional[List[str]] = None, + condition: Optional[str] = None, + body: Optional[ScxmlExecutionBody] = None, + ): """ Generate a new transition. Currently, transitions must have a target. @@ -70,15 +82,18 @@ def __init__(self, events = [] if body is None: body = [] - assert isinstance(target, str) and len( - target) > 0, "Error SCXML transition: target must be a non-empty string." - assert isinstance(events, list) and \ - all((isinstance(ev, str) and len(ev) > 0) for ev in events), \ - f"Error SCXML transition: events must be a list of filled strings. Found {events}." - assert condition is None or (isinstance(condition, str) and len(condition) > 0), \ - "Error SCXML transition: condition must be a non-empty string." - assert valid_execution_body_entry_types(body), \ - "Error SCXML transition: invalid body provided." + assert ( + isinstance(target, str) and len(target) > 0 + ), "Error SCXML transition: target must be a non-empty string." + assert isinstance(events, list) and all( + (isinstance(ev, str) and len(ev) > 0) for ev in events + ), f"Error SCXML transition: events must be a list of filled strings. Found {events}." + assert condition is None or ( + isinstance(condition, str) and len(condition) > 0 + ), "Error SCXML transition: condition must be a non-empty string." + assert valid_execution_body_entry_types( + body + ), "Error SCXML transition: invalid body provided." self._target: str = target self._body: ScxmlExecutionBody = body self._events: List[str] = events @@ -123,15 +138,18 @@ def append_body_executable_entry(self, exec_entry: ScxmlExecutableEntry): if self._body is None: self._body = [] self._body.append(exec_entry) - assert valid_execution_body_entry_types(self._body), \ - "Error SCXML transition: invalid body entry found after extension." + assert valid_execution_body_entry_types( + self._body + ), "Error SCXML transition: invalid body entry found after extension." def check_validity(self) -> bool: valid_target = isinstance(self._target, str) and len(self._target) > 0 - valid_events = self._events is None or \ - (isinstance(self._events, list) and all(isinstance(ev, str) for ev in self._events)) + valid_events = self._events is None or ( + isinstance(self._events, list) and all(isinstance(ev, str) for ev in self._events) + ) valid_condition = self._condition is None or ( - isinstance(self._condition, str) and len(self._condition) > 0) + isinstance(self._condition, str) and len(self._condition) > 0 + ) valid_body = self._body is None or valid_execution_body(self._body) if not valid_target: print("Error: SCXML transition: target is not valid.") @@ -145,12 +163,14 @@ def check_validity(self) -> bool: print("Error: SCXML transition: executable content is not valid.") return valid_target and valid_events and valid_condition and valid_body - def check_valid_ros_instantiations(self, - ros_declarations: ScxmlRosDeclarationsContainer) -> bool: + def check_valid_ros_instantiations( + self, ros_declarations: ScxmlRosDeclarationsContainer + ) -> bool: """Check if the ros instantiations have been declared.""" # For SCXML transitions, ROS interfaces can be found only in the exec body - return self._body is None or \ - all(entry.check_valid_ros_instantiations(ros_declarations) for entry in self._body) + return self._body is None or all( + entry.check_valid_ros_instantiations(ros_declarations) for entry in self._body + ) def set_thread_id(self, thread_id: int) -> None: """Set the thread ID for the executable entries of this transition.""" @@ -160,10 +180,12 @@ def set_thread_id(self, thread_id: int) -> None: entry.set_thread_id(thread_id) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "ScxmlTransition": - assert isinstance(ros_declarations, ScxmlRosDeclarationsContainer), \ - "Error: SCXML transition: invalid ROS declarations container." - assert self.check_valid_ros_instantiations(ros_declarations), \ - "Error: SCXML transition: invalid ROS instantiations in transition body." + assert isinstance( + ros_declarations, ScxmlRosDeclarationsContainer + ), "Error: SCXML transition: invalid ROS declarations container." + assert self.check_valid_ros_instantiations( + ros_declarations + ), "Error: SCXML transition: invalid ROS instantiations in transition body." new_body = None set_execution_body_callback_type(self._body, CallbackType.TRANSITION) if self._body is not None: diff --git a/src/as2fm/scxml_converter/scxml_entries/utils.py b/src/as2fm/scxml_converter/scxml_entries/utils.py index 19723fd3..38f71f92 100644 --- a/src/as2fm/scxml_converter/scxml_entries/utils.py +++ b/src/as2fm/scxml_converter/scxml_entries/utils.py @@ -29,46 +29,51 @@ PLAIN_FIELD_EVENT_PREFIX: str = f"{PLAIN_SCXML_EVENT_PREFIX}{ROS_FIELD_PREFIX}" ROS_EVENT_PREFIXES = [ - "_msg.", # Topic-related - "_req.", "_res.", # Service-related - "_goal.", "_feedback.", "_wrapped_result.", "_action." # Action-related + "_msg.", # Topic-related + "_req.", + "_res.", # Service-related + "_goal.", + "_feedback.", + "_wrapped_result.", + "_action.", # Action-related ] # TODO: add lower and upper bounds depending on the n. of bits used. # TODO: add support to uint SCXML_DATA_STR_TO_TYPE: Dict[str, Type] = { - "bool": bool, - "float32": float, - "float64": float, - "int8": int, - "int16": int, - "int32": int, - "int64": int, - "int8[]": MutableSequence[int], # array('i'): https://stackoverflow.com/a/67775675 - "int16[]": MutableSequence[int], - "int32[]": MutableSequence[int], - "int64[]": MutableSequence[int], - "float32[]": MutableSequence[float], # array('d'): https://stackoverflow.com/a/67775675 - "float64[]": MutableSequence[float] + "bool": bool, + "float32": float, + "float64": float, + "int8": int, + "int16": int, + "int32": int, + "int64": int, + "int8[]": MutableSequence[int], # array('i'): https://stackoverflow.com/a/67775675 + "int16[]": MutableSequence[int], + "int32[]": MutableSequence[int], + "int64[]": MutableSequence[int], + "float32[]": MutableSequence[float], # array('d'): https://stackoverflow.com/a/67775675 + "float64[]": MutableSequence[float], } # ------------ Expression-conversion functionalities ------------ class CallbackType(Enum): """Enumeration of the different types of callbacks containing a body.""" - STATE = auto() # No callback (e.g. state entry/exit) - TRANSITION = auto() # Transition callback - ROS_TIMER = auto() # Timer callback - ROS_TOPIC = auto() # Topic callback - ROS_SERVICE_REQUEST = auto() # Service callback - ROS_SERVICE_RESULT = auto() # Service callback - ROS_ACTION_GOAL = auto() # Action callback - ROS_ACTION_RESULT = auto() # Action callback - ROS_ACTION_FEEDBACK = auto() # Action callback + + STATE = auto() # No callback (e.g. state entry/exit) + TRANSITION = auto() # Transition callback + ROS_TIMER = auto() # Timer callback + ROS_TOPIC = auto() # Topic callback + ROS_SERVICE_REQUEST = auto() # Service callback + ROS_SERVICE_RESULT = auto() # Service callback + ROS_ACTION_GOAL = auto() # Action callback + ROS_ACTION_RESULT = auto() # Action callback + ROS_ACTION_FEEDBACK = auto() # Action callback @staticmethod - def get_expected_prefixes(cb_type: 'CallbackType') -> List[str]: + def get_expected_prefixes(cb_type: "CallbackType") -> List[str]: if cb_type in (CallbackType.STATE, CallbackType.ROS_TIMER): return [] elif cb_type == CallbackType.TRANSITION: @@ -87,7 +92,7 @@ def get_expected_prefixes(cb_type: 'CallbackType') -> List[str]: return ["_action.goal_id", "_feedback."] @staticmethod - def get_plain_callback(cb_type: 'CallbackType') -> 'CallbackType': + def get_plain_callback(cb_type: "CallbackType") -> "CallbackType": """Convert ROS-specific transitions to plain ones.""" if cb_type == CallbackType.STATE: return CallbackType.STATE @@ -107,24 +112,30 @@ def _replace_ros_interface_expression(msg_expr: str, expected_prefixes: List[str expected_prefixes.remove(PLAIN_SCXML_EVENT_PREFIX) msg_expr.strip() for prefix in expected_prefixes: - assert prefix.startswith("_"), \ - f"Error: SCXML ROS conversion: prefix {prefix} does not start with underscore." + assert prefix.startswith( + "_" + ), f"Error: SCXML ROS conversion: prefix {prefix} does not start with underscore." if prefix.endswith("."): # Generic field substitution, adding the ROS_FIELD_PREFIX prefix_reg = prefix.replace(".", r"\.") msg_expr = re.sub( rf"(^|[^a-zA-Z0-9_.]){prefix_reg}([a-zA-Z0-9_.])", - rf"\g<1>{PLAIN_FIELD_EVENT_PREFIX}\g<2>", msg_expr) + rf"\g<1>{PLAIN_FIELD_EVENT_PREFIX}\g<2>", + msg_expr, + ) else: # Special fields substitution, no need to add the ROS_FIELD_PREFIX split_prefix = prefix.split(".", maxsplit=1) - assert len(split_prefix) == 2, \ - f"Error: SCXML ROS conversion: prefix {prefix} has no dots." + assert ( + len(split_prefix) == 2 + ), f"Error: SCXML ROS conversion: prefix {prefix} has no dots." substitution = f"{PLAIN_SCXML_EVENT_PREFIX}{split_prefix[1]}" prefix_reg = prefix.replace(".", r"\.") msg_expr = re.sub( rf"(^|[^a-zA-Z0-9_.]){prefix_reg}($|[^a-zA-Z0-9_.])", - rf"\g<1>{substitution}\g<2>", msg_expr) + rf"\g<1>{substitution}\g<2>", + msg_expr, + ) return msg_expr @@ -146,15 +157,17 @@ def get_plain_expression(msg_expr: str, cb_type: CallbackType) -> str: expected_prefixes = CallbackType.get_expected_prefixes(cb_type) # pre-check over the expression if PLAIN_SCXML_EVENT_PREFIX not in expected_prefixes: - assert not _contains_prefixes(msg_expr, [PLAIN_SCXML_EVENT_PREFIX]), \ - "Error: SCXML ROS conversion: "\ + assert not _contains_prefixes(msg_expr, [PLAIN_SCXML_EVENT_PREFIX]), ( + "Error: SCXML ROS conversion: " f"unexpected {PLAIN_SCXML_EVENT_PREFIX} prefix in expr. {msg_expr}" + ) forbidden_prefixes = ROS_EVENT_PREFIXES.copy() if len(expected_prefixes) == 0: forbidden_prefixes.append(PLAIN_SCXML_EVENT_PREFIX) new_expr = _replace_ros_interface_expression(msg_expr, expected_prefixes) - assert not _contains_prefixes(new_expr, forbidden_prefixes), \ - f"Error: SCXML ROS conversion: unexpected ROS interface prefixes in expr.: {msg_expr}" + assert not _contains_prefixes( + new_expr, forbidden_prefixes + ), f"Error: SCXML ROS conversion: unexpected ROS interface prefixes in expr.: {msg_expr}" return new_expr @@ -183,8 +196,10 @@ def is_non_empty_string(scxml_type: Type[ScxmlBase], arg_name: str, arg_value: s """ valid_str = isinstance(arg_value, str) and len(arg_value) > 0 if not valid_str: - print(f"Error: SCXML entry from {scxml_type.__name__}: " - f"Expected non-empty argument {arg_name}, got >{arg_value}<.") + print( + f"Error: SCXML entry from {scxml_type.__name__}: " + f"Expected non-empty argument {arg_name}, got >{arg_value}<." + ) return valid_str @@ -198,7 +213,7 @@ def get_data_type_from_string(data_type: str) -> Optional[Type]: """ data_type = data_type.strip() # If the data type is an array, remove the bound value - if '[' in data_type: + if "[" in data_type: data_type = re.sub(r"(^[a-z0-9]*\[)[0-9]*(\]$)", r"\g<1>\g<2>", data_type) return SCXML_DATA_STR_TO_TYPE.get(data_type, None) @@ -215,8 +230,9 @@ def get_array_max_size(data_type: str) -> Optional[int]: """ Get the maximum size of an array, if the data type is an array. """ - assert is_array_type(get_data_type_from_string(data_type)), \ - f"Error: SCXML data: '{data_type}' is not an array." + assert is_array_type( + get_data_type_from_string(data_type) + ), f"Error: SCXML data: '{data_type}' is not an array." match_obj = re.search(r"\[([0-9]+)\]", data_type) if match_obj is not None: return int(match_obj.group(1)) diff --git a/src/as2fm/scxml_converter/scxml_entries/xml_utils.py b/src/as2fm/scxml_converter/scxml_entries/xml_utils.py index 099590b5..b3207427 100644 --- a/src/as2fm/scxml_converter/scxml_entries/xml_utils.py +++ b/src/as2fm/scxml_converter/scxml_entries/xml_utils.py @@ -21,25 +21,34 @@ def assert_xml_tag_ok(scxml_type: Type[ScxmlBase], xml_tree: ET.Element): """Ensures the xml_tree we are trying to parse has the expected name.""" - assert xml_tree.tag == scxml_type.get_tag_name(), \ - f"SCXML conversion: Expected tag {scxml_type.get_tag_name()}, but got {xml_tree.tag}" - - -def get_xml_argument(scxml_type: Type[ScxmlBase], xml_tree: ET.Element, arg_name: str, *, - none_allowed=False, empty_allowed=False) -> Optional[str]: + assert ( + xml_tree.tag == scxml_type.get_tag_name() + ), f"SCXML conversion: Expected tag {scxml_type.get_tag_name()}, but got {xml_tree.tag}" + + +def get_xml_argument( + scxml_type: Type[ScxmlBase], + xml_tree: ET.Element, + arg_name: str, + *, + none_allowed=False, + empty_allowed=False, +) -> Optional[str]: """Load an argument from the xml tree's root tag.""" arg_value = xml_tree.get(arg_name) error_prefix = f"SCXML conversion of {scxml_type.get_tag_name()}" if arg_value is None: assert none_allowed, f"{error_prefix}: Expected argument {arg_name} in {xml_tree.tag}" elif len(arg_value) == 0: - assert empty_allowed, \ - f"{error_prefix}: Expected non-empty argument {arg_name} in {xml_tree.tag}" + assert ( + empty_allowed + ), f"{error_prefix}: Expected non-empty argument {arg_name} in {xml_tree.tag}" return arg_value def get_children_as_scxml( - xml_tree: ET.Element, scxml_types: Iterable[Type[ScxmlBase]]) -> List[ScxmlBase]: + xml_tree: ET.Element, scxml_types: Iterable[Type[ScxmlBase]] +) -> List[ScxmlBase]: """ Load the children of the xml tree as scxml entries. @@ -58,10 +67,11 @@ def get_children_as_scxml( def read_value_from_xml_child( - xml_tree: ET.Element, - child_tag: str, - valid_types: Iterable[Type[Union[ScxmlBase, str]]], *, - none_allowed: bool = False + xml_tree: ET.Element, + child_tag: str, + valid_types: Iterable[Type[Union[ScxmlBase, str]]], + *, + none_allowed: bool = False, ) -> Optional[Union[str, ScxmlBase]]: """ Try to read the value of a child tag from the xml tree. If the child is not found, return None. @@ -72,8 +82,7 @@ def read_value_from_xml_child( print(f"Error: reading from {xml_tree.tag}: Cannot find child '{child_tag}'.") return None if len(xml_child) > 1: - print( - f"Error: reading from {xml_tree.tag}: multiple children '{child_tag}', expected one.") + print(f"Error: reading from {xml_tree.tag}: multiple children '{child_tag}', expected one.") return None tag_children = [child for child in xml_child[0] if child.tag is not ET.Comment] n_tag_children = len(tag_children) @@ -99,22 +108,28 @@ def read_value_from_xml_child( def read_value_from_xml_arg_or_child( - scxml_type: Type[ScxmlBase], xml_tree: ET.Element, tag_name: str, - valid_types: Iterable[Type[Union[ScxmlBase, str]]], - none_allowed: bool = False) -> Optional[Union[str, ScxmlBase]]: + scxml_type: Type[ScxmlBase], + xml_tree: ET.Element, + tag_name: str, + valid_types: Iterable[Type[Union[ScxmlBase, str]]], + none_allowed: bool = False, +) -> Optional[Union[str, ScxmlBase]]: """ Read a value from an xml attribute or, if not found, the child tag with the same name. To read the value from the xml arguments, valid_types must include string. """ - assert str in valid_types, \ - "Error: read_value_from_arg_or_child: valid_types must include str. " \ + assert str in valid_types, ( + "Error: read_value_from_arg_or_child: valid_types must include str. " "If strings are not expected, use 'read_value_from_xml_child'." + ) read_value = get_xml_argument(scxml_type, xml_tree, tag_name, none_allowed=True) if read_value is None: - read_value = read_value_from_xml_child(xml_tree, tag_name, valid_types, - none_allowed=none_allowed) + read_value = read_value_from_xml_child( + xml_tree, tag_name, valid_types, none_allowed=none_allowed + ) if not none_allowed: - assert read_value is not None, \ - f"Error: SCXML conversion of {scxml_type.get_tag_name()}: Missing argument {tag_name}." + assert ( + read_value is not None + ), f"Error: SCXML conversion of {scxml_type.get_tag_name()}: Missing argument {tag_name}." return read_value diff --git a/src/as2fm/trace_visualizer/main.py b/src/as2fm/trace_visualizer/main.py index da4507ff..1fc19830 100644 --- a/src/as2fm/trace_visualizer/main.py +++ b/src/as2fm/trace_visualizer/main.py @@ -28,29 +28,30 @@ def main_trace_to_png(): """ parser = argparse.ArgumentParser( - description='Converts a trace file produced by smc_storm into two' + - ' images. One image for the first verified trace (if any) and one' + - ' image for the first falsified trace (if any).' + description="Converts a trace file produced by smc_storm into two" + + " images. One image for the first verified trace (if any) and one" + + " image for the first falsified trace (if any)." ) - parser.add_argument('input_fname', type=str, help='The trace as csv file.') + parser.add_argument("input_fname", type=str, help="The trace as csv file.") parser.add_argument( - 'output_png_prefix', type=str, - help='Prefix for the output png file. ' + - 'The output will be saved as _.png' + "output_png_prefix", + type=str, + help="Prefix for the output png file. " + + "The output will be saved as _.png", ) parser.add_argument( - '-l', '--left-to-right', action='store_true', - help='If set, the trace will be visualized from left to right. ' + - 'Otherwise, the trace will be visualized from top to bottom. ' + - '(default: top to bottom)' + "-l", + "--left-to-right", + action="store_true", + help="If set, the trace will be visualized from left to right. " + + "Otherwise, the trace will be visualized from top to bottom. " + + "(default: top to bottom)", ) args = parser.parse_args() traces = Traces(args.input_fname, args.left_to_right) ver, fal = traces.print_info_about_result() if ver is not None: - traces.write_trace_to_img( - ver, args.output_png_prefix + "_verified.png") + traces.write_trace_to_img(ver, args.output_png_prefix + "_verified.png") if fal is not None: - traces.write_trace_to_img( - fal, args.output_png_prefix + "_falsified.png") + traces.write_trace_to_img(fal, args.output_png_prefix + "_falsified.png") diff --git a/src/as2fm/trace_visualizer/visualizer.py b/src/as2fm/trace_visualizer/visualizer.py index 6761dd1f..ddad8467 100644 --- a/src/as2fm/trace_visualizer/visualizer.py +++ b/src/as2fm/trace_visualizer/visualizer.py @@ -23,11 +23,11 @@ import pandas from PIL import Image, ImageDraw, ImageEnhance, ImageFont, ImageOps -LOC_PREFIX = '_loc_' -TRACE_NUMBER = 'Trace number' -RESULT = 'Result' -GLOBAL_TIMER = 'global_timer' -VERIFIED = 'Verified' +LOC_PREFIX = "_loc_" +TRACE_NUMBER = "Trace number" +RESULT = "Result" +GLOBAL_TIMER = "global_timer" +VERIFIED = "Verified" PIXELS_EXTERNAL_BORDER = 2 PIXELS_INTERNAL_BORDER = 1 @@ -66,24 +66,28 @@ def __init__(self, fname: str, left_to_right: bool = False): self._prepare_data(fname) # Precomputations for visualization - self.titles, self.titles_max_height, self.titles_max_width = \ - self._precompute_text() # We swap width and height here because + self.titles, self.titles_max_height, self.titles_max_width = ( + self._precompute_text() + ) # We swap width and height here because # the text was rotated by 90 degrees. self.color_per_automaton = self._get_color_per_automaton() - assert len(self.color_per_automaton) == len(self.automata), \ - 'Must have the same number of automata and colors.' + assert len(self.color_per_automaton) == len( + self.automata + ), "Must have the same number of automata and colors." self.data_per_automaton = self._get_data_per_automaton() - assert len(self.data_per_automaton) == len(self.automata), \ - 'Must have the same number of automata and data.' + assert len(self.data_per_automaton) == len( + self.automata + ), "Must have the same number of automata and data." self.width_per_col = self._get_width_per_col() - assert len(self.width_per_col) > 1, \ - 'Must have more than one pixel no.' + assert len(self.width_per_col) > 1, "Must have more than one pixel no." self.scale_per_col = self._get_scale_per_col() - assert len(self.width_per_col) == len(self.scale_per_col), \ - 'Must have the same number of widths and scale.' + assert len(self.width_per_col) == len( + self.scale_per_col + ), "Must have the same number of widths and scale." self.start_per_column = self._get_start_per_col() - assert len(self.width_per_col) == len(self.start_per_column), \ - 'Must have the same number of widths and starts.' + assert len(self.width_per_col) == len( + self.start_per_column + ), "Must have the same number of widths and starts." self.img_width = self._get_img_width() print(f"{self.img_width=}") @@ -98,14 +102,12 @@ def print_info_about_result(self): else: if falsified is None: falsified = i - print( - 'These are the first verified and falsified traces respectively:') - print(f'{verified=}, {falsified=}') + print("These are the first verified and falsified traces respectively:") + print(f"{verified=}, {falsified=}") return verified, falsified # pylint: disable=too-many-locals - def write_trace_to_img( - self, trace_no: int, fname: str): + def write_trace_to_img(self, trace_no: int, fname: str): """Write one trace to image file. Args: @@ -117,11 +119,9 @@ def write_trace_to_img( trace = self.traces[trace_no] print(trace.df()) data_height = len(trace.df().index) - print(f'{data_height=}') - img_height = text_height + data_height \ - + 2 * PIXELS_EXTERNAL_BORDER + PIXELS_INTERNAL_BORDER - image = Image.new( - 'RGB', (self.img_width, img_height), color='black') + print(f"{data_height=}") + img_height = text_height + data_height + 2 * PIXELS_EXTERNAL_BORDER + PIXELS_INTERNAL_BORDER + image = Image.new("RGB", (self.img_width, img_height), color="black") draw = ImageDraw.Draw(image) # Draw the automata names @@ -133,7 +133,7 @@ def write_trace_to_img( # Draw the data y_data_end = y_data_start + data_height for a in self.automata: - for col in [f'{LOC_PREFIX}{a}'] + self.data_per_automaton[a]: + for col in [f"{LOC_PREFIX}{a}"] + self.data_per_automaton[a]: x_start = self.start_per_column[col] width = self.width_per_col[col] scale = self.scale_per_col[col] @@ -141,11 +141,10 @@ def write_trace_to_img( bg_col = self.color_per_automaton[a][2] fr_col = self.color_per_automaton[a][0] else: - bg_col = 'white' + bg_col = "white" fr_col = self.color_per_automaton[a][1] draw.rectangle( - [x_start, y_data_start, x_start + width - 1, y_data_end - 1], - fill=bg_col + [x_start, y_data_start, x_start + width - 1, y_data_end - 1], fill=bg_col ) y_0: Optional[int] = None for y_data, row in trace.df()[col].items(): @@ -160,76 +159,69 @@ def write_trace_to_img( x = int(row * scale) except TypeError as e: print(e) - print(f'{row=}') - assert x >= 0, f'{x=} must be positive.' - assert x < width, \ - f'{x=} must be smaller than {width=}. ({scale=},' + \ - f' {type(row)=}, {row=})' - draw.point( - (x_start + x, y_data_start + y_start), - fill=fr_col + print(f"{row=}") + assert x >= 0, f"{x=} must be positive." + assert x < width, ( + f"{x=} must be smaller than {width=}. ({scale=}," + + f" {type(row)=}, {row=})" ) + draw.point((x_start + x, y_data_start + y_start), fill=fr_col) # Plot result # find line where Result is not none result: bool = trace.is_verified() - color = 'green' if result else 'red' + color = "green" if result else "red" draw.rectangle( - [PIXELS_EXTERNAL_BORDER, - img_height - PIXELS_EXTERNAL_BORDER - 1, - self.img_width - PIXELS_EXTERNAL_BORDER - 1, - img_height - PIXELS_EXTERNAL_BORDER - 1], - fill=color + [ + PIXELS_EXTERNAL_BORDER, + img_height - PIXELS_EXTERNAL_BORDER - 1, + self.img_width - PIXELS_EXTERNAL_BORDER - 1, + img_height - PIXELS_EXTERNAL_BORDER - 1, + ], + fill=color, ) # If the image is to be left-to-right, flip it such that the leftmost # column is on the bottom. Then data that was plotted from left to # right (increasing) will be plotted from bottom to top. if self.ltr: - image = image.transpose( - Image.Transpose.ROTATE_90) + image = image.transpose(Image.Transpose.ROTATE_90) # Write the image to file image.save(fname) def _prepare_data(self, fname: str): - self.df = pandas.read_csv(fname, sep=';') + self.df = pandas.read_csv(fname, sep=";") self.columns = self.df.columns.values - assert len(self.columns) > 1, 'Must have more than one column.' + assert len(self.columns) > 1, "Must have more than one column." self.traces = self._separate_traces() self.automata = self._get_unique_automata() - assert len(self.automata) > 1, 'Must have more than one automaton.' + assert len(self.automata) > 1, "Must have more than one automaton." def _draw_automata_names(self, image: Image) -> Image: for a in self.automata: - x = self.start_per_column[f'{LOC_PREFIX}{a}'] + x = self.start_per_column[f"{LOC_PREFIX}{a}"] y_start = PIXELS_EXTERNAL_BORDER # bbox = self.titles[a].getbbox() # this_text_height = bbox[3] - bbox[1] # this_text_width = bbox[2] - bbox[0] # print(f'{a=}, {x=}, {y=}, {this_text_width=}, {this_text_height=}') colorized_text = ImageOps.colorize( - self.titles[a], black='black', - white=self.color_per_automaton[a][2]) - image.paste(colorized_text, - box=(x, y_start)) + self.titles[a], black="black", white=self.color_per_automaton[a][2] + ) + image.paste(colorized_text, box=(x, y_start)) return image def _draw_lines(self, draw: ImageDraw.Draw, text_height: int): - y_data_start = PIXELS_EXTERNAL_BORDER + text_height + \ - PIXELS_INTERNAL_BORDER + y_data_start = PIXELS_EXTERNAL_BORDER + text_height + PIXELS_INTERNAL_BORDER for a in self.automata: - x = self.start_per_column[f'{LOC_PREFIX}{a}'] + x = self.start_per_column[f"{LOC_PREFIX}{a}"] bbox = self.titles[a].getbbox() - y_start = PIXELS_EXTERNAL_BORDER + \ - bbox[3] - bbox[1] + PIXELS_EXTERNAL_BORDER + y_start = PIXELS_EXTERNAL_BORDER + bbox[3] - bbox[1] + PIXELS_EXTERNAL_BORDER y_end = y_data_start - 1 - PIXELS_EXTERNAL_BORDER if y_start >= y_end: continue - draw.line( - [x, y_start, x, y_end], - fill=self.color_per_automaton[a][2] - ) + draw.line([x, y_start, x, y_end], fill=self.color_per_automaton[a][2]) return y_data_start def _precompute_text(self): @@ -238,8 +230,7 @@ def _precompute_text(self): max_height = 0 max_width = 0 enhancer = ImageEnhance.Contrast - font_path = os.path.join( - os.path.dirname(__file__), 'data', 'slkscr.ttf') + font_path = os.path.join(os.path.dirname(__file__), "data", "slkscr.ttf") for automaton in self.automata: f = ImageFont.truetype(font_path, 7) bbox = f.getbbox(automaton) @@ -248,7 +239,7 @@ def _precompute_text(self): height = 7 # bbox[3] - bbox[1] max_height = max(max_height, height) # print(f'{automaton=}, {bbox=}, {width=}, {height=}') - txt = Image.new('L', (width, height), color=0) + txt = Image.new("L", (width, height), color=0) d = ImageDraw.Draw(txt) d.text((0, 0), automaton, font=f, fill=255) txt = enhancer(txt).enhance(10.0) @@ -270,23 +261,20 @@ def _precompute_text(self): def _separate_traces(self) -> List[Trace]: """Separates the traces in the dataframe into Trace objects.""" - assert TRACE_NUMBER in self.columns, \ - f'Must have a column named "{TRACE_NUMBER}"' + assert TRACE_NUMBER in self.columns, f'Must have a column named "{TRACE_NUMBER}"' unique_traces = self.df[TRACE_NUMBER].unique() unique_traces.sort() traces = [] for trace in unique_traces: traces.append(Trace(self.df[self.df[TRACE_NUMBER] == trace])) - print(f'{len(traces)=}') + print(f"{len(traces)=}") return traces def _get_unique_automata(self) -> List[str]: """Returns a list of names of automata in the traces.""" - all_automata = sorted([ - x.replace(LOC_PREFIX, '') - for x in self.columns - if x.startswith(LOC_PREFIX) - ]) + all_automata = sorted( + [x.replace(LOC_PREFIX, "") for x in self.columns if x.startswith(LOC_PREFIX)] + ) if GLOBAL_TIMER in all_automata: all_automata.remove(GLOBAL_TIMER) return [GLOBAL_TIMER] + all_automata @@ -295,12 +283,9 @@ def _get_unique_automata(self) -> List[str]: all_automata.reverse() return all_automata - def _get_color_per_automaton(self) -> Dict[ - str, Tuple[ - Tuple[int, int, int], - Tuple[int, int, int], - Tuple[int, int, int]] - ]: + def _get_color_per_automaton( + self, + ) -> Dict[str, Tuple[Tuple[int, int, int], Tuple[int, int, int], Tuple[int, int, int]]]: """Returns a dictionary with the color of each automaton.""" colors = {} random_automata_i = list(range(len(self.automata))) @@ -311,15 +296,15 @@ def _get_color_per_automaton(self) -> Dict[ if automaton == GLOBAL_TIMER: # gray colors[automaton] = ( - _hsv_to_rgb(hue, 0, .5), # dark - _hsv_to_rgb(hue, 0, .7), # mid - _hsv_to_rgb(hue, 0, 1) # light + _hsv_to_rgb(hue, 0, 0.5), # dark + _hsv_to_rgb(hue, 0, 0.7), # mid + _hsv_to_rgb(hue, 0, 1), # light ) else: colors[automaton] = ( - _hsv_to_rgb(hue, 1, .5), # dark - _hsv_to_rgb(hue, 1, .7), # mid - _hsv_to_rgb(hue, .2, 1) # light + _hsv_to_rgb(hue, 1, 0.5), # dark + _hsv_to_rgb(hue, 1, 0.7), # mid + _hsv_to_rgb(hue, 0.2, 1), # light ) return colors @@ -327,14 +312,11 @@ def _get_data_per_automaton(self) -> Dict[str, List[str]]: """Returns a dictionary with the data column names that can be somhow related to that automaton. This is only done by comparing the name, so it is not perfect.""" - data_per_automaton: Dict[str, List[str]] = { - automaton: [] - for automaton in self.automata - } + data_per_automaton: Dict[str, List[str]] = {automaton: [] for automaton in self.automata} for col in self.columns: if col.startswith(LOC_PREFIX): continue - if col.startswith('Unnamed: '): + if col.startswith("Unnamed: "): continue if col == TRACE_NUMBER: continue @@ -362,12 +344,12 @@ def _get_width_per_col(self) -> Dict[str, int]: """ width_per_col = {} for a in self.automata: - width_per_col[f'{LOC_PREFIX}{a}'] = int(self.df[f'{LOC_PREFIX}{a}'].max() + 1) + width_per_col[f"{LOC_PREFIX}{a}"] = int(self.df[f"{LOC_PREFIX}{a}"].max() + 1) # print(width_per_col[f'{LOC_PREFIX}{a}']) # print(self.data_per_automaton[a]) for col in self.data_per_automaton[a]: # print(self.df[col].dtype) - if self.df[col].dtype == 'float64': + if self.df[col].dtype == "float64": width_per_col[col] = int(min(self.df[col].max() + 1, 10)) else: # we assume this is a binary width_per_col[col] = 2 @@ -379,8 +361,8 @@ def _get_scale_per_col(self) -> Dict[str, float]: for col in self.width_per_col: scale_per_col[col] = 1.0 try: - if self.df[col].max()+1 > 10: - scale_per_col[col] = 10.0 / (self.df[col].max()+1) + if self.df[col].max() + 1 > 10: + scale_per_col[col] = 10.0 / (self.df[col].max() + 1) except TypeError as e: print(e) return scale_per_col @@ -394,13 +376,14 @@ def _get_start_per_col(self): for a in self.automata: if start_last_automaton is not None: current_loc = max( - current_loc, start_last_automaton + self.titles_max_width - + PIXELS_INTERNAL_BORDER) + current_loc, + start_last_automaton + self.titles_max_width + PIXELS_INTERNAL_BORDER, + ) start_automaton = current_loc - for col in [f'{LOC_PREFIX}{a}'] + self.data_per_automaton[a]: + for col in [f"{LOC_PREFIX}{a}"] + self.data_per_automaton[a]: start_per_col[col] = current_loc this_width = self.width_per_col[col] - current_loc += (this_width + PIXELS_INTERNAL_BORDER) + current_loc += this_width + PIXELS_INTERNAL_BORDER start_last_automaton = start_automaton return start_per_col @@ -408,6 +391,5 @@ def _get_img_width(self) -> int: """Calculate the width of the image.""" last_col = self.data_per_automaton[self.automata[-1]][-1] return ( - self.start_per_column[last_col] + self.width_per_col[last_col] - + PIXELS_EXTERNAL_BORDER + self.start_per_column[last_col] + self.width_per_col[last_col] + PIXELS_EXTERNAL_BORDER ) diff --git a/test/as2fm_common/test_unittest_ecmascript_interpretation.py b/test/as2fm_common/test_unittest_ecmascript_interpretation.py index 51176caa..405d672f 100644 --- a/test/as2fm_common/test_unittest_ecmascript_interpretation.py +++ b/test/as2fm_common/test_unittest_ecmascript_interpretation.py @@ -20,8 +20,7 @@ import pytest -from as2fm.as2fm_common.ecmascript_interpretation import \ - interpret_ecma_script_expr +from as2fm.as2fm_common.ecmascript_interpretation import interpret_ecma_script_expr class TestEcmascriptInterpreter(unittest.TestCase): @@ -40,7 +39,7 @@ def test_ecmascript_types(self): self.assertEqual(interpret_ecma_script_expr("1.1"), 1.1) self.assertEqual(interpret_ecma_script_expr("true"), True) self.assertEqual(interpret_ecma_script_expr("false"), False) - self.assertEqual(interpret_ecma_script_expr("[1,2,3]"), array('i', [1, 2, 3])) + self.assertEqual(interpret_ecma_script_expr("[1,2,3]"), array("i", [1, 2, 3])) def test_ecmascript_unsupported(self): """ @@ -52,11 +51,11 @@ def test_ecmascript_unsupported(self): src https://alexzhornyak.github.io/SCXML-tutorial/Doc/\ datamodel.html#ecmascript """ - self.assertRaises(ValueError, interpret_ecma_script_expr, "\'this is a string\'") + self.assertRaises(ValueError, interpret_ecma_script_expr, "'this is a string'") self.assertRaises(ValueError, interpret_ecma_script_expr, "null") self.assertRaises(ValueError, interpret_ecma_script_expr, "undefined") self.assertRaises(ValueError, interpret_ecma_script_expr, "new Date()") -if __name__ == '__main__': - pytest.main(['-s', '-v', __file__]) +if __name__ == "__main__": + pytest.main(["-s", "-v", __file__]) diff --git a/test/as2fm_common/test_utilities_smc_storm.py b/test/as2fm_common/test_utilities_smc_storm.py index 4b4ca2e8..7bae2b04 100644 --- a/test/as2fm_common/test_utilities_smc_storm.py +++ b/test/as2fm_common/test_utilities_smc_storm.py @@ -25,10 +25,7 @@ import pytest -def _interpret_output( - output: str, - expected_content: List[str], - not_expected_content: List[str]): +def _interpret_output(output: str, expected_content: List[str], not_expected_content: List[str]): """Interpret the output of the command. Make sure that the expected content is present and that the not expected content is not present.""" @@ -44,26 +41,20 @@ def _run_smc_storm(args: str) -> Tuple[str, str, int]: command = f"smc_storm {args} --max-trace-length 10000 --max-n-traces 10000" print("Running command: ", command) with subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - universal_newlines=True + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, universal_newlines=True ) as process: stdout, stderr = process.communicate() return_code = process.returncode - print(f"stdout: \"\"\"\n{stdout}\"\"\"") - print(f"stderr: \"\"\"\n{stderr}\"\"\"") + print(f'stdout: """\n{stdout}"""') + print(f'stderr: """\n{stderr}"""') print(f"return code: {return_code}") - assert return_code == 0, \ - f"Command failed with return code {return_code}" + assert return_code == 0, f"Command failed with return code {return_code}" return stdout, stderr, return_code def run_smc_storm_with_output( - args: str, - expected_content: List[str], - not_expected_content: List[str]): + args: str, expected_content: List[str], not_expected_content: List[str] +): """Run smc_storm with the given arguments and check if the output is as expected.""" stdout, stderr, result = _run_smc_storm(args) @@ -78,5 +69,5 @@ def test_run_smc_storm(): assert result == 0, "smc_storm failed to run" -if __name__ == '__main__': - pytest.main(['-s', '-vv', __file__]) +if __name__ == "__main__": + pytest.main(["-s", "-vv", __file__]) diff --git a/test/jani_generator/test_systemtest_convince_to_plain_jani.py b/test/jani_generator/test_systemtest_convince_to_plain_jani.py index 2f6e9041..0fd16349 100644 --- a/test/jani_generator/test_systemtest_convince_to_plain_jani.py +++ b/test/jani_generator/test_systemtest_convince_to_plain_jani.py @@ -25,8 +25,9 @@ def test_convince_to_plain_jani(): """ Test the conversion of a CONVINCE robotics Jani model to plain Jani. """ - test_file = os.path.join(os.path.dirname(__file__), '_test_data', 'convince_jani', - 'first-model-mc-version.jani') + test_file = os.path.join( + os.path.dirname(__file__), "_test_data", "convince_jani", "first-model-mc-version.jani" + ) jani_model = JaniModel() assert os.path.isfile(test_file), f"File {test_file} does not exist." convince_jani_parser(jani_model, test_file) diff --git a/test/jani_generator/test_systemtest_scxml_to_jani.py b/test/jani_generator/test_systemtest_scxml_to_jani.py index 25a0bbec..835273f2 100644 --- a/test/jani_generator/test_systemtest_scxml_to_jani.py +++ b/test/jani_generator/test_systemtest_scxml_to_jani.py @@ -25,9 +25,10 @@ from as2fm.jani_generator.jani_entries import JaniAutomaton from as2fm.jani_generator.scxml_helpers.scxml_event import EventsHolder from as2fm.jani_generator.scxml_helpers.scxml_to_jani import ( - convert_multiple_scxmls_to_jani, convert_scxml_root_to_jani_automaton) -from as2fm.jani_generator.scxml_helpers.top_level_interpreter import \ - interpret_top_level_xml + convert_multiple_scxmls_to_jani, + convert_scxml_root_to_jani_automaton, +) +from as2fm.jani_generator.scxml_helpers.top_level_interpreter import interpret_top_level_xml from as2fm.scxml_converter.scxml_entries import ScxmlRoot from ..as2fm_common.test_utilities_smc_storm import run_smc_storm_with_output @@ -64,7 +65,7 @@ def test_basic_example(self): automaton = jani_a.as_dict(constant={}) self.assertEqual(len(automaton["locations"]), 2) - locations = [loc['name'] for loc in automaton["locations"]] + locations = [loc["name"] for loc in automaton["locations"]] self.assertIn("Initial-first-exec", locations) self.assertIn("Initial-first-exec", automaton["initial-locations"]) @@ -73,8 +74,8 @@ def test_battery_drainer(self): Testing conversion with the battery_drainer SCXML files. """ scxml_battery_drainer = os.path.join( - os.path.dirname(__file__), '_test_data', 'battery_example', - 'battery_drainer.scxml') + os.path.dirname(__file__), "_test_data", "battery_example", "battery_drainer.scxml" + ) scxml_root = ScxmlRoot.from_scxml_file(scxml_battery_drainer) jani_a = JaniAutomaton() @@ -85,7 +86,7 @@ def test_battery_drainer(self): self.assertEqual(automaton["name"], "BatteryDrainer") self.assertEqual(len(automaton["locations"]), 4) self.assertEqual(len(automaton["initial-locations"]), 1) - locations = [loc['name'] for loc in automaton["locations"]] + locations = [loc["name"] for loc in automaton["locations"]] self.assertIn(automaton.get("initial-locations")[0], locations) self.assertEqual(len(automaton["edges"]), 4) @@ -101,8 +102,8 @@ def test_battery_manager(self): Testing conversion with the battery_manager SCXML files. """ scxml_battery_manager = os.path.join( - os.path.dirname(__file__), '_test_data', 'battery_example', - 'battery_manager.scxml') + os.path.dirname(__file__), "_test_data", "battery_example", "battery_manager.scxml" + ) scxml_root = ScxmlRoot.from_scxml_file(scxml_battery_manager) jani_a = JaniAutomaton() @@ -114,8 +115,7 @@ def test_battery_manager(self): self.assertEqual(len(automaton["locations"]), 1) self.assertEqual(len(automaton["initial-locations"]), 1) init_location = automaton["locations"][0] - self.assertEqual(init_location['name'], - automaton.get("initial-locations")[0]) + self.assertEqual(init_location["name"], automaton.get("initial-locations")[0]) self.assertEqual(len(automaton["edges"]), 1) # Variables @@ -130,19 +130,17 @@ def test_example_with_sync(self): """ Testing the conversion of two SCXML files with a sync. """ - test_data_folder = os.path.join( - os.path.dirname(__file__), '_test_data', 'battery_example') - scxml_battery_drainer_path = os.path.join( - test_data_folder, 'battery_drainer.scxml') - scxml_battery_manager_path = os.path.join( - test_data_folder, 'battery_manager.scxml') - with open(scxml_battery_drainer_path, 'r', encoding='utf-8') as f: + test_data_folder = os.path.join(os.path.dirname(__file__), "_test_data", "battery_example") + scxml_battery_drainer_path = os.path.join(test_data_folder, "battery_drainer.scxml") + scxml_battery_manager_path = os.path.join(test_data_folder, "battery_manager.scxml") + with open(scxml_battery_drainer_path, "r", encoding="utf-8") as f: scxml_battery_drainer = ScxmlRoot.from_scxml_file(f.read()) - with open(scxml_battery_manager_path, 'r', encoding='utf-8') as f: + with open(scxml_battery_manager_path, "r", encoding="utf-8") as f: scxml_battery_manager = ScxmlRoot.from_scxml_file(f.read()) jani_model = convert_multiple_scxmls_to_jani( - [scxml_battery_drainer, scxml_battery_manager], [], 0, 100) + [scxml_battery_drainer, scxml_battery_manager], [], 0, 100 + ) jani_dict = jani_model.as_dict() # pprint(jani_dict) @@ -161,38 +159,37 @@ def test_example_with_sync(self): self.assertIn({"automaton": "level"}, elements) syncs = jani_dict["system"]["syncs"] self.assertEqual(len(syncs), 4) - self.assertIn({'result': 'level_on_send', - 'synchronise': [ - 'level_on_send', None, 'level_on_send']}, - syncs) - self.assertIn({'result': 'level_on_receive', - 'synchronise': [ - None, 'level_on_receive', 'level_on_receive']}, - syncs) + self.assertIn( + {"result": "level_on_send", "synchronise": ["level_on_send", None, "level_on_send"]}, + syncs, + ) + self.assertIn( + { + "result": "level_on_receive", + "synchronise": [None, "level_on_receive", "level_on_receive"], + }, + syncs, + ) # Check global variables for event variables = jani_dict["variables"] self.assertEqual(len(variables), 2) - self.assertIn({"name": "level.valid", - "type": "bool", - "initial-value": False, - "transient": False}, variables) - self.assertIn({"name": "level.data", - "type": "int", - "initial-value": 0, - "transient": False}, variables) + self.assertIn( + {"name": "level.valid", "type": "bool", "initial-value": False, "transient": False}, + variables, + ) + self.assertIn( + {"name": "level.data", "type": "int", "initial-value": 0, "transient": False}, variables + ) # Check full jani file - test_file = os.path.join( - test_data_folder, 'output.jani') - ground_truth_file = os.path.join( - test_data_folder, 'output_GROUND_TRUTH.jani') + test_file = os.path.join(test_data_folder, "output.jani") + ground_truth_file = os.path.join(test_data_folder, "output_GROUND_TRUTH.jani") if os.path.exists(test_file): os.remove(test_file) - with open(test_file, "w", encoding='utf-8') as output_file: - json.dump(jani_dict, output_file, - indent=4, ensure_ascii=False) - with open(ground_truth_file, "r", encoding='utf-8') as f: + with open(test_file, "w", encoding="utf-8") as output_file: + json.dump(jani_dict, output_file, indent=4, ensure_ascii=False) + with open(ground_truth_file, "r", encoding="utf-8") as f: ground_truth = json.load(f) self.maxDiff = None # pylint: disable=invalid-name self.assertEqual(jani_dict, ground_truth) @@ -204,8 +201,13 @@ def test_example_with_sync(self): # pylint: disable=too-many-arguments, too-many-positional-arguments def _test_with_main( - self, folder: str, store_generated_scxmls: bool = False, - property_name: str = "", success: bool = False, skip_smc: bool = False): + self, + folder: str, + store_generated_scxmls: bool = False, + property_name: str = "", + success: bool = False, + skip_smc: bool = False, + ): """ Testing the conversion of the main.xml file with the entrypoint. @@ -215,10 +217,9 @@ def _test_with_main( :param success: If the property is expected to be always satisfied of always not satisfied. :param skip_smc: If the model shall be executed using SMC (uses smc_storm). """ - test_data_dir = os.path.join( - os.path.dirname(__file__), '_test_data', folder) - xml_main_path = os.path.join(test_data_dir, 'main.xml') - ouput_path = os.path.join(test_data_dir, 'main.jani') + test_data_dir = os.path.join(os.path.dirname(__file__), "_test_data", folder) + xml_main_path = os.path.join(test_data_dir, "main.xml") + ouput_path = os.path.join(test_data_dir, "main.jani") if os.path.exists(ouput_path): os.remove(ouput_path) generated_scxml_path = "generated_plain_scxml" if store_generated_scxmls else None @@ -231,86 +232,86 @@ def _test_with_main( run_smc_storm_with_output( f"--model {ouput_path} --properties-names {property_name}", [property_name, ouput_path, pos_res], - [neg_res]) + [neg_res], + ) # if os.path.exists(ouput_path): # os.remove(ouput_path) def test_battery_ros_example_depleted_success(self): """Test the battery_depleted property is satisfied.""" - self._test_with_main('ros_example', False, 'battery_depleted', True) + self._test_with_main("ros_example", False, "battery_depleted", True) def test_battery_ros_example_over_depleted_fail(self): """Here we expect the property to be *not* satisfied.""" - self._test_with_main('ros_example', False, 'battery_over_depleted', False) + self._test_with_main("ros_example", False, "battery_over_depleted", False) def test_battery_ros_example_alarm_on(self): """Here we expect the property to be *not* satisfied.""" - self._test_with_main('ros_example', False, 'alarm_on', False) + self._test_with_main("ros_example", False, "alarm_on", False) def test_battery_example_w_bt_battery_depleted(self): """Here we expect the property to be *not* satisfied.""" # TODO: Improve properties under evaluation! - self._test_with_main('ros_example_w_bt', True, 'battery_depleted', False) + self._test_with_main("ros_example_w_bt", True, "battery_depleted", False) def test_battery_example_w_bt_main_battery_under_twenty(self): """Here we expect the property to be *not* satisfied.""" # TODO: Improve properties under evaluation! - self._test_with_main('ros_example_w_bt', False, 'battery_below_20', False) + self._test_with_main("ros_example_w_bt", False, "battery_below_20", False) def test_battery_example_w_bt_main_alarm_and_charge(self): """Here we expect the property to be satisfied in a battery example with charging feature.""" - self._test_with_main('ros_example_w_bt', False, 'battery_alarm_on', True) + self._test_with_main("ros_example_w_bt", False, "battery_alarm_on", True) def test_battery_example_w_bt_main_charged_after_time(self): """Here we expect the property to be satisfied in a battery example with charging feature.""" - self._test_with_main('ros_example_w_bt', False, 'battery_charged', True) + self._test_with_main("ros_example_w_bt", False, "battery_charged", True) def test_events_sync_handling(self): """Here we make sure, the synchronization can handle events being sent in different orders without deadlocks.""" - self._test_with_main('events_sync_examples', False, 'seq_check', True) + self._test_with_main("events_sync_examples", False, "seq_check", True) def test_multiple_senders_same_event(self): """Test topic synchronization, handling events being sent in different orders without deadlocks.""" - self._test_with_main('multiple_senders_same_event', False, 'seq_check', True) + self._test_with_main("multiple_senders_same_event", False, "seq_check", True) def test_array_model_basic(self): """Test the array model.""" - self._test_with_main('array_model_basic', False, 'array_check', True) + self._test_with_main("array_model_basic", False, "array_check", True) def test_array_model_additional(self): """Test the array model.""" - self._test_with_main('array_model_additional', False, 'array_check', True) + self._test_with_main("array_model_additional", False, "array_check", True) def test_ros_add_int_srv_example(self): """Test the services are properly handled in Jani.""" - self._test_with_main('ros_add_int_srv_example', True, 'happy_clients', True) + self._test_with_main("ros_add_int_srv_example", True, "happy_clients", True) def test_ros_fibonacci_action_example(self): """Test the actions are properly handled in Jani.""" - self._test_with_main('fibonacci_action_example', True, 'clients_ok', True) + self._test_with_main("fibonacci_action_example", True, "clients_ok", True) def test_ros_fibonacci_action_single_client_example(self): """Test the actions are properly handled in Jani.""" - self._test_with_main('fibonacci_action_single_thread', True, 'client1_ok', True) + self._test_with_main("fibonacci_action_single_thread", True, "client1_ok", True) @pytest.mark.skip(reason="Not yet working. The BT ticking needs some revision.") def test_ros_delib_ws_2024_p1(self): """Test the ROS Deliberation Workshop example works.""" - self._test_with_main('delibws24_p1', True, 'snack_at_table', True) + self._test_with_main("delibws24_p1", True, "snack_at_table", True) def test_robot_navigation_demo(self): """Test the robot demo.""" - self._test_with_main('robot_navigation_tutorial', True, - 'goal_reached', True, skip_smc=True) + self._test_with_main("robot_navigation_tutorial", True, "goal_reached", True, skip_smc=True) def test_robot_navigation_with_bt_demo(self): """Test the robot demo.""" - self._test_with_main('robot_navigation_with_bt', True, 'goal_reached', True, skip_smc=True) + self._test_with_main("robot_navigation_with_bt", True, "goal_reached", True, skip_smc=True) -if __name__ == '__main__': - pytest.main(['-s', '-v', __file__]) +if __name__ == "__main__": + pytest.main(["-s", "-v", __file__]) diff --git a/test/jani_generator/test_unittest_jani_model_loading.py b/test/jani_generator/test_unittest_jani_model_loading.py index 321d3dae..10d8dea9 100644 --- a/test/jani_generator/test_unittest_jani_model_loading.py +++ b/test/jani_generator/test_unittest_jani_model_loading.py @@ -25,9 +25,10 @@ def test_jani_file_loading(): """ Test the loading of a Jani file. """ - jani_file = os.path.join(os.path.dirname(__file__), - '_test_data', 'plain_jani_examples', 'array_test.jani') - with open(jani_file, "r", encoding='utf-8') as file: + jani_file = os.path.join( + os.path.dirname(__file__), "_test_data", "plain_jani_examples", "array_test.jani" + ) + with open(jani_file, "r", encoding="utf-8") as file: convince_jani_json = json.load(file) jani_model = JaniModel.from_dict(convince_jani_json) assert isinstance(jani_model, JaniModel) diff --git a/test/jani_generator/test_unittest_ros_timer.py b/test/jani_generator/test_unittest_ros_timer.py index e484306c..6befa3cc 100644 --- a/test/jani_generator/test_unittest_ros_timer.py +++ b/test/jani_generator/test_unittest_ros_timer.py @@ -19,7 +19,10 @@ from as2fm.jani_generator.jani_entries import JaniAutomaton from as2fm.jani_generator.ros_helpers.ros_timer import ( - GLOBAL_TIMER_TICK_ACTION, RosTimer, make_global_timer_automaton) + GLOBAL_TIMER_TICK_ACTION, + RosTimer, + make_global_timer_automaton, +) def generic_ros_timer_check(rate_hz: float, expected_unit: str, expected_int_period: int): @@ -42,8 +45,9 @@ def get_time_step_from_timer_automaton(automaton: JaniAutomaton) -> int: """ Get the time step from the global timer automaton. """ - global_tick_edge = [edge for edge in automaton.get_edges() - if edge.get_action() == GLOBAL_TIMER_TICK_ACTION] + global_tick_edge = [ + edge for edge in automaton.get_edges() if edge.get_action() == GLOBAL_TIMER_TICK_ACTION + ] assert len(global_tick_edge) == 1, "Expected only one edge advancing the global timer" edge_dict = global_tick_edge[0].as_dict({}) return int(edge_dict["destinations"][0]["assignments"][0]["value"]["right"]) @@ -59,8 +63,9 @@ def generic_global_timer_check(timer_rates: List[float], expected_time_step: int timers.append(RosTimer(f"timer{i}", rate)) jani_automaton = make_global_timer_automaton(timers, max_time_ns) time_step = get_time_step_from_timer_automaton(jani_automaton) - assert time_step == expected_time_step, \ - f"Expected the global timer to advance by {expected_time_step} each time." + assert ( + time_step == expected_time_step + ), f"Expected the global timer to advance by {expected_time_step} each time." def test_ros_timer_10hz(): diff --git a/test/jani_visualizer/jani_visualizer_test.py b/test/jani_visualizer/jani_visualizer_test.py index 5907908f..1dbb936f 100644 --- a/test/jani_visualizer/jani_visualizer_test.py +++ b/test/jani_visualizer/jani_visualizer_test.py @@ -23,25 +23,21 @@ def test_plantumlautomata(): """ Regression test to see if the PlantUML automata are correctly generated. """ - for data_prefix in ['demo_manual', 'ros_example_w_bt']: - test_data_folder = os.path.join( - os.path.dirname(__file__), - '_test_data') - jani_fname = os.path.join(test_data_folder, f'{data_prefix}.jani') + for data_prefix in ["demo_manual", "ros_example_w_bt"]: + test_data_folder = os.path.join(os.path.dirname(__file__), "_test_data") + jani_fname = os.path.join(test_data_folder, f"{data_prefix}.jani") - with open(jani_fname, 'r', encoding='utf-8') as f: + with open(jani_fname, "r", encoding="utf-8") as f: jani_dict = json.load(f) pua = PlantUMLAutomata(jani_dict) puml_str = pua.to_plantuml( with_assignments=True, # default with_guards=True, # default - with_syncs=True # default + with_syncs=True, # default ) # Comparing the generated images with the reference images - output_file = os.path.join( - test_data_folder, 'expected_output', f'{data_prefix}.plantuml') - with open(output_file, 'r', encoding='utf-8') as f: + output_file = os.path.join(test_data_folder, "expected_output", f"{data_prefix}.plantuml") + with open(output_file, "r", encoding="utf-8") as f: expected_content = f.read() - assert puml_str == expected_content, \ - f'The content for {output_file} is not as expected.' + assert puml_str == expected_content, f"The content for {output_file} is not as expected." diff --git a/test/scxml_converter/test_systemtest_scxml_entries.py b/test/scxml_converter/test_systemtest_scxml_entries.py index dbb2b15c..e5a3de75 100644 --- a/test/scxml_converter/test_systemtest_scxml_entries.py +++ b/test/scxml_converter/test_systemtest_scxml_entries.py @@ -17,23 +17,31 @@ from test_utils import canonicalize_xml, remove_empty_lines -from as2fm.scxml_converter.scxml_entries import (BtGetValueInputPort, - BtInputPortDeclaration, - RosField, RosRateCallback, - RosTimeRate, RosTopicCallback, - RosTopicPublish, - RosTopicPublisher, - RosTopicSubscriber, - ScxmlAssign, ScxmlData, - ScxmlDataModel, ScxmlParam, - ScxmlRoot, ScxmlSend, - ScxmlState, ScxmlTransition) +from as2fm.scxml_converter.scxml_entries import ( + BtGetValueInputPort, + BtInputPortDeclaration, + RosField, + RosRateCallback, + RosTimeRate, + RosTopicCallback, + RosTopicPublish, + RosTopicPublisher, + RosTopicSubscriber, + ScxmlAssign, + ScxmlData, + ScxmlDataModel, + ScxmlParam, + ScxmlRoot, + ScxmlSend, + ScxmlState, + ScxmlTransition, +) from as2fm.scxml_converter.scxml_entries.utils import ROS_FIELD_PREFIX def _test_scxml_from_code(scxml_root: ScxmlRoot, ref_file_path: str): # Check output xml - with open(ref_file_path, 'r', encoding='utf-8') as f_o: + with open(ref_file_path, "r", encoding="utf-8") as f_o: expected_output = f_o.read() test_output = scxml_root.as_xml_string() test_xml_string = remove_empty_lines(canonicalize_xml(test_output)) @@ -47,9 +55,10 @@ def _test_xml_parsing(xml_file_path: str, valid_xml: bool = True): if valid_xml: test_output = scxml_root.as_xml_string() test_xml_string = remove_empty_lines(canonicalize_xml(test_output)) - ref_file_path = os.path.join(os.path.dirname(xml_file_path), 'gt_parsed_scxml', - os.path.basename(xml_file_path)) - with open(ref_file_path, 'r', encoding='utf-8') as f_o: + ref_file_path = os.path.join( + os.path.dirname(xml_file_path), "gt_parsed_scxml", os.path.basename(xml_file_path) + ) + with open(ref_file_path, "r", encoding="utf-8") as f_o: ref_xml_string = remove_empty_lines(canonicalize_xml(f_o.read())) assert test_xml_string == ref_xml_string # All the test scxml files we are using contain ROS declarations @@ -63,20 +72,38 @@ def test_battery_drainer_from_code(): Test for scxml_entries generation and conversion to xml. """ battery_drainer_scxml = ScxmlRoot("BatteryDrainer") - battery_drainer_scxml.set_data_model(ScxmlDataModel([ - ScxmlData("battery_percent", "100", "int16")])) + battery_drainer_scxml.set_data_model( + ScxmlDataModel([ScxmlData("battery_percent", "100", "int16")]) + ) use_battery_state = ScxmlState( "use_battery", - on_entry=[ScxmlSend("topic_level_msg", - [ScxmlParam(f"{ROS_FIELD_PREFIX}data", expr="battery_percent")])], - body=[ScxmlTransition("use_battery", ["ros_time_rate.my_timer"], - body=[ScxmlAssign("battery_percent", "battery_percent - 1")]), - ScxmlTransition("use_battery", ["topic_charge_msg"], - body=[ScxmlAssign("battery_percent", "100")])]) + on_entry=[ + ScxmlSend( + "topic_level_msg", [ScxmlParam(f"{ROS_FIELD_PREFIX}data", expr="battery_percent")] + ) + ], + body=[ + ScxmlTransition( + "use_battery", + ["ros_time_rate.my_timer"], + body=[ScxmlAssign("battery_percent", "battery_percent - 1")], + ), + ScxmlTransition( + "use_battery", ["topic_charge_msg"], body=[ScxmlAssign("battery_percent", "100")] + ), + ], + ) battery_drainer_scxml.add_state(use_battery_state, initial=True) - _test_scxml_from_code(battery_drainer_scxml, os.path.join( - os.path.dirname(__file__), '_test_data', 'battery_drainer_w_bt', - 'gt_plain_scxml', 'battery_drainer.scxml')) + _test_scxml_from_code( + battery_drainer_scxml, + os.path.join( + os.path.dirname(__file__), + "_test_data", + "battery_drainer_w_bt", + "gt_plain_scxml", + "battery_drainer.scxml", + ), + ) def test_battery_drainer_ros_from_code(): @@ -102,10 +129,11 @@ def test_battery_drainer_ros_from_code(): - field - if / elseif / else - assign -""" + """ battery_drainer_scxml = ScxmlRoot("BatteryDrainer") - battery_drainer_scxml.set_data_model(ScxmlDataModel([ - ScxmlData("battery_percent", "100", "int16")])) + battery_drainer_scxml.set_data_model( + ScxmlDataModel([ScxmlData("battery_percent", "100", "int16")]) + ) ros_topic_sub = RosTopicSubscriber("charge", "std_msgs/Empty", "sub") ros_topic_pub = RosTopicPublisher("level", "std_msgs/Int32", "pub") ros_timer = RosTimeRate("my_timer", 1) @@ -115,17 +143,29 @@ def test_battery_drainer_ros_from_code(): use_battery_state = ScxmlState("use_battery") use_battery_state.append_on_entry( - RosTopicPublish(ros_topic_pub, [RosField("data", "battery_percent")])) + RosTopicPublish(ros_topic_pub, [RosField("data", "battery_percent")]) + ) use_battery_state.add_transition( - RosRateCallback(ros_timer, "use_battery", None, - [ScxmlAssign("battery_percent", "battery_percent - 1")])) + RosRateCallback( + ros_timer, "use_battery", None, [ScxmlAssign("battery_percent", "battery_percent - 1")] + ) + ) use_battery_state.add_transition( - RosTopicCallback(ros_topic_sub, "use_battery", None, - [ScxmlAssign("battery_percent", "100")])) + RosTopicCallback( + ros_topic_sub, "use_battery", None, [ScxmlAssign("battery_percent", "100")] + ) + ) battery_drainer_scxml.add_state(use_battery_state, initial=True) - _test_scxml_from_code(battery_drainer_scxml, os.path.join( - os.path.dirname(__file__), '_test_data', 'battery_drainer_w_bt', - 'gt_parsed_scxml', 'battery_drainer.scxml')) + _test_scxml_from_code( + battery_drainer_scxml, + os.path.join( + os.path.dirname(__file__), + "_test_data", + "battery_drainer_w_bt", + "gt_parsed_scxml", + "battery_drainer.scxml", + ), + ) def test_bt_action_with_ports_from_code(): @@ -134,12 +174,20 @@ def test_bt_action_with_ports_from_code(): """ data_model = ScxmlDataModel([ScxmlData("number", "0", "int16")]) topic_publisher = RosTopicPublisher(BtGetValueInputPort("name"), "std_msgs/Int16", "answer_pub") - init_state = ScxmlState("initial", body=[ - ScxmlTransition("initial", ["bt_tick"], None, [ - ScxmlAssign("number", BtGetValueInputPort("data")), - RosTopicPublish(topic_publisher, [RosField("data", "number")]) - ]) - ]) + init_state = ScxmlState( + "initial", + body=[ + ScxmlTransition( + "initial", + ["bt_tick"], + None, + [ + ScxmlAssign("number", BtGetValueInputPort("data")), + RosTopicPublish(topic_publisher, [RosField("data", "number")]), + ], + ) + ], + ) scxml_root = ScxmlRoot("BtTopicAction") scxml_root.set_data_model(data_model) scxml_root.add_bt_port_declaration(BtInputPortDeclaration("name", "string")) @@ -149,36 +197,60 @@ def test_bt_action_with_ports_from_code(): assert not scxml_root.check_validity(), "Currently, we handle unspecified BT entries as invalid" scxml_root.set_bt_ports_values([("name", "/sys/add_srv"), ("data", "25")]) scxml_root.update_bt_ports_values() - _test_scxml_from_code(scxml_root, os.path.join( - os.path.dirname(__file__), '_test_data', 'bt_ports_only', - 'gt_parsed_scxml', 'bt_topic_action.scxml')) + _test_scxml_from_code( + scxml_root, + os.path.join( + os.path.dirname(__file__), + "_test_data", + "bt_ports_only", + "gt_parsed_scxml", + "bt_topic_action.scxml", + ), + ) def test_xml_parsing_battery_drainer(): """Test the parsing of the battery drainer scxml file.""" - _test_xml_parsing(os.path.join(os.path.dirname(__file__), '_test_data', - 'battery_drainer_w_bt', 'battery_drainer.scxml')) + _test_xml_parsing( + os.path.join( + os.path.dirname(__file__), "_test_data", "battery_drainer_w_bt", "battery_drainer.scxml" + ) + ) def test_xml_parsing_bt_topic_condition(): """Test the parsing of the bt topic condition scxml file.""" - _test_xml_parsing(os.path.join(os.path.dirname(__file__), '_test_data', - 'battery_drainer_w_bt', 'bt_topic_condition.scxml')) + _test_xml_parsing( + os.path.join( + os.path.dirname(__file__), + "_test_data", + "battery_drainer_w_bt", + "bt_topic_condition.scxml", + ) + ) def test_xml_parsing_invalid_battery_drainer_xml(): """Test the parsing of the battery drainer scxml file with invalid xml.""" - _test_xml_parsing(os.path.join(os.path.dirname(__file__), '_test_data', - 'invalid_xmls', 'battery_drainer.scxml'), valid_xml=False) + _test_xml_parsing( + os.path.join( + os.path.dirname(__file__), "_test_data", "invalid_xmls", "battery_drainer.scxml" + ), + valid_xml=False, + ) def test_xml_parsing_invalid_bt_topic_action_xml(): """Test the parsing of the bt topic action scxml file with invalid xml.""" - _test_xml_parsing(os.path.join(os.path.dirname(__file__), '_test_data', - 'invalid_xmls', 'bt_topic_action.scxml'), valid_xml=False) + _test_xml_parsing( + os.path.join( + os.path.dirname(__file__), "_test_data", "invalid_xmls", "bt_topic_action.scxml" + ), + valid_xml=False, + ) -if __name__ == '__main__': +if __name__ == "__main__": test_battery_drainer_from_code() test_battery_drainer_ros_from_code() test_xml_parsing_battery_drainer() diff --git a/test/scxml_converter/test_systemtest_xml.py b/test/scxml_converter/test_systemtest_xml.py index 68d30ce3..493eff9f 100644 --- a/test/scxml_converter/test_systemtest_xml.py +++ b/test/scxml_converter/test_systemtest_xml.py @@ -24,7 +24,7 @@ def get_output_folder(test_folder: str): """Get the output folder for the test.""" - return os.path.join(os.path.dirname(__file__), '_test_data', test_folder, 'output') + return os.path.join(os.path.dirname(__file__), "_test_data", test_folder, "output") def clear_output_folder(test_folder: str): @@ -38,7 +38,8 @@ def clear_output_folder(test_folder: str): def bt_to_scxml_test( - test_folder: str, bt_file: str, bt_plugins: List[str], store_generated: bool = False): + test_folder: str, bt_file: str, bt_plugins: List[str], store_generated: bool = False +): """ Test the conversion of a BT to SCXML. @@ -47,34 +48,35 @@ def bt_to_scxml_test( :param bt_plugins: The names of the BT plugins scxml files. :param store_generated: If True, the generated SCXML files are stored in the output folder. """ - test_data_path = os.path.join(os.path.dirname(__file__), '_test_data', test_folder) + test_data_path = os.path.join(os.path.dirname(__file__), "_test_data", test_folder) bt_file = os.path.join(test_data_path, bt_file) plugin_files = [os.path.join(test_data_path, f) for f in bt_plugins] scxml_objs = bt_converter(bt_file, plugin_files, 1.0) - assert len(scxml_objs) == 3, \ - f"Expecting 3 scxml objects, found {len(scxml_objs)}." + assert len(scxml_objs) == 3, f"Expecting 3 scxml objects, found {len(scxml_objs)}." if store_generated: clear_output_folder(test_folder) for scxml_obj in scxml_objs: output_file = os.path.join( - get_output_folder(test_folder), f'{scxml_obj.get_name()}.scxml') - with open(output_file, 'w', encoding='utf-8') as f_o: + get_output_folder(test_folder), f"{scxml_obj.get_name()}.scxml" + ) + with open(output_file, "w", encoding="utf-8") as f_o: f_o.write(scxml_obj.as_xml_string()) for scxml_root in scxml_objs: scxml_name = scxml_root.get_name() - gt_scxml_path = os.path.join(test_data_path, 'gt_bt_scxml', - f'{scxml_name}.scxml') - with open(gt_scxml_path, 'r', encoding='utf-8') as f_o: + gt_scxml_path = os.path.join(test_data_path, "gt_bt_scxml", f"{scxml_name}.scxml") + with open(gt_scxml_path, "r", encoding="utf-8") as f_o: gt_xml = remove_empty_lines(canonicalize_xml(f_o.read())) scxml_xml = remove_empty_lines(canonicalize_xml(scxml_root.as_xml_string())) assert scxml_xml == gt_xml -def ros_to_plain_scxml_test(test_folder: str, - scxml_bt_ports: Dict[str, List[Tuple[str, str]]], - expected_scxmls: Dict[str, List[str]], - store_generated: bool = False): +def ros_to_plain_scxml_test( + test_folder: str, + scxml_bt_ports: Dict[str, List[Tuple[str, str]]], + expected_scxmls: Dict[str, List[str]], + store_generated: bool = False, +): """ Test the conversion of SCXML with ROS-specific macros to plain SCXML. @@ -84,8 +86,8 @@ def ros_to_plain_scxml_test(test_folder: str, :param store_generated: If True, the generated SCXML files are stored in the output folder. """ # pylint: disable=too-many-locals - test_data_path = os.path.join(os.path.dirname(__file__), '_test_data', test_folder) - scxml_files = [file for file in os.listdir(test_data_path) if file.endswith('.scxml')] + test_data_path = os.path.join(os.path.dirname(__file__), "_test_data", test_folder) + scxml_files = [file for file in os.listdir(test_data_path) if file.endswith(".scxml")] if store_generated: clear_output_folder(test_folder) for fname in scxml_files: @@ -99,27 +101,32 @@ def ros_to_plain_scxml_test(test_folder: str, plain_scxmls, _ = scxml_obj.to_plain_scxml_and_declarations() if store_generated: for generated_scxml in plain_scxmls: - output_file = os.path.join(get_output_folder(test_folder), - f'{generated_scxml.get_name()}.scxml') - with open(output_file, 'w', encoding='utf-8') as f_o: + output_file = os.path.join( + get_output_folder(test_folder), f"{generated_scxml.get_name()}.scxml" + ) + with open(output_file, "w", encoding="utf-8") as f_o: f_o.write(generated_scxml.as_xml_string()) if fname not in expected_scxmls: - gt_files: List[str] = [fname.removesuffix('.scxml')] + gt_files: List[str] = [fname.removesuffix(".scxml")] else: gt_files: List[str] = expected_scxmls[fname] - assert len(plain_scxmls) == len(gt_files), \ - f"Expecting {len(gt_files)} scxml objects, found {len(plain_scxmls)}." + assert len(plain_scxmls) == len( + gt_files + ), f"Expecting {len(gt_files)} scxml objects, found {len(plain_scxmls)}." for generated_scxml in plain_scxmls: # Make sure the comparison uses snake case scxml_object_name = to_snake_case(generated_scxml.get_name()) - assert scxml_object_name in gt_files, \ - f"Generated SCXML {scxml_object_name} not in gt SCXMLs {gt_files}." - gt_file_path = os.path.join(test_data_path, 'gt_plain_scxml', - f'{scxml_object_name}.scxml') - with open(gt_file_path, 'r', encoding='utf-8') as f_o: + assert ( + scxml_object_name in gt_files + ), f"Generated SCXML {scxml_object_name} not in gt SCXMLs {gt_files}." + gt_file_path = os.path.join( + test_data_path, "gt_plain_scxml", f"{scxml_object_name}.scxml" + ) + with open(gt_file_path, "r", encoding="utf-8") as f_o: gt_output = f_o.read() - assert remove_empty_lines(canonicalize_xml(generated_scxml.as_xml_string())) == \ - remove_empty_lines(canonicalize_xml(gt_output)) + assert remove_empty_lines( + canonicalize_xml(generated_scxml.as_xml_string()) + ) == remove_empty_lines(canonicalize_xml(gt_output)) except Exception as e: print(f"Error in file {fname}:") raise e @@ -127,37 +134,43 @@ def ros_to_plain_scxml_test(test_folder: str, def test_bt_to_scxml_battery_drainer(): """Test the conversion of the battery drainer with BT to SCXML.""" - bt_to_scxml_test('battery_drainer_w_bt', 'bt.xml', - ['bt_topic_action.scxml', 'bt_topic_condition.scxml'], False) + bt_to_scxml_test( + "battery_drainer_w_bt", + "bt.xml", + ["bt_topic_action.scxml", "bt_topic_condition.scxml"], + False, + ) def test_ros_to_plain_scxml_battery_drainer(): """Test the conversion of the battery drainer with ROS macros to plain SCXML.""" - ros_to_plain_scxml_test('battery_drainer_w_bt', {}, {}, True) + ros_to_plain_scxml_test("battery_drainer_w_bt", {}, {}, True) def test_bt_to_scxml_bt_ports(): """Test the conversion of the BT with ports to SCXML.""" - bt_to_scxml_test('bt_ports_only', 'bt.xml', ['bt_topic_action.scxml'], False) + bt_to_scxml_test("bt_ports_only", "bt.xml", ["bt_topic_action.scxml"], False) def test_ros_to_plain_scxml_bt_ports(): """Test the conversion of the BT with ports to plain SCXML.""" - ros_to_plain_scxml_test('bt_ports_only', - {'bt_topic_action.scxml': [('name', 'out'), ('data', '123')]}, {}, - True) + ros_to_plain_scxml_test( + "bt_ports_only", {"bt_topic_action.scxml": [("name", "out"), ("data", "123")]}, {}, True + ) def test_ros_to_plain_scxml_add_int_srv(): """Test the conversion of the add_int_srv_example with ROS macros to plain SCXML.""" - ros_to_plain_scxml_test('add_int_srv_example', {}, {}, True) + ros_to_plain_scxml_test("add_int_srv_example", {}, {}, True) def test_ros_to_plain_scxml_fibonacci_action(): """Test the conversion of the fibonacci_action_example with ROS macros to plain SCXML.""" ros_to_plain_scxml_test( - 'fibonacci_action_example', {}, + "fibonacci_action_example", + {}, {"server.scxml": ["server", "fibonacci_thread_0", "fibonacci_thread_1"]}, - True) + True, + ) diff --git a/test/scxml_converter/test_unittest_scxml_data.py b/test/scxml_converter/test_unittest_scxml_data.py index 5db1739d..cb105316 100644 --- a/test/scxml_converter/test_unittest_scxml_data.py +++ b/test/scxml_converter/test_unittest_scxml_data.py @@ -33,35 +33,30 @@ def test_no_type_information(self): """ Test with no type information should raise a ValueError. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') self.assertRaises(AssertionError, ScxmlData.from_xml_tree, tag) - tag = ET.fromstring( - '') + tag = ET.fromstring('') self.assertRaises(AssertionError, ScxmlData.from_xml_tree, tag) def test_no_expr_information(self): """ Test with no expr information should raise a AssertionError. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') self.assertRaises(AssertionError, ScxmlData.from_xml_tree, tag) def test_no_id_information(self): """ Test with no id information should raise a AssertionError. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') self.assertRaises(AssertionError, ScxmlData.from_xml_tree, tag) def test_regular_int_tag(self): """ Test with regular tag with type int32. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') scxml_data = ScxmlData.from_xml_tree(tag) self.assertEqual(scxml_data.get_name(), "level") self.assertEqual(scxml_data.get_type(), int) @@ -71,8 +66,7 @@ def test_regular_float_tag(self): """ Test with regular tag with type int32. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') scxml_data = ScxmlData.from_xml_tree(tag) self.assertEqual(scxml_data.get_name(), "level_float") self.assertEqual(scxml_data.get_type(), float) @@ -82,8 +76,7 @@ def test_regular_bool_tag(self): """ Test with regular tag with type int32. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') scxml_data = ScxmlData.from_xml_tree(tag) self.assertEqual(scxml_data.get_name(), "condition") self.assertEqual(scxml_data.get_type(), bool) @@ -93,8 +86,7 @@ def test_regular_int_array_tag(self): """ Test with regular tag with type int32. """ - tag = ET.fromstring( - '') + tag = ET.fromstring('') scxml_data = ScxmlData.from_xml_tree(tag) self.assertEqual(scxml_data.get_name(), "some_array") self.assertEqual(scxml_data.get_type(), MutableSequence[int]) @@ -110,8 +102,7 @@ def test_comment_int32(self): environment-XML/batteryDriverCmp.scxml#L11C1-L11C28 """ comment_above = "TYPE level:int32" - tag = ET.fromstring( - '') + tag = ET.fromstring('') scxml_data = ScxmlData.from_xml_tree(tag, comment_above) self.assertEqual(scxml_data.get_name(), "level") self.assertEqual(scxml_data.get_expr(), "0") @@ -126,8 +117,7 @@ def test_invalid_id_in_comment(self): environment-XML/batteryDriverCmp.scxml#L11C1-L11C28 """ comment_above = "TYPE other:int32" - tag = ET.fromstring( - '') + tag = ET.fromstring('') self.assertRaises(AssertionError, ScxmlData.from_xml_tree, tag, comment_above) def test_datamodel_loading(self): @@ -136,13 +126,15 @@ def test_datamodel_loading(self): """ xml_parser = ET.XMLParser(target=ET.TreeBuilder(insert_comments=True)) xml_tree = ET.fromstring( - '' + "" '' '' - '' + "" '' '' - '', xml_parser) + "", + xml_parser, + ) scxml_data_model = ScxmlDataModel.from_xml_tree(xml_tree) data_entries = scxml_data_model.get_data_entries() self.assertEqual(len(data_entries), 4) @@ -152,5 +144,5 @@ def test_datamodel_loading(self): self.assertEqual(data_entries[3].get_name(), "some_array") -if __name__ == '__main__': - pytest.main(['-s', '-v', __file__]) +if __name__ == "__main__": + pytest.main(["-s", "-v", __file__]) diff --git a/test/scxml_converter/test_unittest_scxml_utils.py b/test/scxml_converter/test_unittest_scxml_utils.py index dea5ee88..80e93512 100644 --- a/test/scxml_converter/test_unittest_scxml_utils.py +++ b/test/scxml_converter/test_unittest_scxml_utils.py @@ -20,7 +20,10 @@ import pytest from as2fm.scxml_converter.scxml_entries.utils import ( - CallbackType, get_data_type_from_string, get_plain_expression) + CallbackType, + get_data_type_from_string, + get_plain_expression, +) def test_standard_good_expressions(): @@ -42,7 +45,7 @@ def test_standard_bad_expressions(): "_event.data", "x + y + z == _msg.data", "_action.goal_id == 0", - "x + y + z == 0 && _event.data == 1" + "x + y + z == 0 && _event.data == 1", ] for expr in bad_expressions: with pytest.raises(AssertionError): @@ -57,14 +60,14 @@ def test_topic_good_expressions(): "cos(_msg.data) == 1.0", "some_msg.data + _msg.count", "_msg.x<1 && sin(_msg.angle.x+_msg.angle.y)>2", - "_msg.array_entry[_msg.index] == _msg.index" + "_msg.array_entry[_msg.index] == _msg.index", ] expected_expressions: List[str] = [ "_event.ros_fields__data == 1", "cos(_event.ros_fields__data) == 1.0", "some_msg.data + _event.ros_fields__count", "_event.ros_fields__x<1 && sin(_event.ros_fields__angle.x+_event.ros_fields__angle.y)>2", - "_event.ros_fields__array_entry[_event.ros_fields__index] == _event.ros_fields__index" + "_event.ros_fields__array_entry[_event.ros_fields__index] == _event.ros_fields__index", ] for test_expr, gt_expr in zip(ok_expressions, expected_expressions): conv_expr = get_plain_expression(test_expr, CallbackType.ROS_TOPIC) @@ -77,7 +80,7 @@ def test_topic_bad_expressions(): "_event.data", "x + _res.y + z == _msg.data", "_action.goal_id == 0", - "_wrapped_result.code == 1" + "_wrapped_result.code == 1", ] for expr in bad_expressions: with pytest.raises(AssertionError): @@ -87,15 +90,11 @@ def test_topic_bad_expressions(): def test_action_goal_good_expressions(): """Test expressions that have events related to actions.""" - ok_expressions: List[str] = [ - "some_action.goal_id", - "_action.goal_id", - "_goal.x < 1" - ] + ok_expressions: List[str] = ["some_action.goal_id", "_action.goal_id", "_goal.x < 1"] expected_expressions: List[str] = [ "some_action.goal_id", "_event.goal_id", - "_event.ros_fields__x < 1" + "_event.ros_fields__x < 1", ] for test_expr, gt_expr in zip(ok_expressions, expected_expressions): conv_expr = get_plain_expression(test_expr, CallbackType.ROS_ACTION_GOAL) diff --git a/test/scxml_converter/test_utils.py b/test/scxml_converter/test_utils.py index 9d5d578c..7e0f9990 100644 --- a/test/scxml_converter/test_utils.py +++ b/test/scxml_converter/test_utils.py @@ -23,7 +23,7 @@ def to_snake_case(text: str) -> str: """Convert a string to snake case.""" - return re.sub(r'(? str: @@ -33,7 +33,7 @@ def canonicalize_xml(xml: str) -> str: et = ET.fromstring(xml) for elem in et.iter(): elem.attrib = {k: elem.attrib[k] for k in sorted(elem.attrib.keys())} - return ET.tostring(et, encoding='unicode') + return ET.tostring(et, encoding="unicode") def remove_empty_lines(text: str) -> str: diff --git a/test/trace_visualizer/trace_visualizer_test.py b/test/trace_visualizer/trace_visualizer_test.py index 2519988e..3c4fb76d 100644 --- a/test/trace_visualizer/trace_visualizer_test.py +++ b/test/trace_visualizer/trace_visualizer_test.py @@ -27,33 +27,24 @@ def test_traces(): Regression test to see if the traces are correctly read and the images are correctly generated. """ - data_prefix: str = 'ros_example_w_bt_battery_below_20_p_0_0107' - test_data_folder = os.path.join( - os.path.dirname(__file__), - '_test_data') - csv_file = os.path.join(test_data_folder, f'{data_prefix}.csv') + data_prefix: str = "ros_example_w_bt_battery_below_20_p_0_0107" + test_data_folder = os.path.join(os.path.dirname(__file__), "_test_data") + csv_file = os.path.join(test_data_folder, f"{data_prefix}.csv") traces = Traces(csv_file) ver, fal = traces.print_info_about_result() - assert ver == 73, \ - f"The id of the first verified trace must be 73 but is {ver}." - assert fal == 0, \ - f"The id of the first falsified trace must be 0 but is {fal}." + assert ver == 73, f"The id of the first verified trace must be 73 but is {ver}." + assert fal == 0, f"The id of the first falsified trace must be 0 but is {fal}." # Comparing the generated images with the reference images - output_file = 'test.png' + output_file = "test.png" for i, fname_expected in [ - (ver, f'{data_prefix}_verified.png'), - (fal, f'{data_prefix}_falsified.png')]: - assert not os.path.exists(output_file), \ - f'The file {output_file} already exists.' + (ver, f"{data_prefix}_verified.png"), + (fal, f"{data_prefix}_falsified.png"), + ]: + assert not os.path.exists(output_file), f"The file {output_file} already exists." traces.write_trace_to_img(i, output_file) - assert os.path.exists(output_file), \ - f'The file {output_file} was not created.' - path_expected = os.path.join( - test_data_folder, - 'expected_output', - fname_expected) - with open(output_file, 'rb') as f1, open(path_expected, 'rb') as f2: - assert f1.read() == f2.read(), \ - f'The content for {fname_expected} is not as expected.' + assert os.path.exists(output_file), f"The file {output_file} was not created." + path_expected = os.path.join(test_data_folder, "expected_output", fname_expected) + with open(output_file, "rb") as f1, open(path_expected, "rb") as f2: + assert f1.read() == f2.read(), f"The content for {fname_expected} is not as expected." os.remove(output_file)