Skip to content

Commit

Permalink
added smart carrot refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 18, 2024
1 parent ff242f9 commit 6834bea
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 85 deletions.
16 changes: 16 additions & 0 deletions wild_visual_navigation_ros/launch/smart_carrot.launch
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<launch>
<!-- Launch Smart Carrot node -->
<node name="wild_visual_navigation_smart_carrot" pkg="wild_visual_navigation_ros" type="smart_carrot.py">
<param name="gridmap_sub_topic" value="/elevation_mapping/semantic_map"/>
<param name="debug_pub_topic" value="semantic_map"/>
<param name="goal_pub_topic" value="/initialpose"/>
<param name="debug" value="False"/>

<param name="map_frame" value="odom"/>
<param name="base_frame" value="base_inverted_field_local_planner"/>

<param name="distance_force_factor" value="0.5"/>
<param name="center_force_factor" value="0.0"/>
</node>

</launch>
210 changes: 125 additions & 85 deletions wild_visual_navigation_ros/scripts/smart_carrot.py
Original file line number Diff line number Diff line change
@@ -1,132 +1,172 @@
#
# Copyright (c) 2022-2024, ETH Zurich, Matias Mattamala, Jonas Frey.
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
from grid_map_msgs.msg import GridMap
import rospy
import sys
import numpy as np
import tf2_ros
from scipy.spatial.transform import Rotation as R
import cv2
import math
import time

import rospy
from grid_map_msgs.msg import GridMap, GridMapInfo
from geometry_msgs.msg import PoseWithCovarianceStamped
from dynamic_reconfigure.server import Server
import tf2_ros


class SmartCarrotNode:
def __init__(self):
self.gridmap_sub_topic = "/elevation_mapping/elevation_map_wifi"
self.pub = rospy.Publisher(f"~binary_mask", GridMap, queue_size=5)
self.pub_goal = rospy.Publisher(f"/initialpose", PoseWithCovarianceStamped, queue_size=5)
self.tf_buffer = tf2_ros.Buffer(cache_time=rospy.Duration(3.0))
# ROS topics
self.gridmap_sub_topic = rospy.get_param("~gridmap_sub_topic")
self.debug_pub_topic = rospy.get_param("~debug_pub_topic")
self.goal_pub_topic = rospy.get_param("~goal_pub_topic")
self.map_frame = rospy.get_param("~map_frame")
self.base_frame = rospy.get_param("~base_frame")
# Operation Mode
self.debug = rospy.get_param("~debug")

# Parameters
self.distance_force_factor = rospy.get_param("~distance_force_factor")
self.center_force_factor = rospy.get_param("~center_force_factor")

# TODO this could be implemented
self.filter_chain_funcs = []
if self.distance_force_factor > 0:
self.filter_chain_funcs.append(self.apply_distance_force)
self.distance_force = None

if self.center_force_factor > 0:
self.filter_chain_funcs.append(self.apply_center_force)

def distance_to_line(array, x_cor, y_cor, yaw, start_x, start_y):
return np.abs(np.cos(yaw) * (x_cor - start_x) - np.sin(yaw) * (y_cor - start_y))

self.vdistance_to_line = np.vectorize(distance_to_line)

# Initialize ROS publishers
self.pub = rospy.Publisher(f"~{self.debug_pub_topic}", GridMap, queue_size=5)
self.pub_goal = rospy.Publisher(self.goal_pub_topic, PoseWithCovarianceStamped, queue_size=5)

# Initialize TF listener
self.tf_buffer = tf2_ros.Buffer(cache_time=rospy.Duration(10.0))
self.tf_listener = tf2_ros.TransformListener(self.tf_buffer)
self.sub = rospy.Subscriber(self.gridmap_sub_topic, GridMap, self.callback, queue_size=5)

