Skip to content

Commit

Permalink
improved feature extraction node to cont. log
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Aug 16, 2023
1 parent 467149e commit a4268d3
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ image_callback_rate: 10 # hertz
proprio_callback_rate: 4 # hertz
learning_thread_rate: 10 # hertz
logging_thread_rate: 2 # hertz
status_thread_rate: 0.5 # hertz

# Runtime options
device: "cuda"
Expand All @@ -55,7 +56,7 @@ print_image_callback_time: false
print_proprio_callback_time: false
log_time: false
log_confidence: false
verbose: false
verbose: true
debug_supervision_node_index_from_last: 10

extraction_store_folder: "nan"
Expand Down
78 changes: 68 additions & 10 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@

import rospy
from sensor_msgs.msg import Image, CameraInfo, CompressedImage
from std_msgs.msg import Float32, MultiArrayDimension
from rospy.numpy_msg import numpy_msg
import message_filters
from std_msgs.msg import MultiArrayDimension

from pytictac import Timer, CpuTimer
import os
import torch
import numpy as np
import dataclasses
from torch_geometric.data import Data
import torch.nn.functional as F
from threading import Thread, Event
from prettytable import PrettyTable
from termcolor import colored


class WvnFeatureExtractor:
Expand All @@ -34,6 +34,7 @@ def __init__(self):
feature_type=self.feature_type,
input_size=self.network_input_image_height,
)
self.i = 0
self.setup_ros()

self.model = get_model(self.exp_cfg["model"]).to(self.device)
Expand All @@ -45,7 +46,38 @@ def __init__(self):
self.scale_traversability = True
self.traversability_thershold = 0.5

self.i = 0
if self.verbose:
self.status_thread = Thread(target=self.status_thread_loop, name="status")
self.status_thread.start()

def status_thread_loop(self):
rate = rospy.Rate(self.status_thread_rate)
# Learning loop
while True:
t = rospy.get_time()
x = PrettyTable()
x.field_names = ["Key", "Value"]

for k, v in self.log_data.items():
if "time" in k:
d = t - v
if d < 0:
c = "red"
if d < 0.2:
c = "green"
elif d < 1.0:
c = "yellow"
else:
c = "red"
x.add_row([k, colored(round(d, 2), c)])
else:
x.add_row([k, v])
print(x)
try:
rate.sleep()
except Exception as e:
rate = rospy.Rate(self.status_thread_rate)
print("Ignored jump pack in time!")

def read_params(self):
"""Reads all the parameters from the parameter server"""
Expand All @@ -65,6 +97,7 @@ def read_params(self):
self.confidence_std_factor = rospy.get_param("~confidence_std_factor")
self.scale_traversability = rospy.get_param("~scale_traversability")
self.scale_traversability_max_fpr = rospy.get_param("~scale_traversability_max_fpr")
self.status_thread_rate = rospy.get_param("~status_thread_rate")

# Initialize traversability estimator parameters
# Experiment file
Expand All @@ -80,8 +113,21 @@ def read_params(self):
def setup_ros(self, setup_fully=True):
"""Main function to setup ROS-related stuff: publishers, subscribers and services"""
# Image callback

self.camera_handler = {}

if self.verbose:
# DEBUG Logging
self.log_data = {}
self.log_data[f"time_last_model"] = -1
self.log_data[f"nr_model_updates"] = -1

for cam in self.camera_topics:
if self.verbose:
# DEBUG Logging
self.log_data[f"nr_images_{cam}"] = 0
self.log_data[f"time_last_image_{cam}"] = -1

# Initialize camera handler for given cam
self.camera_handler[cam] = {}
# Store camera name
Expand Down Expand Up @@ -160,7 +206,9 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo
cam (str): Camera name
"""
if self.verbose:
print("Processing Camera: ", cam)
# DEBUG Logging
self.log_data[f"nr_images_{cam}"] += 1
self.log_data[f"time_last_image_{cam}"] = rospy.get_time()

# Update model from file if possible
self.load_model()
Expand Down Expand Up @@ -254,17 +302,16 @@ def load_model(self):
res = torch.load(f"{WVN_ROOT_DIR}/tmp_state_dict2.pt")
if (self.model.state_dict()["layers.0.weight"] != res["layers.0.weight"]).any():
if self.verbose:
print("Model updated.")
self.log_data[f"time_last_model"] = rospy.get_time()
self.log_data[f"nr_model_updates"] += 1

self.model.load_state_dict(res, strict=False)
self.traversability_thershold = res["traversability_thershold"]
self.confidence_generator_state = res["confidence_generator"]

self.confidence_generator.var = self.confidence_generator_state["var"]
self.confidence_generator.mean = self.confidence_generator_state["mean"]
self.confidence_generator.std = self.confidence_generator_state["std"]
else:
if self.verbose:
print("Model did not change.")
except Exception as e:
if self.verbose:
print(f"Model Loading Failed: {e}")
Expand All @@ -273,5 +320,16 @@ def load_model(self):
if __name__ == "__main__":
node_name = "wvn_feature_extractor_node"
rospy.init_node(node_name)

if True:
import rospkg

rospack = rospkg.RosPack()
wvn_path = rospack.get_path("wild_visual_navigation_ros")
os.system(f"rosparam load {wvn_path}/config/wild_visual_navigation/default.yaml wvn_feature_extractor_node")
os.system(
f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/alphasense_compressed.yaml wvn_feature_extractor_node"
)

wvn = WvnFeatureExtractor()
rospy.spin()

0 comments on commit a4268d3

Please sign in to comment.