Skip to content

Commit

Permalink
Integrate action functionalities in RosDeclarationsContainer
Browse files Browse the repository at this point in the history
Signed-off-by: Marco Lampacrescia <[email protected]>
  • Loading branch information
MarcoLm993 committed Aug 19, 2024
1 parent ff9777d commit 97255b4
Showing 1 changed file with 60 additions and 10 deletions.
70 changes: 60 additions & 10 deletions scxml_converter/src/scxml_converter/scxml_entries/ros_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""Collection of SCXML utilities related to ROS functionalities."""

from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Tuple, Type

from scxml_converter.scxml_entries.scxml_ros_field import RosField

Expand Down Expand Up @@ -325,12 +325,21 @@ def is_publisher_defined(self, pub_name: str) -> bool:
def is_subscriber_defined(self, sub_name: str) -> bool:
return sub_name in self._subscribers

def is_service_client_defined(self, client_name: str) -> bool:
return client_name in self._service_clients

def is_service_server_defined(self, server_name: str) -> bool:
return server_name in self._service_servers

def is_action_client_defined(self, client_name: str) -> bool:
return client_name in self._action_clients

def is_action_server_defined(self, server_name: str) -> bool:
return server_name in self._action_servers

def is_timer_defined(self, timer_name: str) -> bool:
return timer_name in self._timers

def get_timers(self) -> Dict[str, float]:
return self._timers

def get_publisher_info(self, pub_name: str) -> Tuple[str, str]:
"""Provide a publisher topic name and type"""
pub_info = self._publishers.get(pub_name)
Expand Down Expand Up @@ -358,16 +367,27 @@ def get_service_client_info(self, client_name: str) -> Tuple[str, str]:
f"Error: SCXML ROS declarations: unknown service client {client_name}."
return client_info

def is_service_client_defined(self, client_name: str) -> bool:
return client_name in self._service_clients
def get_action_server_info(self, server_name: str) -> Tuple[str, str]:
"""Given an action server name, provide the related action name and type."""
server_info = self._action_servers.get(server_name)
assert server_info is not None, \
f"Error: SCXML ROS declarations: unknown action server {server_name}."
return server_info

def is_service_server_defined(self, server_name: str) -> bool:
return server_name in self._service_servers
def get_action_client_info(self, client_name: str) -> Tuple[str, str]:
"""Given an action client name, provide the related action name and type."""
client_info = self._action_clients.get(client_name)
assert client_info is not None, \
f"Error: SCXML ROS declarations: unknown action client {client_name}."
return client_info

def get_timers(self) -> Dict[str, float]:
return self._timers

def check_valid_srv_req_fields(self, client_name: str, ros_fields: List[RosField]) -> bool:
"""Check if the provided fields match the service request type."""
_, req_type = self.get_service_client_info(client_name)
req_fields, _ = get_srv_type_params(req_type)
_, service_type = self.get_service_client_info(client_name)
req_fields, _ = get_srv_type_params(service_type)
if not check_all_fields_known(ros_fields, req_fields):
print(f"Error: SCXML ROS declarations: Srv request {client_name} has invalid fields.")
return False
Expand All @@ -381,3 +401,33 @@ def check_valid_srv_res_fields(self, server_name: str, ros_fields: List[RosField
print(f"Error: SCXML ROS declarations: Srv response {server_name} has invalid fields.")
return False
return True

def check_valid_action_goal_fields(self, client_name: str, ros_fields: List[RosField]) -> bool:
"""Check if the provided fields match the action goal type."""
_, action_type = self.get_action_client_info(client_name)
goal_fields, _, _ = get_action_type_params(action_type)
if not check_all_fields_known(ros_fields, goal_fields):
print(f"Error: SCXML ROS declarations: Action goal {client_name} has invalid fields.")
return False
return True

def check_valid_action_feedback_fields(
self, server_name: str, ros_fields: List[RosField]) -> bool:
"""Check if the provided fields match the action feedback type."""
_, action_type = self.get_action_server_info(server_name)
_, feedback_fields, _ = get_action_type_params(action_type)
if not check_all_fields_known(ros_fields, feedback_fields):
print(f"Error: SCXML ROS declarations: Action feedback {server_name} "
"has invalid fields.")
return False
return True

def check_valid_action_result_fields(
self, server_name: str, ros_fields: List[RosField]) -> bool:
"""Check if the provided fields match the action result type."""
_, action_type = self.get_action_server_info(server_name)
_, _, result_fields = get_action_type_params(action_type)
if not check_all_fields_known(ros_fields, result_fields):
print(f"Error: SCXML ROS declarations: Action result {server_name} has invalid fields.")
return False
return True

0 comments on commit 97255b4

Please sign in to comment.