From 8e1829f27a0cd0c7d12c016db657e85cd34c50fe Mon Sep 17 00:00:00 2001 From: Christian Henkel <6976069+ct2034@users.noreply.github.com> Date: Wed, 7 Aug 2024 10:11:33 +0200 Subject: [PATCH] Adding the usual Linters (#8) Signed-off-by: Christian Henkel Co-authored-by: Marco Lampacrescia <65171491+MarcoLm993@users.noreply.github.com> --- .github/workflows/lint.yml | 77 ++++++++++++ as2fm_common/pyproject.toml | 8 +- jani_generator/pyproject.toml | 8 +- .../convince_to_plain_jani.py | 9 +- .../jani_generator/jani_entries/__init__.py | 2 + .../jani_entries/jani_assignment.py | 6 +- .../jani_entries/jani_automaton.py | 16 ++- .../jani_entries/jani_composition.py | 14 +-- .../jani_entries/jani_constant.py | 10 +- .../jani_convince_expression_expansion.py | 35 ++++-- .../jani_generator/jani_entries/jani_edge.py | 9 +- .../jani_entries/jani_expression.py | 18 +-- .../jani_generator/jani_entries/jani_model.py | 14 +-- .../jani_entries/jani_property.py | 10 +- .../jani_generator/jani_entries/jani_value.py | 2 +- .../jani_entries/jani_variable.py | 36 +++--- jani_generator/src/jani_generator/main.py | 7 +- .../ros_helpers/ros_services.py | 17 +-- .../scxml_helpers/scxml_data.py | 13 +- .../scxml_helpers/scxml_event.py | 2 + .../scxml_helpers/scxml_expression.py | 3 +- .../scxml_helpers/scxml_tags.py | 21 ++-- .../scxml_helpers/scxml_to_jani.py | 11 +- .../scxml_helpers/top_level_interpreter.py | 82 +++++++------ .../battery_example/output_GROUND_TRUTH.jani | 14 +-- .../test_systemtest_convince_to_plain_jani.py | 2 +- .../test/test_systemtest_scxml_to_jani.py | 6 +- .../test/test_unittest_ros_timer.py | 3 +- .../test/test_unittest_scxml_data.py | 5 +- .../test/test_utilities_smc_storm.py | 3 +- scxml_converter/pyproject.toml | 3 +- .../src/scxml_converter/bt_converter.py | 2 + .../src/scxml_converter/scxml_converter.py | 7 +- .../scxml_converter/scxml_entries/__init__.py | 2 + .../scxml_entries/scxml_base.py | 6 +- .../scxml_entries/scxml_data.py | 14 ++- .../scxml_entries/scxml_data_model.py | 14 ++- .../scxml_entries/scxml_executable_entries.py | 79 ++++++------ .../scxml_entries/scxml_param.py | 25 ++-- .../scxml_entries/scxml_root.py | 33 +++-- .../scxml_entries/scxml_ros_entries.py | 15 ++- .../scxml_entries/scxml_ros_field.py | 10 +- .../scxml_entries/scxml_ros_service.py | 44 +++---- .../scxml_entries/scxml_ros_timer.py | 23 ++-- .../scxml_entries/scxml_ros_topic.py | 23 ++-- .../scxml_entries/scxml_state.py | 114 +++++++++--------- .../scxml_entries/scxml_transition.py | 20 +-- .../scxml_converter/scxml_entries/utils.py | 32 ++--- .../test/test_systemtest_scxml_entries.py | 15 ++- scxml_converter/test/test_systemtest_xml.py | 3 +- 50 files changed, 566 insertions(+), 381 deletions(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..ba784665 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,77 @@ +name: Lint +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: +jobs: + build: + name: ${{ matrix.package }} ⏩ ${{ matrix.linter }} + strategy: + fail-fast: false + matrix: + linter: [ + # "pylint", + # "pycodestyle", + # "flake8", + "mypy", + "isort" + ] + package: [ "jani_generator", "as2fm_common", "scxml_converter" ] + include: + # (for humble): + - python-version: "3.10" + # os: "ubuntu-latest" + runs-on: ubuntu-latest + + steps: + - uses: szenius/set-timezone@v1.0 + with: + timezoneLinux: "Europe/Berlin" + - uses: actions/checkout@v3 + # Get bt_tools TODO: remove after the release of bt_tools + - name: Checkout bt_tools + uses: actions/checkout@v2 + with: + repository: boschresearch/bt_tools + ref: main + path: bt_tools + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Update pip + run: | + pip install --upgrade pip + pip install setuptools_rust + # Install btlib TODO: remove after the release of bt_tools + - name: Install btlib + run: | + cd bt_tools + pip install -e btlib/. + - name: Install packages + run: | + pip install jani_generator/.[dev] + pip install as2fm_common/.[dev] + pip install scxml_converter/.[dev] + - uses: marian-code/python-lint-annotate@v4 + with: + python-root-list: ${{ matrix.package }}/src/${{ matrix.package }} + python-version: ${{ matrix.python-version }} + use-pylint: ${{ matrix.linter == 'pylint' }} + use-pycodestyle: ${{ matrix.linter == 'pycodestyle' }} + use-flake8: ${{ matrix.linter == 'flake8' }} + use-black: false + use-mypy: ${{ matrix.linter == 'mypy' }} + use-isort: ${{ matrix.linter == 'isort' }} + use-vulture: false + use-pydocstyle: false + extra-pylint-options: "" + extra-pycodestyle-options: "" + extra-flake8-options: "" + extra-black-options: "" + extra-mypy-options: "--ignore-missing-imports" + extra-isort-options: "" \ No newline at end of file diff --git a/as2fm_common/pyproject.toml b/as2fm_common/pyproject.toml index b69753d2..754e5042 100644 --- a/as2fm_common/pyproject.toml +++ b/as2fm_common/pyproject.toml @@ -19,10 +19,14 @@ keywords = [] dependencies = [ ] -requires-python = ">=3.7" +requires-python = ">=3.10" [project.optional-dependencies] dev = ["pytest", "pytest-cov", "pycodestyle", "flake8", "mypy", "isort", "bumpver"] [isort] -profile = "google" \ No newline at end of file +profile = "google" +line_length = 100 + +[flake8] +max_line_length = 100 \ No newline at end of file diff --git a/jani_generator/pyproject.toml b/jani_generator/pyproject.toml index 719f662f..05b6f3a9 100644 --- a/jani_generator/pyproject.toml +++ b/jani_generator/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "js2py", "esprima" ] -requires-python = ">=3.7" +requires-python = ">=3.10" [project.optional-dependencies] dev = ["pytest", "pytest-cov", "pycodestyle", "flake8", "mypy", "isort", "bumpver"] @@ -31,4 +31,8 @@ convince_to_plain_jani = "jani_generator.main:main_convince_to_plain_jani" scxml_to_jani = "jani_generator.main:main_scxml_to_jani" [isort] -profile = "google" \ No newline at end of file +profile = "google" +line_length = 100 + +[flake8] +max_line_length = 100 \ No newline at end of file diff --git a/jani_generator/src/jani_generator/convince_jani_helpers/convince_to_plain_jani.py b/jani_generator/src/jani_generator/convince_jani_helpers/convince_to_plain_jani.py index 23e30a58..0b98ab73 100644 --- a/jani_generator/src/jani_generator/convince_jani_helpers/convince_to_plain_jani.py +++ b/jani_generator/src/jani_generator/convince_jani_helpers/convince_to_plain_jani.py @@ -17,12 +17,13 @@ Module to convert convince-flavored robotic, specific jani into plain jani. """ -from typing import List - -from os import path import json -from jani_generator.jani_entries import JaniModel, JaniAutomaton, JaniComposition, JaniProperty from math import degrees +from os import path +from typing import List + +from jani_generator.jani_entries import (JaniAutomaton, JaniComposition, + JaniModel, JaniProperty) def to_cm(value: float) -> int: diff --git a/jani_generator/src/jani_generator/jani_entries/__init__.py b/jani_generator/src/jani_generator/jani_entries/__init__.py index 83b44f3e..b5c1b5b9 100644 --- a/jani_generator/src/jani_generator/jani_entries/__init__.py +++ b/jani_generator/src/jani_generator/jani_entries/__init__.py @@ -1,3 +1,5 @@ +# isort: skip_file +# Skipping file to avoid circular import problem from .jani_value import JaniValue # noqa: F401 from .jani_expression import JaniExpression # noqa: F401 from .jani_constant import JaniConstant # noqa: F401 diff --git a/jani_generator/src/jani_generator/jani_entries/jani_assignment.py b/jani_generator/src/jani_generator/jani_entries/jani_assignment.py index 33013990..597932dc 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_assignment.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_assignment.py @@ -18,8 +18,10 @@ """ from typing import Dict -from jani_generator.jani_entries import JaniExpression, JaniConstant -from jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression + +from jani_generator.jani_entries import JaniConstant, JaniExpression +from jani_generator.jani_entries.jani_convince_expression_expansion import \ + expand_expression class JaniAssignment: diff --git a/jani_generator/src/jani_generator/jani_entries/jani_automaton.py b/jani_generator/src/jani_generator/jani_entries/jani_automaton.py index b722ebee..55f112de 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_automaton.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_automaton.py @@ -15,12 +15,14 @@ """An automaton for jani.""" -from typing import List, Dict, Set, Optional -from jani_generator.jani_entries import JaniEdge, JaniConstant, JaniVariable, JaniExpression +from typing import Any, Dict, List, Optional, Set + +from jani_generator.jani_entries import (JaniConstant, JaniEdge, + JaniExpression, JaniVariable) class JaniAutomaton: - def __init__(self, *, automaton_dict: Optional[dict] = None): + def __init__(self, *, automaton_dict: Optional[Dict[str, Any]] = None): self._locations: Set[str] = set() self._initial_locations: Set[str] = set() self._local_variables: Dict[str, JaniVariable] = {} @@ -84,7 +86,7 @@ def remove_empty_self_loop_edges(self): """Remove all self-loop edges from the automaton.""" self._edges = [edge for edge in self._edges if not edge.is_empty_self_loop()] - def _generate_locations(self, location_list: List[str], initial_locations: List[str]): + def _generate_locations(self, location_list: List[Dict[str, Any]], initial_locations: List[str]): for location in location_list: self._locations.add(location["name"]) for init_location in initial_locations: @@ -107,10 +109,12 @@ def _generate_edges(self, edge_list: List[dict]): jani_edge = JaniEdge(edge) self.add_edge(jani_edge) - def get_actions(self) -> Set[JaniEdge]: + def get_actions(self) -> Set[str]: actions = set() for edge in self._edges: - actions.add(edge.get_action()) + action = edge.get_action() + if action is not None: + actions.add(action) return actions def merge(self, other: 'JaniAutomaton'): diff --git a/jani_generator/src/jani_generator/jani_entries/jani_composition.py b/jani_generator/src/jani_generator/jani_entries/jani_composition.py index cf78db06..75a5ebf4 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_composition.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_composition.py @@ -15,7 +15,7 @@ """This allows the composition of multiple automata in jani.""" -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional class JaniComposition: @@ -52,16 +52,16 @@ def add_sync(self, sync_name: str, syncs: Dict[str, str]): :param sync_name: The name of the synchronization action :param syncs: A dictionary relating each automaton to the action to be executed in the sync """ - new_sync = { - "result": sync_name, - "synchronise": [None] * len(self._elements) - } # Generate the synchronize list + sync_list: List[Optional[str]] = [None] * len(self._elements) for automata, action in syncs.items(): assert automata in self._element_to_id, \ f"Automaton {automata} does not exist in the composition" - new_sync["synchronise"][self._element_to_id[automata]] = action - self._syncs.append(new_sync) + sync_list[self._element_to_id[automata]] = action + self._syncs.append({ + "result": sync_name, + "synchronise": sync_list + }) def get_syncs_for_element(self, element: str) -> List[str]: """Get the existing syncs for a specific element (=automaton).""" diff --git a/jani_generator/src/jani_generator/jani_entries/jani_constant.py b/jani_generator/src/jani_generator/jani_entries/jani_constant.py index 03c37c67..7031ced0 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_constant.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_constant.py @@ -16,8 +16,8 @@ """A constant value expression.""" from typing import Type, Union, get_args -from jani_generator.jani_entries import JaniExpression, JaniValue +from jani_generator.jani_entries import JaniExpression, JaniValue ValidTypes = Union[bool, int, float] @@ -33,13 +33,14 @@ def __init__(self, c_name: str, c_type: Type, c_value: JaniExpression): def name(self) -> str: return self._name - def value(self) -> JaniValue: + def value(self) -> ValidTypes: assert self._value is not None, "Value not set" jani_value = self._value.value assert jani_value is not None and jani_value.is_valid(), "The expression can't be evaluated to a constant value" return jani_value.value() - def jani_type_from_string(str_type: str) -> ValidTypes: + @staticmethod + def jani_type_from_string(str_type: str) -> Type[ValidTypes]: """ Translate a (Jani) type string to a Python type. """ @@ -53,7 +54,8 @@ def jani_type_from_string(str_type: str) -> ValidTypes: raise ValueError(f"Type {str_type} not supported by Jani") # TODO: Move this to a util function file - def jani_type_to_string(c_type: ValidTypes) -> str: + @staticmethod + def jani_type_to_string(c_type: Type[ValidTypes]) -> str: """ Translate a Python type to the name of the type in Jani. diff --git a/jani_generator/src/jani_generator/jani_entries/jani_convince_expression_expansion.py b/jani_generator/src/jani_generator/jani_entries/jani_convince_expression_expansion.py index edaa26c3..9817dd6e 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_convince_expression_expansion.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_convince_expression_expansion.py @@ -15,14 +15,15 @@ """Expand expressions into jani.""" -from typing import Dict -from jani_generator.jani_entries.jani_expression_generator import minus_operator, plus_operator, \ - equal_operator, max_operator, min_operator, greater_equal_operator, lower_operator, \ - and_operator, or_operator, if_operator, multiply_operator, divide_operator, pow_operator, \ - abs_operator, floor_operator, modulo_operator -from jani_generator.jani_entries import JaniExpression, JaniConstant from math import pi +from typing import Dict, Union +from jani_generator.jani_entries import JaniConstant, JaniExpression, JaniValue +from jani_generator.jani_entries.jani_expression_generator import ( + abs_operator, and_operator, divide_operator, equal_operator, + floor_operator, greater_equal_operator, if_operator, lower_operator, + max_operator, min_operator, minus_operator, modulo_operator, + multiply_operator, or_operator, plus_operator, pow_operator) BASIC_EXPRESSIONS_MAPPING = { "-": "-", @@ -277,8 +278,12 @@ def __expression_interpolation( jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "intersect" - robot_name = jani_expression.operands["robot"].identifier - barrier_name = jani_expression.operands["barrier"].identifier + robot_op = jani_expression.operands["robot"] + assert isinstance(robot_op, JaniExpression), "The robot operand must be a JaniExpression" + barrier_op = jani_expression.operands["barrier"] + assert isinstance(barrier_op, JaniExpression), "The barrier operand must be a JaniExpression" + robot_name = robot_op.identifier + barrier_name = barrier_op.identifier if barrier_name == "all": return max_operator( __expression_interpolation_next_boundaries(jani_constants, robot_name, 0), @@ -345,8 +350,12 @@ def __expression_distance( jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "distance" - robot_name = jani_expression.operands["robot"].identifier - barrier_name = jani_expression.operands["barrier"].identifier + robot_op = jani_expression.operands["robot"] + assert isinstance(robot_op, JaniExpression), "The robot operand must be a JaniExpression" + barrier_op = jani_expression.operands["barrier"] + assert isinstance(barrier_op, JaniExpression), "The barrier operand must be a JaniExpression" + robot_name = robot_op.identifier + barrier_name = barrier_op.identifier if barrier_name == "all": return min_operator( __expression_distance_next_boundaries(jani_constants, robot_name, 0), @@ -362,7 +371,9 @@ def __expression_distance_to_point( jani_expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: assert isinstance(jani_expression, JaniExpression), "The input must be a JaniExpression" assert jani_expression.op == "distance_to_point" - robot_name = jani_expression.operands["robot"].identifier + robot_op = jani_expression.operands["robot"] + assert isinstance(robot_op, JaniExpression), "The robot operand must be a JaniExpression" + robot_name = robot_op.identifier target_x_cm = to_cm_operator(expand_expression(jani_expression.operands["x"], jani_constants)) target_y_cm = to_cm_operator(expand_expression(jani_expression.operands["y"], jani_constants)) robot_x_cm = f"robots.{robot_name}.pose.x_cm" @@ -380,7 +391,7 @@ def __substitute_expression_op(expression: JaniExpression) -> JaniExpression: def expand_expression( - expression: JaniExpression, jani_constants: Dict[str, JaniConstant]) -> JaniExpression: + expression: Union[JaniExpression, JaniValue], jani_constants: Dict[str, JaniConstant]) -> JaniExpression: # Given a CONVINCE JaniExpression, expand it to a plain JaniExpression assert isinstance(expression, JaniExpression), \ f"The expression should be a JaniExpression instance, found {type(expression)} instead." diff --git a/jani_generator/src/jani_generator/jani_entries/jani_edge.py b/jani_generator/src/jani_generator/jani_entries/jani_edge.py index ce024857..e50a67f0 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_edge.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_edge.py @@ -16,14 +16,17 @@ """And edge defining the possible transition from one state to another in jani.""" from typing import Dict, Optional -from jani_generator.jani_entries import JaniGuard, JaniAssignment, JaniExpression, JaniConstant -from jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression + +from jani_generator.jani_entries import (JaniAssignment, JaniConstant, + JaniExpression, JaniGuard) +from jani_generator.jani_entries.jani_convince_expression_expansion import \ + expand_expression class JaniEdge: def __init__(self, edge_dict: dict): self.location = edge_dict["location"] - self.action: str = None + self.action: Optional[str] = None if "action" in edge_dict: self.action = edge_dict["action"] self.guard = None diff --git a/jani_generator/src/jani_generator/jani_entries/jani_expression.py b/jani_generator/src/jani_generator/jani_entries/jani_expression.py index 90d61ceb..0e161ccd 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_expression.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_expression.py @@ -17,7 +17,8 @@ Expressions in Jani """ -from typing import Dict, Union +from typing import Any, Dict, Optional, Union + from jani_generator.jani_entries import JaniValue SupportedExp = Union[str, int, float, bool, dict] @@ -36,10 +37,10 @@ class JaniExpression: - operands: a dictionary of operands, related to the specified operator """ def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]): - self.identifier: str = None - self.value: JaniValue = None - self.op = None - self.operands: Dict[str, Union[JaniExpression, JaniValue]] = None + self.identifier: Optional[str] = None + self.value: Optional[JaniValue] = None + self.op: Optional[str] = None + self.operands: Dict[str, Union[JaniExpression, JaniValue]] = {} if isinstance(expression, JaniExpression): self.identifier = expression.identifier self.value = expression.value @@ -48,8 +49,8 @@ def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]) elif isinstance(expression, JaniValue): self.value = expression else: - assert isinstance(expression, SupportedExp), \ - f"Unexpected expression type: {type(expression)} should be a dict or a base type." + if (not isinstance(expression, SupportedExp)): # type: ignore + raise RuntimeError(f"Unexpected expression type: {type(expression)} should be a dict or a base type.") if isinstance(expression, str): # If it is a reference to a constant or variable, we do not need to expand further self.identifier = expression @@ -66,6 +67,7 @@ def __init__(self, expression: Union[SupportedExp, 'JaniExpression', JaniValue]) self.operands = self._get_operands(expression) def _get_operands(self, expression_dict: dict): + assert self.op is not None, "Operator not set" if (self.op in ("intersect", "distance")): # intersect: returns a value in [0.0, 1.0], indicating where on the robot trajectory # the intersection occurs. @@ -142,7 +144,7 @@ def as_dict(self) -> Union[str, int, float, bool, dict]: return self.identifier if self.value is not None: return self.value.as_dict() - op_dict = { + op_dict: Dict[str, Any] = { "op": self.op, } for op_key, op_value in self.operands.items(): diff --git a/jani_generator/src/jani_generator/jani_entries/jani_model.py b/jani_generator/src/jani_generator/jani_entries/jani_model.py index ce00f99b..8dfc5e59 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_model.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_model.py @@ -18,11 +18,11 @@ """ -from typing import List, Dict, Optional, Union, Type -from jani_generator.jani_entries import ( - JaniValue, JaniVariable, JaniConstant, JaniAutomaton, JaniComposition, - JaniProperty, JaniExpression) +from typing import Dict, List, Optional, Type, Union +from jani_generator.jani_entries import (JaniAutomaton, JaniComposition, + JaniConstant, JaniExpression, + JaniProperty, JaniValue, JaniVariable) ValidValue = Union[int, float, bool, dict, JaniExpression] @@ -108,9 +108,9 @@ def _generate_missing_syncs(self): "We expect there to be explicit syncs for all automata." for automaton in self._automata: existing_syncs = self._system.get_syncs_for_element(automaton.get_name()) - for edge in automaton.get_actions(): - if edge not in existing_syncs: - self._system.add_sync(edge, {automaton.get_name(): edge}) + for action in automaton.get_actions(): + if action not in existing_syncs: + self._system.add_sync(action, {automaton.get_name(): action}) def add_jani_property(self, property: JaniProperty): self._properties.append(property) diff --git a/jani_generator/src/jani_generator/jani_entries/jani_property.py b/jani_generator/src/jani_generator/jani_entries/jani_property.py index 940cca28..1a0e3d86 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_property.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_property.py @@ -18,9 +18,11 @@ """ -from typing import Dict, Any +from typing import Any, Dict, Union + from jani_generator.jani_entries import JaniConstant, JaniExpression -from jani_generator.jani_entries.jani_convince_expression_expansion import expand_expression +from jani_generator.jani_entries.jani_convince_expression_expansion import \ + expand_expression class FilterProperty: @@ -34,7 +36,7 @@ def __init__(self, property_filter_exp: Dict[str, Any]): self._process_values(property_filter_exp["values"]) def _process_values(self, prop_values: Dict[str, Any]) -> None: - self._values = ProbabilityProperty(prop_values) + self._values: Union[ProbabilityProperty, RewardProperty, NumPathsProperty] = ProbabilityProperty(prop_values) if self._values.is_valid(): return self._values = RewardProperty(prop_values) @@ -44,6 +46,8 @@ def _process_values(self, prop_values: Dict[str, Any]) -> None: assert self._values.is_valid(), "Unexpected values in FilterProperty" def as_dict(self, constants: Dict[str, JaniConstant]): + assert isinstance(self._values, ProbabilityProperty), \ + "Only ProbabilityProperty is supported in FilterProperty" return { "op": "filter", "fun": self._fun, diff --git a/jani_generator/src/jani_generator/jani_entries/jani_value.py b/jani_generator/src/jani_generator/jani_entries/jani_value.py index 434f38df..3ad07fab 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_value.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_value.py @@ -17,8 +17,8 @@ Values in Jani """ -from typing import Union from math import e, pi +from typing import Union class JaniValue: diff --git a/jani_generator/src/jani_generator/jani_entries/jani_variable.py b/jani_generator/src/jani_generator/jani_entries/jani_variable.py index 851ca478..93ec8e86 100644 --- a/jani_generator/src/jani_generator/jani_entries/jani_variable.py +++ b/jani_generator/src/jani_generator/jani_entries/jani_variable.py @@ -18,55 +18,56 @@ """ from typing import Optional, Union, get_args -from jani_generator.jani_entries import JaniExpression, JaniValue from as2fm_common.common import ValidTypes +from jani_generator.jani_entries import JaniExpression, JaniValue class JaniVariable: def __init__(self, v_name: str, v_type: ValidTypes, - v_init_value: Optional[Union[JaniExpression, JaniValue]] = None, + init_value: Optional[Union[JaniExpression, JaniValue]] = None, v_transient: bool = False): - assert v_init_value is None or isinstance(v_init_value, (JaniExpression, JaniValue)), \ + assert init_value is None or isinstance(init_value, (JaniExpression, JaniValue)), \ "Init value should be a JaniExpression or a JaniValue" - if v_init_value is not None: - if isinstance(v_init_value, JaniExpression): - self._init_value = v_init_value - else: # In this case it can only be a JaniValue - self._init_value = JaniExpression(v_init_value) - assert v_type in get_args(ValidTypes), f"Type {v_type} not supported by Jani" self._name = v_name - self._init_value = v_init_value self._type = v_type self._transient = v_transient - # Some Model Checkers really need them to be defined to have a unique initial state - if self._init_value is None: + self._init_expr: Optional[JaniExpression] = None + if init_value is not None: + self._init_expr = JaniExpression(init_value) + else: + # Some Model Checkers need a explicit initial value. if self._type == int: - self._init_value = JaniExpression(0) + self._init_expr = JaniExpression(0) elif self._type == bool: - self._init_value = JaniExpression(False) + self._init_expr = JaniExpression(False) elif self._type == float: - self._init_value = JaniExpression(0.0) + self._init_expr = JaniExpression(0.0) + assert v_type in get_args(ValidTypes), f"Type {v_type} not supported by Jani" if not self._transient and self._type == float: print(f"Warning: Variable {self._name} is not transient and has type float." "This is not supported by STORM yet.") def name(self): + """Get name.""" return self._name def get_type(self): + """Get type.""" return self._type def as_dict(self): + """Return the variable as a dictionary.""" d = { "name": self._name, "type": JaniVariable.jani_type_to_string(self._type), "transient": self._transient } - if self._init_value is not None: - d["initial-value"] = self._init_value.as_dict() + if self._init_expr is not None: + d["initial-value"] = self._init_expr.as_dict() return d + @staticmethod def jani_type_from_string(str_type: str) -> ValidTypes: """ Translate a (Jani) type string to a Python type. @@ -80,6 +81,7 @@ def jani_type_from_string(str_type: str) -> ValidTypes: else: raise ValueError(f"Type {str_type} not supported by Jani") + @staticmethod def jani_type_to_string(v_type: ValidTypes) -> str: """ Translate a Python type to the name of the type in Jani. diff --git a/jani_generator/src/jani_generator/main.py b/jani_generator/src/jani_generator/main.py index 4ebc26d5..9d0e0453 100644 --- a/jani_generator/src/jani_generator/main.py +++ b/jani_generator/src/jani_generator/main.py @@ -16,14 +16,15 @@ # limitations under the License. import argparse +import json import os import timeit -import json from typing import Optional, Sequence -from jani_generator.jani_entries import JaniModel from jani_generator.convince_jani_helpers import convince_jani_parser -from jani_generator.scxml_helpers.top_level_interpreter import interpret_top_level_xml +from jani_generator.jani_entries import JaniModel +from jani_generator.scxml_helpers.top_level_interpreter import \ + interpret_top_level_xml def main_convince_to_plain_jani(_args: Optional[Sequence[str]] = None) -> None: diff --git a/jani_generator/src/jani_generator/ros_helpers/ros_services.py b/jani_generator/src/jani_generator/ros_helpers/ros_services.py index 9f1869e4..62616c0f 100644 --- a/jani_generator/src/jani_generator/ros_helpers/ros_services.py +++ b/jani_generator/src/jani_generator/ros_helpers/ros_services.py @@ -18,16 +18,17 @@ """ from typing import Dict, List, Optional -from scxml_converter.scxml_entries import ( - ScxmlRoot, ScxmlData, ScxmlDataModel, ScxmlState, - ScxmlParam, ScxmlAssign, ScxmlTransition, ScxmlSend) -from scxml_converter.scxml_entries.utils import ( - get_srv_type_params, sanitize_ros_interface_name, generate_srv_request_event, - generate_srv_response_event, generate_srv_server_request_event, - generate_srv_server_response_event, get_default_expression_for_type) from jani_generator.jani_entries import JaniModel - +from scxml_converter.scxml_entries import (ScxmlAssign, ScxmlData, + ScxmlDataModel, ScxmlParam, + ScxmlRoot, ScxmlSend, ScxmlState, + ScxmlTransition) +from scxml_converter.scxml_entries.utils import ( + generate_srv_request_event, generate_srv_response_event, + generate_srv_server_request_event, generate_srv_server_response_event, + get_default_expression_for_type, get_srv_type_params, + sanitize_ros_interface_name) SRV_PREFIX = "srv_handler_" diff --git a/jani_generator/src/jani_generator/scxml_helpers/scxml_data.py b/jani_generator/src/jani_generator/scxml_helpers/scxml_data.py index 18ae9bc3..8fa5a640 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/scxml_data.py +++ b/jani_generator/src/jani_generator/scxml_helpers/scxml_data.py @@ -22,8 +22,7 @@ from typing import Dict, List, Optional, get_args from as2fm_common.common import ros_type_name_to_python_type -from as2fm_common.ecmascript_interpretation import \ - interpret_ecma_script_expr +from as2fm_common.ecmascript_interpretation import interpret_ecma_script_expr from jani_generator.jani_entries.jani_expression import JaniExpression from jani_generator.jani_entries.jani_variable import JaniVariable, ValidTypes @@ -79,16 +78,18 @@ def __init__(self, element: ET.Element, else self.type()) def _interpret_type_from_comment_above( - self, comment_above: Optional[str]) -> Optional[type]: + self, comment_above: Optional[str]) -> Optional[Dict[str, type]]: """Interpret the type of the data from the comment above the data tag. :param comment_above: The comment above the data tag (optional) - :return: The type of the data + :return: The type of the data, None if not found """ if comment_above is None: return None # match string inside xml comment brackets match = re.match(r'', comment_above.strip()) + if match is None: + return None comment_content = match.group(1).strip() if 'TYPE' not in comment_content: return None @@ -140,9 +141,9 @@ def _evalute_possible_types( if len(types) == 0: raise ValueError( f"Could not determine type for data {self.id}") - if len(types) == 1: + elif len(types) == 1: return types.pop() - if len(types) > 1: + else: # len(types) > 1 raise ValueError( f"Multiple types found for data {self.id}: {types}") diff --git a/jani_generator/src/jani_generator/scxml_helpers/scxml_event.py b/jani_generator/src/jani_generator/scxml_helpers/scxml_event.py index 2fa97a24..90359e1b 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/scxml_event.py +++ b/jani_generator/src/jani_generator/scxml_helpers/scxml_event.py @@ -72,6 +72,8 @@ def get_receivers(self) -> List[EventReceiver]: def get_data_structure(self) -> Dict[str, type]: """Get the data structure of the event.""" + if self.data_struct is None: + return {} return self.data_struct def set_data_structure(self, data_struct: Dict[str, type]): diff --git a/jani_generator/src/jani_generator/scxml_helpers/scxml_expression.py b/jani_generator/src/jani_generator/scxml_helpers/scxml_expression.py index 7813e3dc..fb0f4848 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/scxml_expression.py +++ b/jani_generator/src/jani_generator/scxml_helpers/scxml_expression.py @@ -19,9 +19,10 @@ import esprima +from jani_generator.jani_entries.jani_convince_expression_expansion import \ + BASIC_EXPRESSIONS_MAPPING from jani_generator.jani_entries.jani_expression import JaniExpression from jani_generator.jani_entries.jani_value import JaniValue -from jani_generator.jani_entries.jani_convince_expression_expansion import BASIC_EXPRESSIONS_MAPPING def parse_ecmascript_to_jani_expression(ecmascript: str) -> JaniExpression: diff --git a/jani_generator/src/jani_generator/scxml_helpers/scxml_tags.py b/jani_generator/src/jani_generator/scxml_helpers/scxml_tags.py index 167a6ef3..a8586954 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/scxml_tags.py +++ b/jani_generator/src/jani_generator/scxml_helpers/scxml_tags.py @@ -21,17 +21,15 @@ from hashlib import sha256 from typing import Dict, List, Optional, Set, Tuple, Union +from as2fm_common.ecmascript_interpretation import interpret_ecma_script_expr from jani_generator.jani_entries import (JaniAssignment, JaniAutomaton, JaniEdge, JaniExpression, JaniGuard, JaniVariable) from jani_generator.jani_entries.jani_expression_generator import ( and_operator, not_operator) - from jani_generator.scxml_helpers.scxml_event import Event, EventsHolder from jani_generator.scxml_helpers.scxml_expression import \ parse_ecmascript_to_jani_expression -from as2fm_common.ecmascript_interpretation import \ - interpret_ecma_script_expr from scxml_converter.scxml_entries import (ScxmlAssign, ScxmlBase, ScxmlData, ScxmlDataModel, ScxmlExecutionBody, ScxmlIf, ScxmlRoot, ScxmlSend, @@ -103,7 +101,7 @@ def _append_scxml_body_to_jani_automaton(jani_automaton: JaniAutomaton, events_h body: ScxmlExecutionBody, source: str, target: str, hash_str: str, guard: Optional[JaniGuard], trigger_event: Optional[str]) \ - -> Tuple[List[JaniEdge], List[str]]: + -> Tuple[List[JaniEdge], List[str]]: """ Converts the body of an SCXML element to a set of locations and edges. @@ -184,7 +182,7 @@ def _append_scxml_body_to_jani_automaton(jani_automaton: JaniAutomaton, events_h interm_loc_before = f"{source}_{i}_before_if" interm_loc_after = f"{source}_{i}_after_if" new_edges[-1].destinations[0]['location'] = interm_loc_before - previous_conditions = [] + previous_conditions: List[JaniExpression] = [] for cond_str, conditional_body in ec.get_conditional_executions(): print(f"Condition: {cond_str}") print(f"Body: {conditional_body}") @@ -321,7 +319,7 @@ def handle_entry_state(self): initial_state_id = self.element.get_initial_state_id() initial_state = self.element.get_state_by_id(initial_state_id) # Make sure we execute the onentry block of the initial state at the start - if initial_state.get_onentry() is not None: + if len(initial_state.get_onentry()) > 0: source_state = f"{initial_state_id}-first-exec" target_state = initial_state_id onentry_body = initial_state.get_onentry() @@ -392,10 +390,9 @@ def get_guard_for_prev_conditions(self, event_name: str) -> Optional[JaniGuard]: parse_ecmascript_to_jani_expression(cond) for cond in self._event_to_conditions.get(event_name, [])] if len(previous_expressions) > 0: - guard = JaniGuard(_merge_conditions(previous_expressions)) + return JaniGuard(_merge_conditions(previous_expressions)) else: - guard = None - return guard + return None def add_unhandled_transitions(self, transitions_set: Set[str]): """Add self-loops for transitions that weren't handled yet.""" @@ -420,7 +417,7 @@ def write_model(self): self._events_no_condition: List[str] = [] for child in self.children: transition_events = child.element.get_events() - transition_event = "" if transition_events is None else transition_events[0] + transition_event = "" if len(transition_events) == 0 else transition_events[0] transition_condition = child.element.get_condition() # Add previous conditions matching the same event trigger to the current child state child.set_previous_siblings_conditions( @@ -463,9 +460,9 @@ def write_model(self): assert target_state is not None, f"Transition's target state {target_state_id} not found." event_name = self.element.get_events() # TODO: Need to extend this to support multiple events - assert event_name is None or len(event_name) == 1, \ + assert len(event_name) == 0 or len(event_name) == 1, \ "Transitions triggered by multiple events are not supported." - transition_trigger_event = None if event_name is None else event_name[0] + transition_trigger_event = None if len(event_name) == 0 else event_name[0] if transition_trigger_event is not None: # TODO: Maybe get rid of one of the two event variables assert len(transition_trigger_event) > 0, "Empty event name not supported." diff --git a/jani_generator/src/jani_generator/scxml_helpers/scxml_to_jani.py b/jani_generator/src/jani_generator/scxml_helpers/scxml_to_jani.py index 624d0193..3afacdb2 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/scxml_to_jani.py +++ b/jani_generator/src/jani_generator/scxml_helpers/scxml_to_jani.py @@ -19,18 +19,17 @@ from typing import List, Union -from scxml_converter.scxml_entries import ScxmlRoot - from jani_generator.jani_entries.jani_automaton import JaniAutomaton from jani_generator.jani_entries.jani_model import JaniModel -from jani_generator.ros_helpers.ros_timer import ( - RosTimer, make_global_timer_automaton) -from jani_generator.ros_helpers.ros_services import ( - remove_empty_self_loops_from_srv_handlers_in_jani) +from jani_generator.ros_helpers.ros_services import \ + remove_empty_self_loops_from_srv_handlers_in_jani +from jani_generator.ros_helpers.ros_timer import (RosTimer, + make_global_timer_automaton) from jani_generator.scxml_helpers.scxml_event import EventsHolder from jani_generator.scxml_helpers.scxml_event_processor import \ implement_scxml_events_as_jani_syncs from jani_generator.scxml_helpers.scxml_tags import BaseTag +from scxml_converter.scxml_entries import ScxmlRoot def convert_scxml_root_to_jani_automaton( diff --git a/jani_generator/src/jani_generator/scxml_helpers/top_level_interpreter.py b/jani_generator/src/jani_generator/scxml_helpers/top_level_interpreter.py index c12a5014..f6215993 100644 --- a/jani_generator/src/jani_generator/scxml_helpers/top_level_interpreter.py +++ b/jani_generator/src/jani_generator/scxml_helpers/top_level_interpreter.py @@ -17,20 +17,30 @@ Module reading the top level xml file containing the whole model to check. """ -import os import json - -from typing import Any, Dict, List, Tuple - +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union from xml.etree import ElementTree as ET from as2fm_common.common import remove_namespace -from scxml_converter.bt_converter import bt_converter -from scxml_converter.scxml_entries import ScxmlRoot from jani_generator.jani_entries import JaniModel +from jani_generator.ros_helpers.ros_services import RosService, RosServices from jani_generator.ros_helpers.ros_timer import RosTimer -from jani_generator.ros_helpers.ros_services import RosServices, RosService -from jani_generator.scxml_helpers.scxml_to_jani import convert_multiple_scxmls_to_jani +from jani_generator.scxml_helpers.scxml_to_jani import \ + convert_multiple_scxmls_to_jani +from scxml_converter.bt_converter import bt_converter +from scxml_converter.scxml_entries import ScxmlRoot + + +@dataclass() +class FullModel: + max_time: Optional[int] = None + bt: Optional[str] = None + plugins: List[str] = field(default_factory=list) + skills: List[str] = field(default_factory=list) + components: List[str] = field(default_factory=list) + properties: List[str] = field(default_factory=list) def _parse_time_element(time_element: ET.Element) -> int: @@ -51,7 +61,7 @@ def _parse_time_element(time_element: ET.Element) -> int: return int(time_element.attrib["value"]) * TIME_MULTIPLIERS[time_unit] -def parse_main_xml(xml_path: str) -> Dict[str, Any]: +def parse_main_xml(xml_path: str) -> FullModel: """ Interpret the top-level XML file as a dictionary. @@ -69,21 +79,14 @@ def parse_main_xml(xml_path: str) -> Dict[str, Any]: xml = ET.parse(f) assert remove_namespace(xml.getroot().tag) == "convince_mc_tc", \ "The top-level XML element must be convince_mc_tc." - main_dict = { - "max_time": None, - "bt": None, - "plugins": [], - "skills": [], - "components": [], - "properties": [] - } + model = FullModel() for first_level in xml.getroot(): if remove_namespace(first_level.tag) == "mc_parameters": for mc_parameter in first_level: # if remove_namespace(mc_parameter.tag) == "time_resolution": # time_resolution = _parse_time_element(mc_parameter) if remove_namespace(mc_parameter.tag) == "max_time": - main_dict["max_time"] = _parse_time_element(mc_parameter) + model.max_time = _parse_time_element(mc_parameter) else: raise ValueError( f"Invalid mc_parameter tag: {mc_parameter.tag}") @@ -91,48 +94,48 @@ def parse_main_xml(xml_path: str) -> Dict[str, Any]: for child in first_level: if remove_namespace(child.tag) == "input": if child.attrib["type"] == "bt.cpp-xml": - assert main_dict["bt"] is None, "Only one BT is supported." - main_dict["bt"] = os.path.join(folder_of_xml, child.attrib["src"]) + assert model.bt is None, "Only one Behavior Tree is supported." + model.bt = os.path.join(folder_of_xml, child.attrib["src"]) elif child.attrib["type"] == "bt-plugin-ros-scxml": - main_dict["plugins"].append( + model.plugins.append( os.path.join(folder_of_xml, child.attrib["src"])) else: raise ValueError(f"Invalid input type: {child.attrib['type']}") else: raise ValueError( f"Invalid behavior_tree tag: {child.tag} != input") - assert main_dict["bt"] is not None, "There must be a Behavior Tree defined." + assert model.bt is not None, "A Behavior Tree must be defined." elif remove_namespace(first_level.tag) == "node_models": for node_model in first_level: assert remove_namespace(node_model.tag) == "input", \ "Only input tags are supported." assert node_model.attrib["type"] == "ros-scxml", \ "Only ROS-SCXML node models are supported." - main_dict["skills"].append(os.path.join(folder_of_xml, node_model.attrib["src"])) + model.skills.append(os.path.join(folder_of_xml, node_model.attrib["src"])) elif remove_namespace(first_level.tag) == "properties": for property in first_level: assert remove_namespace(property.tag) == "input", \ "Only input tags are supported." assert property.attrib["type"] == "jani", \ "Only Jani properties are supported." - main_dict["properties"].append(os.path.join(folder_of_xml, property.attrib["src"])) + model.properties.append(os.path.join(folder_of_xml, property.attrib["src"])) else: raise ValueError(f"Invalid main point tag: {first_level.tag}") - return main_dict + return model def generate_plain_scxml_models_and_timers( - full_model_dict: str) -> Tuple[List[ScxmlRoot], List[RosTimer]]: + model: FullModel) -> Tuple[List[ScxmlRoot], List[RosTimer]]: """ Generate plain SCXML models and ROS timers from the full model dictionary. """ # Convert behavior tree and plugins to ROS-scxml - scxml_files_to_convert: list = full_model_dict["skills"] + full_model_dict["components"] - if full_model_dict["bt"] is not None: - bt_out_dir = os.path.join(os.path.dirname(full_model_dict["bt"]), "generated_bt_scxml") + scxml_files_to_convert: list = model.skills + model.components + if model.bt is not None: + bt_out_dir = os.path.join(os.path.dirname(model.bt), "generated_bt_scxml") os.makedirs(bt_out_dir, exist_ok=True) expanded_bt_plugin_scxmls = bt_converter( - full_model_dict["bt"], full_model_dict["plugins"], bt_out_dir) + model.bt, model.plugins, bt_out_dir) scxml_files_to_convert.extend(expanded_bt_plugin_scxmls) # Convert ROS-SCXML FSMs to plain SCXML @@ -165,17 +168,18 @@ def generate_plain_scxml_models_and_timers( return plain_scxml_models, all_timers -def interpret_top_level_xml(xml_path: str, store_generated_scxmls: bool = False) -> JaniModel: +def interpret_top_level_xml(xml_path: str, store_generated_scxmls: bool = False): """ - Interpret the top-level XML file as a Jani model. + Interpret the top-level XML file as a Jani model. And write it to a file. + The generated Jani model is written to the same directory as the input XML file under the + name `main.jani`. :param xml_path: The path to the XML file to interpret. - :return: The interpreted Jani model. """ model_dir = os.path.dirname(xml_path) - full_model_dict = parse_main_xml(xml_path) - assert full_model_dict["max_time"] is not None, f"Max time must be defined in {xml_path}." - plain_scxml_models, all_timers = generate_plain_scxml_models_and_timers(full_model_dict) + model = parse_main_xml(xml_path) + assert model.max_time is not None, f"Max time must be defined in {xml_path}." + plain_scxml_models, all_timers = generate_plain_scxml_models_and_timers(model) if store_generated_scxmls: plain_scxml_dir = os.path.join(model_dir, "generated_plain_scxml") @@ -186,11 +190,11 @@ def interpret_top_level_xml(xml_path: str, store_generated_scxmls: bool = False) f.write(scxml_model.as_xml_string()) jani_model = convert_multiple_scxmls_to_jani( - plain_scxml_models, all_timers, full_model_dict["max_time"]) + plain_scxml_models, all_timers, model.max_time) jani_dict = jani_model.as_dict() - assert len(full_model_dict["properties"]) == 1, "Only one property is supported right now." - with open(full_model_dict["properties"][0], "r", encoding='utf-8') as f: + assert len(model.properties) == 1, "Only one property is supported right now." + with open(model.properties[0], "r", encoding='utf-8') as f: jani_dict["properties"] = json.load(f)["properties"] output_path = os.path.join(model_dir, "main.jani") diff --git a/jani_generator/test/_test_data/battery_example/output_GROUND_TRUTH.jani b/jani_generator/test/_test_data/battery_example/output_GROUND_TRUTH.jani index c9fd97c6..33dddfef 100644 --- a/jani_generator/test/_test_data/battery_example/output_GROUND_TRUTH.jani +++ b/jani_generator/test/_test_data/battery_example/output_GROUND_TRUTH.jani @@ -31,7 +31,7 @@ "name": "use_battery-first-exec-use_battery-766fa6e4" }, { - "name": "use_battery-use_battery-1b935c10" + "name": "use_battery-use_battery-cf7e7c41" } ], "automata": [ @@ -42,7 +42,7 @@ "name": "use_battery" }, { - "name": "use_battery-1-1b935c10" + "name": "use_battery-1-cf7e7c41" }, { "name": "use_battery-first-exec" @@ -59,7 +59,7 @@ "location": "use_battery", "destinations": [ { - "location": "use_battery-1-1b935c10", + "location": "use_battery-1-cf7e7c41", "assignments": [ { "ref": "battery_percent", @@ -73,10 +73,10 @@ ] } ], - "action": "use_battery-use_battery-1b935c10" + "action": "use_battery-use_battery-cf7e7c41" }, { - "location": "use_battery-1-1b935c10", + "location": "use_battery-1-cf7e7c41", "destinations": [ { "location": "use_battery", @@ -259,9 +259,9 @@ ] }, { - "result": "use_battery-use_battery-1b935c10", + "result": "use_battery-use_battery-cf7e7c41", "synchronise": [ - "use_battery-use_battery-1b935c10", + "use_battery-use_battery-cf7e7c41", null, null ] diff --git a/jani_generator/test/test_systemtest_convince_to_plain_jani.py b/jani_generator/test/test_systemtest_convince_to_plain_jani.py index 4dd2c83f..fe995cc0 100644 --- a/jani_generator/test/test_systemtest_convince_to_plain_jani.py +++ b/jani_generator/test/test_systemtest_convince_to_plain_jani.py @@ -17,8 +17,8 @@ import os -from jani_generator.jani_entries import JaniModel from jani_generator.convince_jani_helpers import convince_jani_parser +from jani_generator.jani_entries import JaniModel def test_convince_to_plain_jani(): diff --git a/jani_generator/test/test_systemtest_scxml_to_jani.py b/jani_generator/test/test_systemtest_scxml_to_jani.py index 96529f31..b54dec6e 100644 --- a/jani_generator/test/test_systemtest_scxml_to_jani.py +++ b/jani_generator/test/test_systemtest_scxml_to_jani.py @@ -22,12 +22,14 @@ import pytest -from scxml_converter.scxml_entries import ScxmlRoot from jani_generator.jani_entries import JaniAutomaton from jani_generator.scxml_helpers.scxml_event import EventsHolder from jani_generator.scxml_helpers.scxml_to_jani import ( convert_multiple_scxmls_to_jani, convert_scxml_root_to_jani_automaton) -from jani_generator.scxml_helpers.top_level_interpreter import interpret_top_level_xml +from jani_generator.scxml_helpers.top_level_interpreter import \ + interpret_top_level_xml +from scxml_converter.scxml_entries import ScxmlRoot + from .test_utilities_smc_storm import run_smc_storm_with_output diff --git a/jani_generator/test/test_unittest_ros_timer.py b/jani_generator/test/test_unittest_ros_timer.py index 767bcea2..4010d1bf 100644 --- a/jani_generator/test/test_unittest_ros_timer.py +++ b/jani_generator/test/test_unittest_ros_timer.py @@ -15,9 +15,10 @@ """Test the ROS timer conversion""" -import pytest import unittest +import pytest + from jani_generator.ros_helpers.ros_timer import RosTimer diff --git a/jani_generator/test/test_unittest_scxml_data.py b/jani_generator/test/test_unittest_scxml_data.py index 278ad57c..b40dbbbe 100644 --- a/jani_generator/test/test_unittest_scxml_data.py +++ b/jani_generator/test/test_unittest_scxml_data.py @@ -15,9 +15,10 @@ """"Test the SCXML data conversion""" -import pytest -import xml.etree.ElementTree as ET import unittest +import xml.etree.ElementTree as ET + +import pytest from jani_generator.scxml_helpers.scxml_data import ScxmlData diff --git a/jani_generator/test/test_utilities_smc_storm.py b/jani_generator/test/test_utilities_smc_storm.py index e4d2861a..a1d7de33 100644 --- a/jani_generator/test/test_utilities_smc_storm.py +++ b/jani_generator/test/test_utilities_smc_storm.py @@ -13,9 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple - import subprocess +from typing import List, Tuple import pytest diff --git a/scxml_converter/pyproject.toml b/scxml_converter/pyproject.toml index e1ae4c9e..daaf5652 100644 --- a/scxml_converter/pyproject.toml +++ b/scxml_converter/pyproject.toml @@ -20,13 +20,14 @@ dependencies = [ "networkx", "btlib", ] -requires-python = ">=3.7" +requires-python = ">=3.10" [project.optional-dependencies] dev = ["pytest", "pytest-cov", "pycodestyle", "flake8", "mypy", "isort", "bumpver"] [isort] profile = "google" +line_length = 100 [flake8] max_line_length = 100 \ No newline at end of file diff --git a/scxml_converter/src/scxml_converter/bt_converter.py b/scxml_converter/src/scxml_converter/bt_converter.py index 15d4e1b5..e4e6b411 100644 --- a/scxml_converter/src/scxml_converter/bt_converter.py +++ b/scxml_converter/src/scxml_converter/bt_converter.py @@ -27,6 +27,7 @@ from btlib.bt_to_fsm.bt_to_fsm import Bt2FSM from btlib.bts import xml_to_networkx from btlib.common import NODE_CAT + from scxml_converter.scxml_entries import (RosRateCallback, RosTimeRate, ScxmlRoot, ScxmlSend, ScxmlState, ScxmlTransition) @@ -39,6 +40,7 @@ class BT_EVENT_TYPE(Enum): FAILURE = auto() RUNNING = auto() + @staticmethod def from_str(event_name: str) -> 'BT_EVENT_TYPE': event_name = event_name.replace('event=', '') event_name = event_name.replace('"', '') diff --git a/scxml_converter/src/scxml_converter/scxml_converter.py b/scxml_converter/src/scxml_converter/scxml_converter.py index 157689f1..4791f663 100644 --- a/scxml_converter/src/scxml_converter/scxml_converter.py +++ b/scxml_converter/src/scxml_converter/scxml_converter.py @@ -23,11 +23,10 @@ import xml.etree.ElementTree as ET from typing import Dict, Tuple, Union -from scxml_converter.scxml_entries import ScxmlRoot, ScxmlRosDeclarationsContainer - from as2fm_common.common import ros_type_name_to_python_type -from as2fm_common.ecmascript_interpretation import \ - interpret_ecma_script_expr +from as2fm_common.ecmascript_interpretation import interpret_ecma_script_expr +from scxml_converter.scxml_entries import (ScxmlRoot, + ScxmlRosDeclarationsContainer) BASIC_FIELD_TYPES = ['boolean', 'int32', 'int16', 'float', 'double'] diff --git a/scxml_converter/src/scxml_converter/scxml_entries/__init__.py b/scxml_converter/src/scxml_converter/scxml_entries/__init__.py index 88c3d328..c827fae1 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/__init__.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/__init__.py @@ -1,3 +1,5 @@ +# isort: skip_file +# Skipping file to avoid circular import problem from .scxml_base import ScxmlBase # noqa: F401 from .scxml_param import ScxmlParam # noqa: F401 from .scxml_ros_field import RosField # noqa: F401 diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_base.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_base.py index 4641ddd4..8102baed 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_base.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_base.py @@ -17,15 +17,19 @@ Base SCXML class, defining the methods all SCXML entries shall implement. """ +from xml.etree import ElementTree as ET + class ScxmlBase: """This class is the base class for all SCXML entries.""" + @staticmethod def get_tag_name() -> str: """Get the tag name of the XML element.""" raise NotImplementedError - def from_xml_tree(xml_tree) -> "ScxmlBase": + @staticmethod + def from_xml_tree(xml_tree: ET.Element) -> "ScxmlBase": """Create a ScxmlBase object from an XML tree.""" raise NotImplementedError diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_data.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_data.py index 8537d062..e00dee73 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_data.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_data.py @@ -18,12 +18,11 @@ """ from typing import Any -from scxml_converter.scxml_entries import ScxmlBase +from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import ScxmlBase from scxml_converter.scxml_entries.utils import SCXML_DATA_STR_TO_TYPE -from xml.etree import ElementTree as ET - class ScxmlData(ScxmlBase): """This class represents the variables defined in the model.""" @@ -48,9 +47,9 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlData": return ScxmlData(data_id, data_expr, data_type, lower_bound, upper_bound) def __init__( - self, id: str, expr: str, data_type: str, + self, id_ : str, expr: str, data_type: str, lower_bound: Any = None, upper_bound: Any = None): - self._id = id + self._id = id_ self._expr = expr self._data_type = data_type self._lower_bound = lower_bound @@ -103,7 +102,10 @@ def as_xml(self) -> ET.Element: xml_data = ET.Element(ScxmlData.get_tag_name(), {"id": self._id, "expr": self._expr, "type": self._data_type}) if self._lower_bound is not None: - xml_data.set("lower_bound_incl", str(self._lower_bound_incl)) + xml_data.set("lower_bound_incl", str(self._lower_bound)) if self._upper_bound is not None: xml_data.set("upper_bound_incl", str(self._upper_bound)) return xml_data + + def as_plain_scxml(self, ros_declarations): + raise NotImplementedError("Error: SCXML data: as_plain_scxml not implemented.") diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_data_model.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_data_model.py index 654c288e..e63db7a4 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_data_model.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_data_model.py @@ -17,25 +17,24 @@ Container for the variables defined in the SCXML model. In XML, it has the tag `datamodel`. """ -from scxml_converter.scxml_entries import ScxmlBase, ScxmlData - from typing import List, Optional - from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import ScxmlBase, ScxmlData + class ScxmlDataModel(ScxmlBase): """This class represents the variables defined in the model.""" + def __init__(self, data_entries: List[ScxmlData] = None): # TODO: Check ScxmlData from scxml_helpers, for alternative parsing self._data_entries = data_entries + @staticmethod def get_tag_name() -> str: return "datamodel" - def get_data_entries(self) -> List[ScxmlData]: - return self._data_entries - + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlDataModel": """Create a ScxmlDataModel object from an XML tree.""" assert xml_tree.tag == ScxmlDataModel.get_tag_name(), \ @@ -47,6 +46,9 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlDataModel": data_entries.append(ScxmlData.from_xml_tree(data_entry_xml)) return ScxmlDataModel(data_entries) + def get_data_entries(self) -> Optional[List[ScxmlData]]: + return self._data_entries + def check_validity(self) -> bool: valid_data_entries = True if self._data_entries is not None: diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_executable_entries.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_executable_entries.py index 1032b18b..e17ac00c 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_executable_entries.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_executable_entries.py @@ -17,12 +17,13 @@ Definition of SCXML Tags that can be part of executable content """ -from typing import List, Optional, Union, Tuple, get_args +from typing import List, Optional, Tuple, Union, get_args from xml.etree import ElementTree as ET -from scxml_converter.scxml_entries import (ScxmlBase, ScxmlParam, ScxmlRosDeclarationsContainer) - -from scxml_converter.scxml_entries.utils import replace_ros_interface_expression +from scxml_converter.scxml_entries import (ScxmlBase, ScxmlParam, + ScxmlRosDeclarationsContainer) +from scxml_converter.scxml_entries.utils import \ + replace_ros_interface_expression # Use delayed type evaluation: https://peps.python.org/pep-0484/#forward-references ScxmlExecutableEntry = Union['ScxmlAssign', 'ScxmlIf', 'ScxmlSend'] @@ -45,17 +46,11 @@ def __init__(self, self._conditional_executions = conditional_executions self._else_execution = else_execution + @staticmethod def get_tag_name() -> str: return "if" - def get_conditional_executions(self) -> List[ConditionalExecutionBody]: - """Get the conditional executions.""" - return self._conditional_executions - - def get_else_execution(self) -> Optional[ScxmlExecutionBody]: - """Get the else execution.""" - return self._else_execution - + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlIf": """Create a ScxmlIf object from an XML tree.""" assert xml_tree.tag == ScxmlIf.get_tag_name(), \ @@ -63,7 +58,8 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlIf": conditions: List[str] = [] exec_bodies: List[ScxmlExecutionBody] = [] conditions.append(xml_tree.attrib["cond"]) - current_body: ScxmlExecutionBody = [] + current_body: Optional[ScxmlExecutionBody] = [] + assert current_body is not None, "Error: SCXML if: current body is not valid." for child in xml_tree: if child.tag == "elseif": conditions.append(child.attrib["cond"]) @@ -80,6 +76,14 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlIf": current_body = None return ScxmlIf(list(zip(conditions, exec_bodies)), current_body) + def get_conditional_executions(self) -> List[ConditionalExecutionBody]: + """Get the conditional executions.""" + return self._conditional_executions + + def get_else_execution(self) -> Optional[ScxmlExecutionBody]: + """Get the else execution.""" + return self._else_execution + def check_validity(self) -> bool: valid_conditional_executions = len(self._conditional_executions) > 0 if not valid_conditional_executions: @@ -123,8 +127,10 @@ def check_valid_ros_instantiations(self, def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "ScxmlIf": condional_executions = [] for condition, execution in self._conditional_executions: + execution_body = as_plain_execution_body(execution, ros_declarations) + assert execution_body is not None, "Error: SCXML if: invalid execution body." condional_executions.append((replace_ros_interface_expression(condition), - as_plain_execution_body(execution, ros_declarations))) + execution_body)) else_execution = as_plain_execution_body(self._else_execution, ros_declarations) return ScxmlIf(condional_executions, else_execution) @@ -135,7 +141,7 @@ def as_xml(self) -> ET.Element: xml_if = ET.Element(ScxmlIf.get_tag_name(), {"cond": first_conditional_execution[0]}) append_execution_body_to_xml(xml_if, first_conditional_execution[1]) for condition, execution in self._conditional_executions[1:]: - xml_if.append = ET.Element('elseif', {"cond": condition}) + xml_if.append(ET.Element('elseif', {"cond": condition})) append_execution_body_to_xml(xml_if, execution) if self._else_execution is not None: xml_if.append(ET.Element('else')) @@ -152,29 +158,30 @@ def __init__(self, event: str, params: Optional[List[ScxmlParam]] = None): self._event = event self._params = params + @staticmethod def get_tag_name() -> str: return "send" - def get_event(self) -> str: - """Get the event to send.""" - return self._event - - def get_params(self) -> List[ScxmlParam]: - """Get the parameters to send.""" - return self._params - + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlSend": """Create a ScxmlSend object from an XML tree.""" assert xml_tree.tag == ScxmlSend.get_tag_name(), \ f"Error: SCXML send: XML tag name is not {ScxmlSend.get_tag_name()}." event = xml_tree.attrib["event"] - params = [] + params: List[ScxmlParam] = [] + assert params is not None, "Error: SCXML send: params is not valid." for param_xml in xml_tree: params.append(ScxmlParam.from_xml_tree(param_xml)) - if len(params) == 0: - params = None return ScxmlSend(event, params) + def get_event(self) -> str: + """Get the event to send.""" + return self._event + + def get_params(self) -> List[ScxmlParam]: + """Get the parameters to send.""" + return self._params + def check_validity(self) -> bool: valid_event = isinstance(self._event, str) and len(self._event) > 0 valid_params = True @@ -214,17 +221,11 @@ def __init__(self, location: str, expr: str): self._location = location self._expr = expr + @staticmethod def get_tag_name() -> str: return "assign" - def get_location(self) -> str: - """Get the location to assign.""" - return self._location - - def get_expr(self) -> str: - """Get the expression to assign.""" - return self._expr - + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlAssign": """Create a ScxmlAssign object from an XML tree.""" assert xml_tree.tag == ScxmlAssign.get_tag_name(), \ @@ -236,6 +237,15 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlAssign": assert expr is not None and len(expr) > 0, \ "Error: SCXML assign: expr is not valid." return ScxmlAssign(location, expr) + + def get_location(self) -> str: + """Get the location to assign.""" + return self._location + + def get_expr(self) -> str: + """Get the expression to assign.""" + return self._expr + def check_validity(self) -> bool: # TODO: Check that the location to assign exists in the data-model @@ -301,6 +311,7 @@ def execution_entry_from_xml(xml_tree: ET.Element) -> ScxmlExecutableEntry: """ # TODO: This is pretty bad, need to re-check how to break the circle from .scxml_ros_entries import ScxmlRosSends + # TODO: This should be generated only once, since it stays as it is tag_to_cls = {cls.get_tag_name(): cls for cls in _ResolvedScxmlExecutableEntry + ScxmlRosSends} exec_tag = xml_tree.tag diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_param.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_param.py index 67c6b5e7..c524cb66 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_param.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_param.py @@ -17,12 +17,11 @@ Container for a single parameter, sent within an event. In XML, it has the tag `param`. """ -from scxml_converter.scxml_entries import ScxmlBase - from typing import Optional - from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import ScxmlBase + class ScxmlParam(ScxmlBase): """This class represents a single parameter.""" @@ -33,18 +32,11 @@ def __init__(self, name: str, *, expr: Optional[str] = None, location: Optional[ self._expr = expr self._location = location + @staticmethod def get_tag_name() -> str: return "param" - def get_name(self) -> str: - return self._name - - def get_expr(self) -> Optional[str]: - return self._expr - - def get_location(self) -> Optional[str]: - return self._location - + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlParam": """Create a ScxmlParam object from an XML tree.""" assert xml_tree.tag == ScxmlParam.get_tag_name(), \ @@ -59,6 +51,15 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlParam": "Error: SCXML param: expr and location are both unset." return ScxmlParam(name, expr=expr, location=location) + def get_name(self) -> str: + return self._name + + def get_expr(self) -> Optional[str]: + return self._expr + + def get_location(self) -> Optional[str]: + return self._location + def check_validity(self) -> bool: valid_name = len(self._name) > 0 if not valid_name: diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_root.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_root.py index 4f1f202b..7dfd3d64 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_root.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_root.py @@ -17,17 +17,19 @@ The main entry point of an SCXML Model. In XML, it has the tag `scxml`. """ -from typing import List, Optional, Tuple, get_args -from scxml_converter.scxml_entries import (ScxmlBase, ScxmlState, ScxmlDataModel, - ScxmlRosDeclarations, RosTimeRate, RosTopicSubscriber, - RosTopicPublisher, RosServiceServer, RosServiceClient, - ScxmlRosDeclarationsContainer) - from copy import deepcopy from os.path import isfile - +from typing import List, Optional, Tuple, get_args from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import (RosServiceClient, RosServiceServer, + RosTimeRate, RosTopicPublisher, + RosTopicSubscriber, ScxmlBase, + ScxmlDataModel, + ScxmlRosDeclarations, + ScxmlRosDeclarationsContainer, + ScxmlState) + class ScxmlRoot(ScxmlBase): """This class represents a whole scxml model, that is used to define specific skills.""" @@ -51,7 +53,7 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlRoot": assert datamodel_elements is None or len(datamodel_elements) <= 1, \ f"Error: SCXML root: {len(datamodel_elements)} datamodels found, max 1 allowed." # ROS Declarations - ros_declarations = [] + ros_declarations: List[ScxmlRosDeclarations] = [] for child in xml_tree: if child.tag == RosTimeRate.get_tag_name(): ros_declarations.append(RosTimeRate.from_xml_tree(child)) @@ -103,10 +105,10 @@ def from_scxml_file(xml_file: str) -> "ScxmlRoot": def __init__(self, name: str): self._name = name self._version = "1.0" # This is the only version mentioned in the official documentation - self._initial_state: str = None + self._initial_state: Optional[str] = None self._states: List[ScxmlState] = [] - self._data_model: ScxmlDataModel = None - self._ros_declarations: List[ScxmlRosDeclarations] = None + self._data_model: Optional[ScxmlDataModel] = None + self._ros_declarations: List[ScxmlRosDeclarations] = [] def get_name(self) -> str: """Get the name of the automaton represented by this SCXML model.""" @@ -114,6 +116,7 @@ def get_name(self) -> str: def get_initial_state_id(self) -> str: """Get the ID of the initial state of the SCXML model.""" + assert self._initial_state is not None, "Error: SCXML root: Initial state not set." return self._initial_state def get_data_model(self) -> Optional[ScxmlDataModel]: @@ -147,7 +150,7 @@ def add_ros_declaration(self, ros_declaration: ScxmlRosDeclarations): self._ros_declarations = [] self._ros_declarations.append(ros_declaration) - def _generate_ros_declarations_helper(self) -> ScxmlRosDeclarationsContainer: + def _generate_ros_declarations_helper(self) -> Optional[ScxmlRosDeclarationsContainer]: """Generate a HelperRosDeclarations object from the existing ROS declarations.""" ros_decl_container = ScxmlRosDeclarationsContainer(self._name) if self._ros_declarations is not None: @@ -230,12 +233,14 @@ def to_plain_scxml_and_declarations(self) -> Tuple["ScxmlRoot", ScxmlRosDeclarat plain_root._data_model = deepcopy(self._data_model) plain_root._initial_state = self._initial_state ros_declarations = self._generate_ros_declarations_helper() + assert ros_declarations is not None, "Error: SCXML root: invalid ROS declarations." plain_root._states = [state.as_plain_scxml(ros_declarations) for state in self._states] assert plain_root.is_plain_scxml(), "SCXML root: conversion to plain SCXML failed." return (plain_root, ros_declarations) def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid root object." + assert self._initial_state is not None, "Error: SCXML root: no initial state set." xml_root = ET.Element("scxml", { "name": self._name, "version": self._version, @@ -244,7 +249,9 @@ def as_xml(self) -> ET.Element: "xmlns": "http://www.w3.org/2005/07/scxml" }) if self._data_model is not None: - xml_root.append(self._data_model.as_xml()) + data_model_xml = self._data_model.as_xml() + assert data_model_xml is not None, "Error: SCXML root: invalid data model." + xml_root.append(data_model_xml) if self._ros_declarations is not None: for ros_declaration in self._ros_declarations: xml_root.append(ros_declaration.as_xml()) diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_entries.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_entries.py index af342ad4..b9a4dfb3 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_entries.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_entries.py @@ -16,11 +16,16 @@ """Declaration of ROS-Specific SCXML tags extensions.""" from typing import Union -from scxml_converter.scxml_entries import ( - RosTimeRate, RosTopicPublisher, RosTopicSubscriber, RosServiceServer, RosServiceClient, - RosServiceHandleRequest, RosServiceHandleResponse, - RosServiceSendRequest, RosServiceSendResponse, - RosTopicPublish, RosTopicCallback, RosRateCallback) + +from scxml_converter.scxml_entries import (RosRateCallback, RosServiceClient, + RosServiceHandleRequest, + RosServiceHandleResponse, + RosServiceSendRequest, + RosServiceSendResponse, + RosServiceServer, RosTimeRate, + RosTopicCallback, RosTopicPublish, + RosTopicPublisher, + RosTopicSubscriber) ScxmlRosDeclarations = Union[RosTimeRate, RosTopicPublisher, RosTopicSubscriber, RosServiceServer, RosServiceClient] diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_field.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_field.py index 1f04d241..3d790fbb 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_field.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_field.py @@ -15,9 +15,10 @@ """Declaration of the ROS Field SCXML tag extension.""" -from scxml_converter.scxml_entries import ScxmlParam from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import ScxmlParam + class RosField(ScxmlParam): """Field of a ROS msg published in a topic.""" @@ -27,9 +28,11 @@ def __init__(self, name: str, expr: str): self._expr = expr assert self.check_validity(), "Error: SCXML topic publish field: invalid parameters." + @staticmethod def get_tag_name() -> str: return "field" + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "RosField": """Create a RosField object from an XML tree.""" assert xml_tree.tag == RosField.get_tag_name(), \ @@ -49,8 +52,9 @@ def check_validity(self) -> bool: print("Error: SCXML topic publish field: expr is not valid.") return valid_name and valid_expr - def as_plain_scxml(self) -> ScxmlParam: - from scxml_converter.scxml_entries.utils import replace_ros_interface_expression + def as_plain_scxml(self, _) -> ScxmlParam: + from scxml_converter.scxml_entries.utils import \ + replace_ros_interface_expression return ScxmlParam(self._name, expr=replace_ros_interface_expression(self._expr)) def as_xml(self) -> ET.Element: diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_service.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_service.py index 449228d7..6246cbab 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_service.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_service.py @@ -20,18 +20,19 @@ https://docs.ros.org/en/iron/Tutorials/Beginner-CLI-Tools/Understanding-ROS2-Services/Understanding-ROS2-Services.html """ -from typing import Optional, List, Union -from scxml_converter.scxml_entries import (ScxmlBase, RosField, ScxmlSend, ScxmlTransition, - ScxmlExecutionBody) -from scxml_converter.scxml_entries import (execution_body_from_xml, valid_execution_body, - as_plain_execution_body) +from typing import List, Optional, Union from xml.etree import ElementTree as ET -from scxml_converter.scxml_entries.utils import ScxmlRosDeclarationsContainer -from scxml_converter.scxml_entries.utils import (is_srv_type_known, - generate_srv_request_event, - generate_srv_response_event, - generate_srv_server_request_event, - generate_srv_server_response_event) + +from scxml_converter.scxml_entries import (RosField, ScxmlBase, + ScxmlExecutionBody, ScxmlSend, + ScxmlTransition, + as_plain_execution_body, + execution_body_from_xml, + valid_execution_body) +from scxml_converter.scxml_entries.utils import ( + ScxmlRosDeclarationsContainer, generate_srv_request_event, + generate_srv_response_event, generate_srv_server_request_event, + generate_srv_server_response_event, is_srv_type_known) class RosServiceServer(ScxmlBase): @@ -164,16 +165,14 @@ def from_xml_tree(xml_tree: ET.Element) -> "RosServiceSendRequest": srv_name = xml_tree.attrib.get("service_name") assert srv_name is not None, \ "Error: SCXML service request: 'service_name' attribute not found in input xml." - fields = [] + fields: List[RosField] = [] for field_xml in xml_tree: fields.append(RosField.from_xml_tree(field_xml)) - if len(fields) == 0: - fields = None return RosServiceSendRequest(srv_name, fields) def __init__(self, service_decl: Union[str, RosServiceClient], - fields: Optional[List[RosField]]) -> None: + fields: List[RosField] = None) -> None: """ Initialize a new RosServiceSendRequest object. @@ -187,6 +186,8 @@ def __init__(self, assert isinstance(service_decl, str), \ "Error: SCXML Service Send Request: invalid service name." self._srv_name = service_decl + if fields is None: + fields = [] self._fields = fields assert self.check_validity(), "Error: SCXML Service Send Request: invalid parameters." @@ -220,7 +221,7 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Scx "Error: SCXML service request: invalid ROS instantiations." event_name = generate_srv_request_event( self._srv_name, ros_declarations.get_automaton_name()) - event_params = [field.as_plain_scxml() for field in self._fields] + event_params = [field.as_plain_scxml(ros_declarations) for field in self._fields] return ScxmlSend(event_name, event_params) def as_xml(self) -> ET.Element: @@ -326,7 +327,7 @@ def get_tag_name() -> str: return "ros_service_send_response" @staticmethod - def from_xml_tree(xml_tree: ET.Element) -> "RosServiceClient": + def from_xml_tree(xml_tree: ET.Element) -> "RosServiceSendResponse": """Create a RosServiceServer object from an XML tree.""" assert xml_tree.tag == RosServiceSendResponse.get_tag_name(), \ "Error: SCXML service response: XML tag name is not " + \ @@ -334,7 +335,8 @@ def from_xml_tree(xml_tree: ET.Element) -> "RosServiceClient": srv_name = xml_tree.attrib.get("service_name") assert srv_name is not None, \ "Error: SCXML service response: 'service_name' attribute not found in input xml." - fields = [] + fields: Optional[List[RosField]] = [] + assert fields is not None, "Error: SCXML service response: fields is not valid." for field_xml in xml_tree: fields.append(RosField.from_xml_tree(field_xml)) if len(fields) == 0: @@ -356,7 +358,7 @@ def __init__(self, service_name: Union[str, RosServiceServer], assert isinstance(service_name, str), \ "Error: SCXML Service Send Response: invalid service name." self._service_name = service_name - self._fields = fields + self._fields = fields if fields is not None else [] assert self.check_validity(), "Error: SCXML Service Send Response: invalid parameters." def check_validity(self) -> bool: @@ -388,7 +390,7 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Scx assert self.check_valid_ros_instantiations(ros_declarations), \ "Error: SCXML service response: invalid ROS instantiations." event_name = generate_srv_server_response_event(self._service_name) - event_params = [field.as_plain_scxml() for field in self._fields] + event_params = [field.as_plain_scxml(ros_declarations) for field in self._fields] return ScxmlSend(event_name, event_params) def as_xml(self) -> ET.Element: @@ -409,7 +411,7 @@ def get_tag_name() -> str: return "ros_service_handle_response" @staticmethod - def from_xml_tree(xml_tree: ET.Element) -> "RosServiceClient": + def from_xml_tree(xml_tree: ET.Element) -> "RosServiceHandleResponse": """Create a RosServiceServer object from an XML tree.""" assert xml_tree.tag == RosServiceHandleResponse.get_tag_name(), \ "Error: SCXML service response handler: XML tag name is not " + \ diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_timer.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_timer.py index 0f619288..9b10c7e4 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_timer.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_timer.py @@ -16,12 +16,15 @@ """Declaration of SCXML tags related to ROS Timers.""" from typing import Optional, Union -from scxml_converter.scxml_entries import (ScxmlBase, ScxmlTransition, - ScxmlExecutionBody, ScxmlRosDeclarationsContainer, - valid_execution_body, execution_body_from_xml, - as_plain_execution_body) from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import (ScxmlBase, ScxmlExecutionBody, + ScxmlRosDeclarationsContainer, + ScxmlTransition, + as_plain_execution_body, + execution_body_from_xml, + valid_execution_body) + class RosTimeRate(ScxmlBase): """Object used in the SCXML root to declare a new timer with its related tick rate.""" @@ -30,21 +33,23 @@ def __init__(self, name: str, rate_hz: float): self._name = name self._rate_hz = float(rate_hz) + @staticmethod def get_tag_name() -> str: return "ros_time_rate" + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "RosTimeRate": """Create a RosTimeRate object from an XML tree.""" assert xml_tree.tag == RosTimeRate.get_tag_name(), \ f"Error: SCXML rate timer: XML tag name is not {RosTimeRate.get_tag_name()}" timer_name = xml_tree.attrib.get("name") - timer_rate = xml_tree.attrib.get("rate_hz") - assert timer_name is not None and timer_rate is not None, \ + timer_rate_str = xml_tree.attrib.get("rate_hz") + assert timer_name is not None and timer_rate_str is not None, \ "Error: SCXML rate timer: 'name' or 'rate_hz' attribute not found in input xml." try: - timer_rate = float(timer_rate) - except ValueError: - raise ValueError("Error: SCXML rate timer: rate is not a number.") + timer_rate = float(timer_rate_str) + except ValueError as e: + raise ValueError("Error: SCXML rate timer: rate is not a number.") from e return RosTimeRate(timer_name, timer_rate) def check_validity(self) -> bool: diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_topic.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_topic.py index 05cc761d..5fa2beda 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_topic.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_ros_topic.py @@ -21,12 +21,15 @@ """ from typing import List, Optional, Union -from scxml_converter.scxml_entries import (RosField, ScxmlBase, ScxmlSend, ScxmlParam, - ScxmlTransition, ScxmlExecutionBody, - ScxmlRosDeclarationsContainer, - valid_execution_body, execution_body_from_xml, - as_plain_execution_body) from xml.etree import ElementTree as ET + +from scxml_converter.scxml_entries import (RosField, ScxmlBase, + ScxmlExecutionBody, ScxmlParam, + ScxmlRosDeclarationsContainer, + ScxmlSend, ScxmlTransition, + as_plain_execution_body, + execution_body_from_xml, + valid_execution_body) from scxml_converter.scxml_entries.utils import is_msg_type_known @@ -225,15 +228,15 @@ def from_xml_tree(xml_tree: ET.Element) -> ScxmlSend: topic_name = xml_tree.attrib.get("topic") assert topic_name is not None, \ "Error: SCXML topic publish: 'topic' attribute not found in input xml." - fields = [] + fields: List[RosField] = [] for field_xml in xml_tree: fields.append(RosField.from_xml_tree(field_xml)) - if len(fields) == 0: - fields = None return RosTopicPublish(topic_name, fields) def __init__(self, topic: Union[RosTopicPublisher, str], - fields: Optional[List[RosField]] = None): + fields: List[RosField] = None) -> None: + if fields is None: + fields = [] if isinstance(topic, RosTopicPublisher): self._topic = topic.get_topic_name() else: @@ -279,7 +282,7 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> Scx "Error: SCXML topic publish: invalid ROS instantiations." event_name = "ros_topic." + self._topic params = None if self._fields is None else \ - [field.as_plain_scxml() for field in self._fields] + [field.as_plain_scxml(ros_declarations) for field in self._fields] return ScxmlSend(event_name, params) def as_xml(self) -> ET.Element: diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_state.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_state.py index ce6bb041..ad1be9db 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_state.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_state.py @@ -17,13 +17,17 @@ A single state in SCXML. In XML, it has the tag `state`. """ -from typing import List, Optional, Union +from typing import List, Optional, Sequence, Union from xml.etree import ElementTree as ET -from scxml_converter.scxml_entries import ( - ScxmlBase, ScxmlTransition, ScxmlRosDeclarationsContainer, - ScxmlRosTransitions, ScxmlExecutableEntry, ScxmlExecutionBody, - as_plain_execution_body, execution_body_from_xml, valid_execution_body) +from scxml_converter.scxml_entries import (ScxmlBase, ScxmlExecutableEntry, + ScxmlExecutionBody, + ScxmlRosDeclarationsContainer, + ScxmlRosTransitions, + ScxmlTransition, + as_plain_execution_body, + execution_body_from_xml, + valid_execution_body) class ScxmlState(ScxmlBase): @@ -33,50 +37,42 @@ class ScxmlState(ScxmlBase): def get_tag_name() -> str: return "state" - def __init__(self, id: str, *, - on_entry: Optional[ScxmlExecutionBody] = None, - on_exit: Optional[ScxmlExecutionBody] = None, - body: Optional[List[ScxmlTransition]] = None): - self._id = id - self._on_entry = on_entry - self._on_exit = on_exit - self._body = body - - def get_id(self) -> str: - return self._id - - def get_onentry(self) -> Optional[ScxmlExecutionBody]: - return self._on_entry - - def get_onexit(self) -> Optional[ScxmlExecutionBody]: - return self._on_exit - - def get_body(self) -> Optional[List[ScxmlTransition]]: - """Return the transitions leaving the state.""" - return self._body + def __init__(self, id_: str, *, + on_entry: ScxmlExecutionBody = None, + on_exit: ScxmlExecutionBody = None, + body: List[ScxmlTransition] = None): + self._id = id_ + self._on_entry = on_entry if on_entry is not None else [] + self._on_exit = on_exit if on_exit is not None else [] + self._body: List[ScxmlTransition] = body if body is not None else [] + @staticmethod def from_xml_tree(xml_tree: ET.Element) -> "ScxmlState": """Create a ScxmlState object from an XML tree.""" assert xml_tree.tag == ScxmlState.get_tag_name(), \ f"Error: SCXML state: XML tag name is not {ScxmlState.get_tag_name()}." - id = xml_tree.attrib.get("id") - assert id is not None and len(id) > 0, "Error: SCXML state: id is not valid." - scxml_state = ScxmlState(id) + id_ = xml_tree.attrib.get("id") + assert id_ is not None and len(id_) > 0, "Error: SCXML state: id is not valid." + scxml_state = ScxmlState(id_) # Get the onentry and onexit execution bodies - on_entry = xml_tree.findall("onentry") - if on_entry is not None and len(on_entry) == 0: - on_entry = None - assert on_entry is None or len(on_entry) == 1, \ + on_entry_xml = xml_tree.findall("onentry") + if on_entry_xml is None: + on_entry = [] + else: + on_entry = on_entry_xml + assert len(on_entry) == 0 or len(on_entry) == 1, \ f"Error: SCXML state: {len(on_entry)} onentry tags found, expected 0 or 1." - on_exit = xml_tree.findall("onexit") - if on_exit is not None and len(on_exit) == 0: - on_exit = None - assert on_exit is None or len(on_exit) == 1, \ + on_exit_xml = xml_tree.findall("onexit") + if on_exit_xml is None: + on_exit = [] + else: + on_exit = on_exit_xml + assert len(on_exit) == 0 or len(on_exit) == 1, \ f"Error: SCXML state: {len(on_exit)} onexit tags found, expected 0 or 1." - if on_entry is not None: + if len(on_entry) > 0: for exec_entry in execution_body_from_xml(on_entry[0]): scxml_state.append_on_entry(exec_entry) - if on_exit is not None: + if len(on_exit) > 0: for exec_entry in execution_body_from_xml(on_exit[0]): scxml_state.append_on_exit(exec_entry) # Get the transitions in the state body @@ -84,9 +80,23 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlState": scxml_state.add_transition(body_entry) return scxml_state - def _transitions_from_xml(xml_tree: ET.Element) -> List[ScxmlTransition]: + def get_id(self) -> str: + return self._id + + def get_onentry(self) -> ScxmlExecutionBody: + return self._on_entry + + def get_onexit(self) -> ScxmlExecutionBody: + return self._on_exit + + def get_body(self) -> List[ScxmlTransition]: + """Return the transitions leaving the state.""" + return self._body + + @classmethod + def _transitions_from_xml(cls, xml_tree: ET.Element) -> List[ScxmlTransition]: transitions: List[ScxmlTransition] = [] - tag_to_cls = {cls.get_tag_name(): cls for cls in ScxmlRosTransitions} + tag_to_cls = {cls.get_tag_name(): cls for cls in ScxmlTransition.__subclasses__()} tag_to_cls.update({ScxmlTransition.get_tag_name(): ScxmlTransition}) for child in xml_tree: if child.tag in tag_to_cls: @@ -94,18 +104,12 @@ def _transitions_from_xml(xml_tree: ET.Element) -> List[ScxmlTransition]: return transitions def add_transition(self, transition: ScxmlTransition): - if self._body is None: - self._body = [] self._body.append(transition) def append_on_entry(self, executable_entry: ScxmlExecutableEntry): - if self._on_entry is None: - self._on_entry = [] self._on_entry.append(executable_entry) def append_on_exit(self, executable_entry: ScxmlExecutableEntry): - if self._on_exit is None: - self._on_exit = [] self._on_exit.append(executable_entry) def check_validity(self) -> bool: @@ -146,33 +150,33 @@ def check_valid_ros_instantiations(self, print("Error: SCXML state: found invalid transition in state body.") return valid_entry and valid_exit and valid_body - def _check_valid_ros_instantiations(body: List[Union[ScxmlExecutableEntry, ScxmlTransition]], + @staticmethod + def _check_valid_ros_instantiations(body: Sequence[Union[ScxmlExecutableEntry, ScxmlTransition]], ros_declarations: ScxmlRosDeclarationsContainer) -> bool: """Check if the ros instantiations have been declared in the body.""" - return body is None or \ + return len(body) == 0 or \ all(entry.check_valid_ros_instantiations(ros_declarations) for entry in body) def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "ScxmlState": """Convert the ROS-specific entries to be plain SCXML""" plain_entry = as_plain_execution_body(self._on_entry, ros_declarations) plain_exit = as_plain_execution_body(self._on_exit, ros_declarations) - plain_body = as_plain_execution_body(self._body, ros_declarations) + plain_body = [entry.as_plain_scxml(ros_declarations) for entry in self._body] return ScxmlState(self._id, on_entry=plain_entry, on_exit=plain_exit, body=plain_body) def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid state object." xml_state = ET.Element(ScxmlState.get_tag_name(), {"id": self._id}) - if self._on_entry is not None: + if len(self._on_entry) > 0: xml_on_entry = ET.Element('onentry') for executable_entry in self._on_entry: xml_on_entry.append(executable_entry.as_xml()) xml_state.append(xml_on_entry) - if self._on_exit is not None: + if len(self._on_exit) > 0: xml_on_exit = ET.Element('onexit') for executable_entry in self._on_exit: xml_on_exit.append(executable_entry.as_xml()) xml_state.append(xml_on_exit) - if self._body is not None: - for transition in self._body: - xml_state.append(transition.as_xml()) + for transition in self._body: + xml_state.append(transition.as_xml()) return xml_state diff --git a/scxml_converter/src/scxml_converter/scxml_entries/scxml_transition.py b/scxml_converter/src/scxml_converter/scxml_entries/scxml_transition.py index 8910c49f..5b3fbc02 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/scxml_transition.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/scxml_transition.py @@ -18,12 +18,14 @@ """ from typing import List, Optional -from scxml_converter.scxml_entries import (ScxmlBase, ScxmlExecutionBody, ScxmlExecutableEntry, - ScxmlRosDeclarationsContainer, valid_execution_body, - execution_body_from_xml) - from xml.etree import ElementTree as ET +from scxml_converter.scxml_entries import (ScxmlBase, ScxmlExecutableEntry, + ScxmlExecutionBody, + ScxmlRosDeclarationsContainer, + execution_body_from_xml, + valid_execution_body) + class ScxmlTransition(ScxmlBase): """This class represents a single scxml state.""" @@ -38,8 +40,8 @@ def from_xml_tree(xml_tree: ET.Element) -> "ScxmlTransition": f"Error: SCXML transition: XML root tag name is not {ScxmlTransition.get_tag_name()}." target = xml_tree.get("target") assert target is not None, "Error: SCXML transition: target attribute not found." - events = xml_tree.get("event") - events = events.split(" ") if events is not None else None + events_str = xml_tree.get("event") + events = events_str.split(" ") if events_str is not None else [] condition = xml_tree.get("cond") exec_body = execution_body_from_xml(xml_tree) exec_body = exec_body if exec_body is not None else None @@ -67,7 +69,7 @@ def __init__(self, body), "Error SCXML transition: invalid body provided." self._target = target self._body = body - self._events = events + self._events = events if events is not None else [] self._condition = condition def get_target_state_id(self) -> str: @@ -87,8 +89,6 @@ def get_executable_body(self) -> ScxmlExecutionBody: return self._body if self._body is not None else [] def add_event(self, event: str): - if self._events is None: - self._events = [] self._events.append(event) def append_body_executable_entry(self, exec_entry: ScxmlExecutableEntry): @@ -137,7 +137,7 @@ def as_plain_scxml(self, ros_declarations: ScxmlRosDeclarationsContainer) -> "Sc def as_xml(self) -> ET.Element: assert self.check_validity(), "SCXML: found invalid transition." xml_transition = ET.Element(ScxmlTransition.get_tag_name(), {"target": self._target}) - if self._events is not None: + if len(self._events) > 0: xml_transition.set("event", " ".join(self._events)) if self._condition is not None: xml_transition.set("cond", self._condition) diff --git a/scxml_converter/src/scxml_converter/scxml_entries/utils.py b/scxml_converter/src/scxml_converter/scxml_entries/utils.py index a4e8f8cb..493ae98d 100644 --- a/scxml_converter/src/scxml_converter/scxml_entries/utils.py +++ b/scxml_converter/src/scxml_converter/scxml_entries/utils.py @@ -15,9 +15,9 @@ """Collection of various utilities for scxml entries.""" -from typing import Dict, List, Tuple, Optional -from scxml_converter.scxml_entries.scxml_ros_field import RosField +from typing import Dict, List, Optional, Tuple +from scxml_converter.scxml_entries.scxml_ros_field import RosField MSG_TYPE_SUBSTITUTIONS = { "boolean": "bool", @@ -225,42 +225,42 @@ def get_service_client_type(self, service_name: str) -> Optional[str]: def get_service_server_type(self, service_name: str) -> Optional[str]: return self._service_servers.get(service_name, None) - def check_valid_srv_req_fields(self, service_name: str, fields: List[RosField]) -> bool: + def check_valid_srv_req_fields(self, service_name: str, ros_fields: List[RosField]) -> bool: """Check if the provided fields match the service request type.""" req_type = self.get_service_client_type(service_name) if req_type is None: print(f"Error: SCXML ROS declarations: unknown service client {service_name}.") return False req_fields, _ = get_srv_type_params(req_type) - for field in fields: - if field.get_name() not in req_fields: + for ros_field in ros_fields: + if ros_field.get_name() not in req_fields: print("Error: SCXML ROS declarations: " - f"unknown field {field.get_name()} in service request.") + f"unknown field {ros_field.get_name()} in service request.") return False - req_fields.pop(field.get_name()) + req_fields.pop(ros_field.get_name()) if len(req_fields) > 0: print("Error: SCXML ROS declarations: missing fields in service request.") - for field in req_fields.keys(): - print(f"\t-{field}.") + for req_field in req_fields.keys(): + print(f"\t-{req_field}.") return False return True - def check_valid_srv_res_fields(self, service_name: str, fields: List[RosField]) -> bool: + def check_valid_srv_res_fields(self, service_name: str, ros_fields: List[RosField]) -> bool: """Check if the provided fields match the service response type.""" res_type = self.get_service_server_type(service_name) if res_type is None: print(f"Error: SCXML ROS declarations: unknown service server {service_name}.") return False _, res_fields = get_srv_type_params(res_type) - for field in fields: - if field.get_name() not in res_fields: + for ros_field in ros_fields: + if ros_field.get_name() not in res_fields: print("Error: SCXML ROS declarations: " - f"unknown field {field.get_name()} in service response.") + f"unknown field {ros_field.get_name()} in service response.") return False - res_fields.pop(field.get_name()) + res_fields.pop(ros_field.get_name()) if len(res_fields) > 0: print("Error: SCXML ROS declarations: missing fields in service response.") - for field in res_fields.keys(): - print(f"\t-{field}.") + for res_field in res_fields.keys(): + print(f"\t-{res_field}.") return False return True diff --git a/scxml_converter/test/test_systemtest_scxml_entries.py b/scxml_converter/test/test_systemtest_scxml_entries.py index 99a0a792..2f79400a 100644 --- a/scxml_converter/test/test_systemtest_scxml_entries.py +++ b/scxml_converter/test/test_systemtest_scxml_entries.py @@ -15,13 +15,16 @@ import os -from scxml_converter.scxml_entries import (ScxmlAssign, ScxmlData, ScxmlDataModel, ScxmlParam, - ScxmlRoot, ScxmlSend, ScxmlState, ScxmlTransition, - RosTimeRate, RosTopicPublisher, RosTopicSubscriber, - RosRateCallback, RosTopicPublish, RosTopicCallback, - RosField) from test_utils import canonicalize_xml, remove_empty_lines +from scxml_converter.scxml_entries import (RosField, RosRateCallback, + RosTimeRate, RosTopicCallback, + RosTopicPublish, RosTopicPublisher, + RosTopicSubscriber, ScxmlAssign, + ScxmlData, ScxmlDataModel, + ScxmlParam, ScxmlRoot, ScxmlSend, + ScxmlState, ScxmlTransition) + def test_battery_drainer_from_code(): """ @@ -128,7 +131,7 @@ def test_battery_drainer_ros_from_code(): def _test_xml_parsing(xml_file_path: str, valid_xml: bool = True): - # TODO: Input path to scxml file fro args + # TODO: Input path to scxml file from args scxml_root = ScxmlRoot.from_scxml_file(xml_file_path) # Check output xml if valid_xml: diff --git a/scxml_converter/test/test_systemtest_xml.py b/scxml_converter/test/test_systemtest_xml.py index d8bf3064..42634141 100644 --- a/scxml_converter/test/test_systemtest_xml.py +++ b/scxml_converter/test/test_systemtest_xml.py @@ -14,7 +14,8 @@ # limitations under the License. import os -from test_utils import canonicalize_xml, remove_empty_lines + +from test_utils import canonicalize_xml, remove_empty_lines from scxml_converter.bt_converter import bt_converter from scxml_converter.scxml_entries import ScxmlRoot