def apply_distance_force(self, yaw, sdf):
if self.distance_force is None:
# Create the distance force
self.distance_force = np.zeros((sdf.shape[0], sdf.shape[1]))
for x in range(sdf.shape[0]):
for y in range(sdf.shape[1]):
self.distance_force[x, y] = math.sqrt(
(x - int(sdf.shape[0] / 2)) ** 2 + (y - int(sdf.shape[1] / 2)) ** 2
)
self.distance_force /= self.distance_force.max()
self.distance_force *= self.distance_force_factor
return sdf + self.distance_force

def apply_center_force(self, yaw, sdf):
xv, yv = np.meshgrid(np.arange(0, sdf.shape[0]), np.arange(0, sdf.shape[1]))
center_force = self.vdistance_to_line(sdf, xv, yv, yaw, int(sdf.shape[0] / 2), int(sdf.shape[1] / 2))
return sdf - center_force * self.center_force_factor

def get_pattern_mask(self, H, W, yaw):
# Defines a pattern based on the yaw of the robot where we search for a minimum within the SDF
binary_mask = np.zeros((H, W), dtype=np.uint8)
distance = 30
center_x = int(H / 2 + math.sin(yaw) * distance)
center_y = int(W / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 30)
distance = 55
center_x = int(H / 2 + math.sin(yaw) * distance)
center_y = int(W / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 40)
distance = 90
center_x = int(H / 2 + math.sin(yaw) * distance)
center_y = int(W / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 50)

self.sub = rospy.Subscriber(f"~{self.gridmap_sub_topic}", GridMap, self.callback, queue_size=10)

self.offset = np.zeros((200, 200))
for x in range(200):
for y in range(200):
self.offset[x, y] = math.sqrt((x - 100) ** 2 + (y - 100) ** 2)
return binary_mask == 0

self.offset /= self.offset.max()
self.offset *= 0.5
self.debug = True
def get_elevation_mask(self, elevation_layer):
invalid_elevation = np.isnan(elevation_layer)
# Increase the size of the invalid elevation to reduce noise
kernel = np.ones((3, 3), np.uint8)
invalid_elevation = cv2.dilate(np.uint8(invalid_elevation) * 255, kernel, iterations=1) == 255
return invalid_elevation

def callback(self, msg):
target_layer = "sdf"
if target_layer in msg.layers:
# extract grid_map layer as numpy array
data_list = msg.data[msg.layers.index(target_layer)].data
layout_info = msg.data[msg.layers.index(target_layer)].layout
n_cols = layout_info.dim[0].size
n_rows = layout_info.dim[1].size
sdf = np.reshape(np.array(data_list), (n_rows, n_cols))
sdf = sdf[::-1, ::-1].transpose().astype(np.float32)

target_layer = "elevation"
if target_layer in msg.layers:
# extract grid_map layer as numpy array
data_list = msg.data[msg.layers.index(target_layer)].data
layout_info = msg.data[msg.layers.index(target_layer)].layout
n_cols = layout_info.dim[0].size
n_rows = layout_info.dim[1].size
elevation = np.reshape(np.array(data_list), (n_rows, n_cols))
elevation = elevation[::-1, ::-1].transpose().astype(np.float32)
print("called callback")
# Convert GridMap to numpy array
layers = {}
for layer_name in ["sdf", "elevation"]:
if layer_name in msg.layers:
data_list = msg.data[msg.layers.index(layer_name)].data
layout_info = msg.data[msg.layers.index(layer_name)].layout
n_cols = layout_info.dim[0].size
n_rows = layout_info.dim[1].size
layer = np.reshape(np.array(data_list), (n_rows, n_cols))
layer = layer[::-1, ::-1].transpose().astype(np.float32)
layers[layer_name] = layer
else:
rospy.logwarn(f"Layer {layer_name} not found in GridMap")
return False

