From c8bfe6a0febb2961856360e92b04180c5c17e027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20D=C4=85browski?= Date: Mon, 3 Feb 2025 18:21:59 +0100 Subject: [PATCH] Add benchmarking code from manipulation demo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Kacper DÄ…browski --- src/rai/rai/tools/ros/manipulation.py | 9 +- src/rai_benchmarks/__init__.py | 0 src/rai_benchmarks/benchmark.py | 26 +++ src/rai_benchmarks/manager.py | 182 ++++++++++++++++ .../scenarios/longest_object.py | 123 +++++++++++ .../scenarios/move_to_the_left.py | 125 +++++++++++ src/rai_benchmarks/scenarios/place_on_top.py | 118 +++++++++++ src/rai_benchmarks/scenarios/replace_types.py | 194 ++++++++++++++++++ src/rai_benchmarks/scenarios/scenario_base.py | 145 +++++++++++++ src/rai_benchmarks/showcase.py | 27 +++ .../tools/segmentation_tools.py | 7 +- 11 files changed, 951 insertions(+), 5 deletions(-) create mode 100644 src/rai_benchmarks/__init__.py create mode 100644 src/rai_benchmarks/benchmark.py create mode 100644 src/rai_benchmarks/manager.py create mode 100644 src/rai_benchmarks/scenarios/longest_object.py create mode 100644 src/rai_benchmarks/scenarios/move_to_the_left.py create mode 100644 src/rai_benchmarks/scenarios/place_on_top.py create mode 100644 src/rai_benchmarks/scenarios/replace_types.py create mode 100644 src/rai_benchmarks/scenarios/scenario_base.py create mode 100644 src/rai_benchmarks/showcase.py diff --git a/src/rai/rai/tools/ros/manipulation.py b/src/rai/rai/tools/ros/manipulation.py index 8deacd603..1ddf22b56 100644 --- a/src/rai/rai/tools/ros/manipulation.py +++ b/src/rai/rai/tools/ros/manipulation.py @@ -29,6 +29,7 @@ from rclpy.node import Node from tf2_geometry_msgs import do_transform_pose +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.tools.utils import TF2TransformFetcher from rai_interfaces.srv import ManipulatorMoveTo @@ -156,9 +157,13 @@ class GetObjectPositionsTool(BaseTool): node: Node get_grabbing_point_tool: GetGrabbingPointTool - def __init__(self, node: Node, **kwargs): + def __init__(self, connector: ROS2ARIConnector, node: Node, **kwargs): super(GetObjectPositionsTool, self).__init__( - node=node, get_grabbing_point_tool=GetGrabbingPointTool(node=node), **kwargs + node=node, + get_grabbing_point_tool=GetGrabbingPointTool( + connector=connector, node=node + ), + **kwargs, ) args_schema: Type[GetObjectPositionsToolInput] = GetObjectPositionsToolInput diff --git a/src/rai_benchmarks/__init__.py b/src/rai_benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/rai_benchmarks/benchmark.py b/src/rai_benchmarks/benchmark.py new file mode 100644 index 000000000..b193af1ab --- /dev/null +++ b/src/rai_benchmarks/benchmark.py @@ -0,0 +1,26 @@ +import rclpy +from manager import RaiBenchmarkManager +from rclpy.executors import MultiThreadedExecutor +from scenarios.longest_object import LongestObject +from scenarios.move_to_the_left import MoveToTheLeft +from scenarios.place_on_top import PlaceOnTop +from scenarios.replace_types import ReplaceTypes + + +def main(args=None): + rclpy.init(args=args) + + manager = RaiBenchmarkManager( + [PlaceOnTop, LongestObject, MoveToTheLeft, ReplaceTypes], list(range(4)) + ) + + executor = MultiThreadedExecutor(2) + executor.add_node(manager) + executor.spin() + + manager.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/src/rai_benchmarks/manager.py b/src/rai_benchmarks/manager.py new file mode 100644 index 000000000..647009dc4 --- /dev/null +++ b/src/rai_benchmarks/manager.py @@ -0,0 +1,182 @@ +import random +import time +from threading import Thread + +from gazebo_msgs.srv import DeleteEntity, SpawnEntity +from geometry_msgs.msg import Point, Quaternion +from langchain_core.messages import HumanMessage +from rclpy.node import Node +from rclpy.task import Future +from scenarios.scenario_base import ScenarioBase +from tf2_ros import Buffer, TransformListener + +from rai.agents.conversational_agent import create_conversational_agent +from rai.communication.ros2.connectors import ROS2ARIConnector +from rai.node import RaiBaseNode +from rai.tools.ros2.topics import GetROS2ImageTool +from rai.tools.ros.manipulation import GetObjectPositionsTool, MoveToPointTool +from rai.tools.ros.native import Ros2GetTopicsNamesAndTypesTool +from rai.utils.model_initialization import get_llm_model +from rai_interfaces.srv import ManipulatorMoveTo + + +class ScenarioManager(Node): + """ + A class responsible for playing the scenarios + """ + + def __init__(self, scenario_types, seeds=[]): + """ + Initializes the ScenarioManager + + Args: + scenario_types: A list of scenario classes to play + seeds: A list of seeds to use for each scenario + """ + super().__init__("scenario_manager") + self.scenario_types = scenario_types + self.seeds = seeds + + self.spawn_client = self.create_client(SpawnEntity, "/spawn_entity") + self.delete_client = self.create_client(DeleteEntity, "/delete_entity") + self.manipulator_client = self.create_client( + ManipulatorMoveTo, "/manipulator_move_to" + ) + self.tf2_buffer = Buffer() + self.tf2_listener = TransformListener(self.tf2_buffer, self) + + while not self.spawn_client.wait_for_service(timeout_sec=1.0): + self.get_logger().info("service not available, waiting again...") + while not self.delete_client.wait_for_service(timeout_sec=1.0): + self.get_logger().info("service not available, waiting again...") + + self.timer = self.create_timer(1.0, self.timer_callback) + self.scenario: ScenarioBase = None + self.current_scenario = 0 + self.agent_thread: Thread = None + self.manipulator_ready = False + self.scores = [] + + def _init_scenario(self): + self.scenario = self.scenario_types[self.current_scenario]( + self.spawn_client, self.delete_client, self.manipulator_client, self + ) + self.manipulator_ready = False + request = ManipulatorMoveTo.Request() + request.target_pose.pose.orientation = Quaternion( + x=0.923880, y=-0.382683, z=0.0, w=0.0 + ) + request.target_pose.pose.position = Point(x=0.2, y=0.0, z=0.2) + if self.current_scenario < len(self.seeds): + random.seed(self.seeds[self.current_scenario]) + else: + random.seed(42) + + def callback(future: Future): + self.manipulator_ready = True + self.scenario.reset() + + self.scenario.manipulator_client.call_async(request).add_done_callback(callback) + + def _terminate_scenario(self): + self.get_logger().info(f"Scenario terminated with score {self.scores[-1]}") + self.scenario = None + self.tf2_buffer = Buffer() + self.tf2_listener = TransformListener(self.tf2_buffer, self) + if self.current_scenario == len(self.scenario_types) - 1: + self.get_logger().info( + f"All scenarios are completed, with scores: {self.scores}" + ) + request = ManipulatorMoveTo.Request() + request.target_pose.pose.orientation = Quaternion( + x=0.923880, y=-0.382683, z=0.0, w=0.0 + ) + request.target_pose.pose.position = Point(x=0.2, y=0.0, z=0.2) + + def callback(future: Future): + self.manipulator_ready = True + self.executor.shutdown() + + self.manipulator_client.call_async(request).add_done_callback(callback) + self.timer.cancel() + return + self.current_scenario = (self.current_scenario + 1) % len(self.scenario_types) + self.manipulator_ready = False + + def timer_callback(self): + if self.scenario is None: + self._init_scenario() + + if not self.manipulator_ready: + return + + progress, terminated = self.scenario.step() + if terminated and not (self.agent_thread and self.agent_thread.is_alive()): + self.scores.append(progress) + self._terminate_scenario() + return + if self.agent_thread and not self.agent_thread.is_alive(): + self.get_logger().info( + "Agent failed to fulfill the task, terminating the scenario." + ) + self.scores.append(progress) + self._terminate_scenario() + + +class RaiBenchmarkManager(ScenarioManager): + """ + A class responsible for playing the scenarios and running the conversational agent for each scenario + """ + + def __init__(self, scenario_types, seeds=[]): + super().__init__(scenario_types, seeds) + self.agent = None + + def _init_scenario(self): + super()._init_scenario() + + self.rai_node = RaiBaseNode(node_name="manipulation_demo") + self.rai_node.declare_parameter("conversion_ratio", 1.0) + + connector = ROS2ARIConnector() + tools = [ + GetObjectPositionsTool( + connector=connector, + node=self.rai_node, + target_frame="panda_link0", + source_frame="RGBDCamera5", + camera_topic="/color_image5", + depth_topic="/depth_image5", + camera_info_topic="/color_camera_info5", + ), + MoveToPointTool(node=self.rai_node, manipulator_frame="panda_link0"), + GetROS2ImageTool(node=self.rai_node, connector=connector), + Ros2GetTopicsNamesAndTypesTool(node=self.rai_node), + ] + + llm = get_llm_model(model_type="complex_model") + + system_prompt = """ + You are a robotic arm with interfaces to detect and manipulate objects. + Here are the coordinates information: + x - front to back (positive is forward) + y - left to right (positive is right) + z - up to down (positive is up) + + Before starting the task, make sure to grab the camera image to understand the environment. + """ + + self.agent = create_conversational_agent( + llm=llm, + tools=tools, + system_prompt=system_prompt, + ) + + def run_agent(): + time.sleep(1) + self.agent.invoke( + {"messages": [HumanMessage(content=self.scenario.get_prompt())]} + )["messages"][-1].pretty_print() + + self.agent_thread = Thread(target=run_agent) + self.agent_thread.start() diff --git a/src/rai_benchmarks/scenarios/longest_object.py b/src/rai_benchmarks/scenarios/longest_object.py new file mode 100644 index 000000000..7656fa98c --- /dev/null +++ b/src/rai_benchmarks/scenarios/longest_object.py @@ -0,0 +1,123 @@ +import rclpy +from geometry_msgs.msg import Point, PoseStamped, Quaternion +from rclpy.client import Client +from rclpy.node import Node +from rclpy.task import Future +from scenarios.scenario_base import ScenarioBase + +from rai_interfaces.srv import ManipulatorMoveTo + + +class LongestObject(ScenarioBase): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def get_prompt(self): + return "Put the longest object from the table into the toy box." + + def reset(self): + super().reset() + + prefabs = ["apple", "yellow_cube", "blue_cube", "carrot"] + self.spawn_entities_in_random_positions(prefabs, prefabs) + + pose = PoseStamped() + pose.header.frame_id = "world" + pose.pose.position = Point(x=0.4, y=-0.5, z=0.1) + pose.pose.orientation = Quaternion(x=0.0, y=0.0, z=0.0, w=1.0) + pose_transformed = self.node.tf2_buffer.transform( + pose, "odom", timeout=rclpy.time.Duration(seconds=5.0) + ).pose + self.spawn_entity("toy_box", "toy_box", pose_transformed) + + def calculate_progress(self): + if len(self.entities) == 0: + return 0.0 + + carrot_position = self.pose_transformed(self.get_entity_pose("carrot")).position + toy_box_position = self.pose_transformed( + self.get_entity_pose("toy_box") + ).position + + def distance(a, b): + return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2 + (a.z - b.z) ** 2) ** 0.5 + + return max(0.0, 1.0 - 3.0 * distance(carrot_position, toy_box_position)) + + def step(self): + if len(self.entities) == 0: + return 0.0, False + + progress = self.calculate_progress() + + return progress, progress >= 0.5 + + +class LongestObjectAuto(LongestObject): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + self.manipulator_busy = False + + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def reset(self): + super().reset() + + self.manipulator_busy = False + self.manipulator_queue = [] + + def place_on_top(self, bot_object: str, top_object: str): + pose = self.get_entity_pose(top_object) + pose.position.z += 0.1 + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = False + self.manipulator_queue.append(req) + + pose = self.get_entity_pose(bot_object) + pose.position.z += 0.2 + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = True + self.manipulator_queue.append(req) + + def move_callback(self, future: Future): + result = future.result() + if result.success: + self.node.get_logger().debug("Move performed") + else: + self.node.get_logger().error("Failed to perform move") + self.manipulator_busy = False + + def step(self): + if len(self.entities) == 0: + return 0.0, False + + progress = self.calculate_progress() + + if not self.manipulator_busy: + if len(self.manipulator_queue) == 0 and progress < 0.8: + self.place_on_top("toy_box", "carrot") + + if len(self.manipulator_queue) > 0: + req = self.manipulator_queue.pop(0) + self.manipulator_busy = True + self.manipulator_client.call_async(req).add_done_callback( + self.move_callback + ) + + return progress, progress >= 0.8 and not self.manipulator_busy diff --git a/src/rai_benchmarks/scenarios/move_to_the_left.py b/src/rai_benchmarks/scenarios/move_to_the_left.py new file mode 100644 index 000000000..55b26bc3a --- /dev/null +++ b/src/rai_benchmarks/scenarios/move_to_the_left.py @@ -0,0 +1,125 @@ +import rclpy +from geometry_msgs.msg import Point, PoseStamped, Quaternion +from rclpy.client import Client +from rclpy.node import Node +from rclpy.task import Future +from scenarios.scenario_base import ScenarioBase + +from rai_interfaces.srv import ManipulatorMoveTo + + +class MoveToTheLeft(ScenarioBase): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def get_prompt(self): + return "There are 5 apples on the right half of the table. Their Y position is positive. Move each of them to the left half of the table, such that they dont collide with each other. The Y position of the apples should be negative after the task is completed. First grab one apple using the 'grab' tool, and then drop it in the appropriate position using the 'drop' tool. Repeat this procedure for each apple." + + def reset(self): + super().reset() + + for i in range(5): + pose = PoseStamped() + pose.header.frame_id = "world" + pose.pose.position = Point(x=0.4, y=float(i) / 10 + 0.15, z=0.1) + pose.pose.orientation = Quaternion(x=0.0, y=0.0, z=0.0, w=1.0) + + pose_transformed = self.node.tf2_buffer.transform( + pose, "odom", timeout=rclpy.time.Duration(seconds=5.0) + ) + + prefab = "apple" + self.spawn_entity(prefab, f"{prefab}{i}", pose_transformed.pose) + + def calculate_progress(self): + num_apples = sum(1 for name in self.entities if name.startswith("apple")) + num_good_apples = sum( + 1 + for name in self.entities + if name.startswith("apple") + and self.pose_transformed(self.get_entity_pose(name)).position.y < 0.0 + ) + return num_good_apples / num_apples + + def step(self): + if len(self.entities) == 0: # The entities are not spawned yet + return 0.0, False + + progress = self.calculate_progress() + + return progress, progress >= 1.0 + + +class MoveToTheLeftAuto(MoveToTheLeft): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + super().__init__(spawn_client, delete_client, manipulator_client, node) + self.manipulator_busy = False + + def reset(self): + super().reset() + + self.manipulator_busy = False + self.manipulator_queue = [] + + def move_to_the_left(self, name: str): + pose = self.get_entity_pose(name) + pose.position.z += 0.1 + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = False + self.manipulator_queue.append(req) + + pose.position.y -= 0.6 + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = True + self.manipulator_queue.append(req) + + def move_callback(self, future: Future): + result = future.result() + if result.success: + self.node.get_logger().debug("Move performed") + else: + self.node.get_logger().error("Failed to perform move") + self.manipulator_busy = False + + def step(self): + if len(self.entities) == 0: # The entities are not spawned yet + return 0.0, False + + progress = self.calculate_progress() + + if not self.manipulator_busy: + if len(self.manipulator_queue) == 0: + for name in self.entities: + pose = self.pose_transformed(self.get_entity_pose(name)) + if pose.position.y > 0.0: + self.node.get_logger().info( + f"Moving {name} to the left from {pose.position.y}" + ) + self.move_to_the_left(name) + break + + if len(self.manipulator_queue) > 0: + req = self.manipulator_queue.pop(0) + self.manipulator_busy = True + self.manipulator_client.call_async(req).add_done_callback( + self.move_callback + ) + + return progress, progress >= 1.0 and not self.manipulator_busy diff --git a/src/rai_benchmarks/scenarios/place_on_top.py b/src/rai_benchmarks/scenarios/place_on_top.py new file mode 100644 index 000000000..e3afe123b --- /dev/null +++ b/src/rai_benchmarks/scenarios/place_on_top.py @@ -0,0 +1,118 @@ +from rclpy.client import Client +from rclpy.node import Node +from rclpy.task import Future +from scenarios.scenario_base import ScenarioBase + +from rai_interfaces.srv import ManipulatorMoveTo + + +class PlaceOnTop(ScenarioBase): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + self.top_object = None + self.bot_object = None + + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def get_prompt(self): + return "Place the yellow cube on top of the blue cube. Remember to increase the Z position of the 'drop' task by around 0.2 to avoid collision." + + def reset(self): + super().reset() + + prefabs = ["apple", "yellow_cube", "blue_cube", "carrot"] + self.spawn_entities_in_random_positions(prefabs, prefabs) + self.bot_object = "blue_cube" + self.top_object = "yellow_cube" + + def calculate_progress(self): + if self.top_object is None or self.bot_object is None: + return 0.0 + + top_pose = self.pose_transformed(self.get_entity_pose(self.top_object)) + bot_pose = self.pose_transformed(self.get_entity_pose(self.bot_object)) + + goal_position = bot_pose.position + goal_position.z += 0.05 + + def distance(a, b): + return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2 + (a.z - b.z) ** 2) ** 0.5 + + return max(0.0, 1.0 - 10.0 * distance(top_pose.position, goal_position)) + + def step(self): + if len(self.entities) == 0: + return 0.0, False + + progress = self.calculate_progress() + + return progress, progress >= 0.8 + + +class PlaceOnTopAuto(PlaceOnTop): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + self.manipulator_busy = False + + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def reset(self): + super().reset() + + self.manipulator_busy = False + self.manipulator_queue = [] + + def place_on_top(self, bot_object: str, top_object: str): + pose = self.get_entity_pose(top_object) + pose.position.z += 0.1 + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = False + self.manipulator_queue.append(req) + + pose = self.get_entity_pose(bot_object) + pose.position.z += 0.2 + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = self.pose_transformed(pose) + req.final_gripper_state = True + self.manipulator_queue.append(req) + + def move_callback(self, future: Future): + result = future.result() + if result.success: + self.node.get_logger().debug("Move performed") + else: + self.node.get_logger().error("Failed to perform move") + self.manipulator_busy = False + + def step(self): + if len(self.entities) == 0: + return 0.0, False + + progress = self.calculate_progress() + + if not self.manipulator_busy: + if len(self.manipulator_queue) == 0 and progress < 0.8: + self.place_on_top(self.bot_object, self.top_object) + + if len(self.manipulator_queue) > 0: + req = self.manipulator_queue.pop(0) + self.manipulator_busy = True + self.manipulator_client.call_async(req).add_done_callback( + self.move_callback + ) + + return progress, progress >= 0.8 and not self.manipulator_busy diff --git a/src/rai_benchmarks/scenarios/replace_types.py b/src/rai_benchmarks/scenarios/replace_types.py new file mode 100644 index 000000000..71fefbe26 --- /dev/null +++ b/src/rai_benchmarks/scenarios/replace_types.py @@ -0,0 +1,194 @@ +from rclpy.client import Client +from rclpy.node import Node +from rclpy.task import Future +from scenarios.scenario_base import ScenarioBase + +from rai_interfaces.srv import ManipulatorMoveTo + + +class ReplaceTypes(ScenarioBase): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + self.vegetable_poses = [] + self.toy_poses = [] + self.current_index = 0 + + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def get_prompt(self): + return "Replace the objects in such a way that all the toys are in the places of vegetables and vice versa." + + def reset(self): + super().reset() + + self.vegetable_poses = [] + self.toy_poses = [] + self.current_index = 0 + prefabs = ["apple", "yellow_cube", "blue_cube", "carrot"] + self.spawn_entities_in_random_positions(prefabs, prefabs) + for entity in prefabs: + if self._get_type(entity) == "vegetable": + self.vegetable_poses.append( + self.pose_transformed(self.get_entity_pose(entity)) + ) + elif self._get_type(entity) == "toy": + self.toy_poses.append( + self.pose_transformed(self.get_entity_pose(entity)) + ) + return + + def _get_type(self, object_name: str): + if object_name.startswith("apple") or object_name.startswith("carrot"): + return "vegetable" + elif object_name.startswith("yellow_cube") or object_name.startswith( + "blue_cube" + ): + return "toy" + else: + return None + + def _min_distance(self, entity: str): + def distance(a, b): + return ((a.x - b.x) ** 2 + (a.y - b.y) ** 2) ** 0.5 + + pose = self.pose_transformed(self.get_entity_pose(entity)) + min_distance = float("inf") + if self._get_type(entity) == "vegetable": + for toy_pose in self.toy_poses: + min_distance = min( + min_distance, distance(pose.position, toy_pose.position) + ) + elif self._get_type(entity) == "toy": + for vegetable_pose in self.vegetable_poses: + min_distance = min( + min_distance, distance(pose.position, vegetable_pose.position) + ) + return min_distance + + def calculate_progress(self): + if self.vegetable_poses == [] or self.toy_poses == []: + return 0.0 + + progress = 0.0 + + for entity in self.entities: + progress += max(0.0, 1.0 - self._min_distance(entity) * 10.0) + + return progress / len(self.entities) + + def step(self): + if len(self.entities) == 0: + return 0.0, False + + progress = self.calculate_progress() + + return progress, progress >= 0.8 + + +class ReplaceTypesAuto(ReplaceTypes): + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Client, + node: Node, + ): + self.manipulator_busy = False + + super().__init__(spawn_client, delete_client, manipulator_client, node) + + def reset(self): + super().reset() + + self.manipulator_busy = False + self.manipulator_queue = [] + + def replace(self, a: str, b: str): + from copy import deepcopy + + pose_a = self.pose_transformed(self.get_entity_pose(a)) + pose_a.position.z += 0.1 + + buffer_pose = deepcopy(pose_a) + buffer_pose.position.x = 0.4 + buffer_pose.position.y = -0.5 + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = pose_a + req.final_gripper_state = False + self.manipulator_queue.append(req) + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = buffer_pose + req.final_gripper_state = True + self.manipulator_queue.append(req) + + pose_b = self.pose_transformed(self.get_entity_pose(b)) + pose_b.position.z += 0.1 + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = pose_b + req.final_gripper_state = False + self.manipulator_queue.append(req) + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = pose_a + req.final_gripper_state = True + self.manipulator_queue.append(req) + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = True + req.target_pose.pose = buffer_pose + req.final_gripper_state = False + self.manipulator_queue.append(req) + + req = ManipulatorMoveTo.Request() + req.initial_gripper_state = False + req.target_pose.pose = pose_b + req.final_gripper_state = True + self.manipulator_queue.append(req) + + def move_callback(self, future: Future): + result = future.result() + if result.success: + self.node.get_logger().debug("Move performed") + else: + self.node.get_logger().error("Failed to perform move") + self.manipulator_busy = False + + def step(self): + if len(self.entities) < 4: + return 0.0, False + + progress = self.calculate_progress() + + if not self.manipulator_busy: + vegetables = [] + toys = [] + for entity in self.entities: + if self._get_type(entity) == "vegetable": + vegetables.append(entity) + elif self._get_type(entity) == "toy": + toys.append(entity) + + if len(self.manipulator_queue) == 0 and progress < 0.8: + self.replace(vegetables[self.current_index], toys[self.current_index]) + self.current_index = (self.current_index + 1) % len(vegetables) + + if len(self.manipulator_queue) > 0: + req = self.manipulator_queue.pop(0) + self.manipulator_busy = True + self.manipulator_client.call_async(req).add_done_callback( + self.move_callback + ) + + return progress, progress >= 0.8 and not self.manipulator_busy diff --git a/src/rai_benchmarks/scenarios/scenario_base.py b/src/rai_benchmarks/scenarios/scenario_base.py new file mode 100644 index 000000000..123db35e5 --- /dev/null +++ b/src/rai_benchmarks/scenarios/scenario_base.py @@ -0,0 +1,145 @@ +import random + +import rclpy +from gazebo_msgs.srv import DeleteEntity, SpawnEntity +from geometry_msgs.msg import Pose, PoseStamped, Quaternion +from rclpy.client import Client +from rclpy.node import Node +from rclpy.task import Future + + +class ScenarioBase: + """ + Base class for a scenario. A scenario is a task that the agent has to perform. + """ + + def __init__( + self, + spawn_client: Client, + delete_client: Client, + manipulator_client: Node, + node: Node, + ): + self.spawn_client = spawn_client + self.delete_client = delete_client + self.manipulator_client = manipulator_client + self.node = node + self.entities: dict[str, str] = {} + + def __del__(self): + for name in self.entities: + self.delete_entity(name) + + def get_prompt(self) -> str: + """ + Returns a prompt that describes the task that the agent has to perform. + """ + return "Please do something interesting" + + def reset(self) -> None: + """ + Resets the scenario to its initial state. + + This method is called before the scenario is started. + """ + + for name in self.entities: + self.delete_entity(name) + self.entities = {} + + def spawn_entity(self, prefab_name: str, name: str, pose: Pose) -> None: + """ + Spawns an entity in the simulation. + + Args: + prefab_name: The name of the prefab to spawn. + name: The name of the entity, by which it can be later referenced. + pose: The pose of the entity. + """ + req = SpawnEntity.Request() + req.name = prefab_name + req.xml = "" + req.robot_namespace = name + req.initial_pose = pose + + self.spawn_client.call_async(req).add_done_callback( + lambda future: self.entity_spawned_callback(future, name) + ) + + def spawn_entities_in_random_positions( + self, prefab_names: list[str], names: list[str] + ) -> None: + """ + Spawns entities randomly positione around the table. + + Args: + prefab_names: A list of prefab names to spawn. + names: A list of names for the entities, by which they can be later referenced. + """ + grid = [(x / 10.0 + 0.4, y / 10.0) for x in range(-1, 2) for y in range(-3, 4)] + + positions = random.sample(grid, k=len(prefab_names)) + for prefab_name, name, position in zip(prefab_names, names, positions): + pose = PoseStamped() + pose.header.frame_id = "world" + pose.pose.position.x = position[0] + pose.pose.position.y = position[1] + pose.pose.position.z = 0.05 + pose.pose.orientation = Quaternion(x=0.0, y=0.0, z=0.0, w=1.0) + pose_transformed = self.node.tf2_buffer.transform( + pose, "odom", timeout=rclpy.time.Duration(seconds=5.0) + ) + self.spawn_entity(prefab_name, name, pose_transformed.pose) + + def delete_entity(self, name: str) -> None: + req = DeleteEntity.Request() + req.name = self.entities[name] + + self.delete_client.call_async(req) + + def get_entity_pose(self, name: str) -> Pose: + """ + Returns the pose of an entity. + """ + pose = PoseStamped() + entity_frame = name + "/" + pose.header.frame_id = entity_frame + pose = self.node.tf2_buffer.transform( + pose, entity_frame + "odom", timeout=rclpy.time.Duration(seconds=5.0) + ) + return pose.pose + + def pose_transformed(self, pose: Pose) -> Pose: + """ + Transforms the pose into the frame of the panda robot. + """ + pose_stamped = PoseStamped() + pose_stamped.pose = pose + pose_stamped.header.frame_id = "odom" + pose = self.node.tf2_buffer.transform( + pose_stamped, "world", timeout=rclpy.time.Duration(seconds=5.0) + ).pose + pose.orientation = Quaternion(x=0.923880, y=-0.382683, z=0.0, w=0.0) + return pose + + def step(self) -> tuple[float, bool]: + """ + Performs a step in the scenario. + + Returns: + progress: A float between 0.0 and 1.0 that represents the progress of the task. + terminated: A boolean that indicates whether the task is terminated. + """ + return 0.0, False + + def entity_spawned_callback(self, future: Future, name: str): + result = future.result() + if result.success: + self.node.get_logger().info( + f"Entity spawned: {name} ({result.status_message})" + ) + self.entities[name] = result.status_message + else: + self.node.get_logger().error( + f"Failed to spawn entity: {result.status_message}" + ) diff --git a/src/rai_benchmarks/showcase.py b/src/rai_benchmarks/showcase.py new file mode 100644 index 000000000..52730fa7c --- /dev/null +++ b/src/rai_benchmarks/showcase.py @@ -0,0 +1,27 @@ +import rclpy +from manager import ScenarioManager +from rclpy.executors import MultiThreadedExecutor +from scenarios.longest_object import LongestObjectAuto +from scenarios.move_to_the_left import MoveToTheLeftAuto +from scenarios.place_on_top import PlaceOnTopAuto +from scenarios.replace_types import ReplaceTypesAuto + + +def main(args=None): + rclpy.init(args=args) + + manager = ScenarioManager( + [PlaceOnTopAuto, LongestObjectAuto, MoveToTheLeftAuto, ReplaceTypesAuto], + list(range(4)), + ) + + executor = MultiThreadedExecutor(2) + executor.add_node(manager) + executor.spin() + + manager.destroy_node() + rclpy.shutdown() + + +if __name__ == "__main__": + main() diff --git a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py index c264bd885..2bfb753e4 100644 --- a/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py +++ b/src/rai_extensions/rai_open_set_vision/rai_open_set_vision/tools/segmentation_tools.py @@ -26,6 +26,7 @@ ParameterUninitializedException, ) +from rai.communication.ros2.connectors import ROS2ARIConnector from rai.node import RaiBaseNode from rai.tools.ros import Ros2BaseInput, Ros2BaseTool from rai.tools.ros.utils import convert_ros_img_to_base64, convert_ros_img_to_ndarray @@ -65,6 +66,7 @@ class GetGrabbingPointInput(Ros2BaseInput): # --------------------- Tools --------------------- class GetSegmentationTool(Ros2BaseTool): + connector: ROS2ARIConnector node: RaiBaseNode = Field(..., exclude=True) name: str = "" @@ -84,7 +86,7 @@ def _get_gsam_response(self, future: Future) -> Optional[RAIGroundedSam.Response return get_future_result(future) def _get_image_message(self, topic: str) -> sensor_msgs.msg.Image: - msg = self.node.get_raw_message_from_topic(topic) + msg = self.connector.receive_message(topic).payload if type(msg) is sensor_msgs.msg.Image: return msg else: @@ -186,7 +188,6 @@ def depth_to_point_cloud( class GetGrabbingPointTool(GetSegmentationTool): - name: str = "GetGrabbingPointTool" description: str = "Get the grabbing point of an object" pcd: List[Any] = [] @@ -195,7 +196,7 @@ class GetGrabbingPointTool(GetSegmentationTool): def _get_camera_info_message(self, topic: str) -> sensor_msgs.msg.CameraInfo: for _ in range(3): - msg = self.node.get_raw_message_from_topic(topic, timeout_sec=3.0) + msg = self.connector.receive_message(topic, timeout_sec=3.0).payload if isinstance(msg, sensor_msgs.msg.CameraInfo): return msg self.node.get_logger().warn("Received wrong message type. Retrying...")