try:
res = self.tf_buffer.lookup_transform(
"odom",
"base_inverted_field_local_planner",
msg.info.header.stamp,
timeout=rospy.Duration(0.01),
self.map_frame, self.base_frame, msg.info.header.stamp, timeout=rospy.Duration(0.01)
)
except Exception as e:
print("error")
print("Error in query tf: ", e)
rospy.logwarn(f"Couldn't get between odom and base")
return
error = str(e)
rospy.logwarn(f"Couldn't get between odom and base {error}")
return False

yaw = R.from_quat(
[
res.transform.rotation.x,
res.transform.rotation.y,
res.transform.rotation.z,
res.transform.rotation.w,
]
).as_euler("zxy", degrees=False)[0]
H, W = layers["sdf"].shape
rot = res.transform.rotation
yaw = R.from_quat([rot.x, rot.y, rot.z, rot.w]).as_euler("zxy", degrees=False)[0]

binary_mask = np.zeros((sdf.shape[0], sdf.shape[1]), dtype=np.uint8)
mask_pattern = self.get_pattern_mask(H, W, yaw)
mask_elevation = self.get_elevation_mask(layers["elevation"])

distance = 30 # sdf.shape[0] / 5
center_x = int(sdf.shape[0] / 2 + math.sin(yaw) * distance)
center_y = int(sdf.shape[1] / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 30)
distance = 55 # sdf.shape[0] / 3
center_x = int(sdf.shape[0] / 2 + math.sin(yaw) * distance)
center_y = int(sdf.shape[1] / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 40)
distance = 90 # sdf.shape[0] / 2
center_x = int(sdf.shape[0] / 2 + math.sin(yaw) * distance)
center_y = int(sdf.shape[1] / 2 + math.cos(yaw) * distance)
binary_mask = cv2.circle(binary_mask, (center_x, center_y), 0, 255, 50)
m = binary_mask == 0
sdf += self.offset
m2 = np.isnan(elevation)
for filter_func in self.filter_chain_funcs:
layers["sdf"] = filter_func(yaw, layers["sdf"])

kernel = np.ones((3, 3), np.uint8)
m2 = cv2.dilate(np.uint8(m2) * 255, kernel, iterations=1) == 255
sdf[m] = sdf.min()
sdf[m2] = sdf.min()
layers["sdf"][mask_pattern] = -np.inf
layers["sdf"][mask_elevation] = -np.inf

if sdf.min() == sdf.max():
if layers["sdf"].min() == layers["sdf"].max():
rospy.logwarn(f"No valid elevation within the SDF of the defined pattern {e}")
return

x, y = np.where(sdf == sdf.max())
# Get index of the maximum gridmax cell index within the SDF
x, y = np.where(layers["sdf"] == layers["sdf"].max())
x = x[0]
y = y[0]
x -= sdf.shape[0] / 2
y -= sdf.shape[1] / 2

# Convert the GridMap index to a map frame position
x -= H / 2
y -= W / 2
x *= msg.info.resolution
y *= msg.info.resolution
x += msg.info.pose.position.x
y += msg.info.pose.position.y

# Publish the goal
goal = PoseWithCovarianceStamped()
goal.header.stamp = msg.info.header.stamp
goal.header.frame_id = "odom"
goal.header.frame_id = self.map_frame
goal.pose.pose.position.x = x
goal.pose.pose.position.y = y
goal.pose.pose.position.z = res.transform.translation.z + 0.3
goal.pose.pose.orientation = res.transform.rotation
goal.pose.pose.orientation = rot
self.pub_goal.publish(goal)

if self.debug:
target_layer = "sdf"
msg.data[msg.layers.index(target_layer)].data = sdf[::-1, ::-1].transpose().ravel()
# Republish the SDF used to search for the maximum goal
msg.data[msg.layers.index("sdf")].data = layers["sdf"][::-1, ::-1].transpose().ravel()
self.pub.publish(msg)


if __name__ == "__main__":
rospy.init_node("wild_visual_navigation_smart_carrot")
print("Start")
wvn = SmartCarrotNode()
rospy.spin()
rospy.spin()

0 comments on commit 6834bea

Please sign in to comment.