Skip to content

Commit

Permalink
Merge branch 'dev/repair_develop' into dev/repair_develop_robot
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 authored Sep 4, 2023
2 parents 274768c + cf7e0d7 commit 9a43494
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 139 deletions.
6 changes: 1 addition & 5 deletions wild_visual_navigation/feature_extractor/dino_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
projection_type: str = None, # nonlinear or None
dropout: bool = False, # True or False
):
self.dim = dim # 90 or 384
self.dim = dim
self.cfg = DictConfig(
{
"dino_patch_size": patch_size,
Expand Down Expand Up @@ -72,8 +72,6 @@ def __init__(
# Just normalization
self.norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

self.mean_kernel = torch.ones((1, 5, 5), device=device) / 25

def change_device(self, device):
"""Changes the device of all the class members
Expand Down Expand Up @@ -142,8 +140,6 @@ def inference(self, img: torch.tensor, interpolate: bool = False):
pad = int((W - H) / 2)
features = F.interpolate(features, new_size, mode="bilinear", align_corners=True)
features = F.pad(features, pad=[pad, pad, 0, 0])
# Optionally turn on image feature smoothing
# features = filter2d(features, self.mean_kernel, "replicate")
return features

@property
Expand Down
5 changes: 2 additions & 3 deletions wild_visual_navigation/feature_extractor/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

class FeatureExtractor:
def __init__(
self, device: str, segmentation_type: str = "slic", feature_type: str = "dino", input_size: int = 448, **kwargs
):
self, device: str, segmentation_type: str = "slic", feature_type: str = "dino", input_size: int = 448, **kwargs):
"""Feature extraction from image
Args:
Expand All @@ -40,7 +39,7 @@ def __init__(
elif self._feature_type == "dino":
self._feature_dim = 90

self.extractor = DinoInterface(device=device, input_size=input_size, patch_size=kwargs.get("patch_size", 8))
self.extractor = DinoInterface(device=device, input_size=input_size, patch_size=kwargs.get("patch_size", 8), dim=kwargs.get("dino_dim", 384))
elif self._feature_type == "sift":
self._feature_dim = 128
self.extractor = DenseSIFTDescriptor().to(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchmetrics import ROC
import numpy as np
import cv2
import rospy
import random
from sensor_msgs.msg import Image
from cv_bridge import CvBridge

to_tensor = transforms.ToTensor()

Expand Down Expand Up @@ -109,12 +104,6 @@ def __init__(
# Visualization
self._visualizer = LearningVisualizer()

if self._vis_training_samples:
self._last_image_mask_pub = rospy.Publisher(
f"/wild_visual_navigation_node/last_node_image_mask", Image, queue_size=1
)
self._bridge = CvBridge()

# Lightning module
seed_everything(42)

Expand Down
9 changes: 6 additions & 3 deletions wild_visual_navigation/visu/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,12 @@ def plot_detectron_classification(
overlay_mask=None,
**kwargs,
):
cmap = cm.get_cmap("RdYlBu", 256)
cmap = np.concatenate([cmap(np.linspace(0, 0.3, 128)), cmap(np.linspace(0.7, 1.0, 128))])
cmap = torch.from_numpy(cmap).to(seg)[:, :3]
if kwargs.get("cmap", None):
cmap = kwargs["cmap"]
else:
cmap = cm.get_cmap("RdYlBu", 256)
cmap = np.concatenate([cmap(np.linspace(0, 0.3, 128)), cmap(np.linspace(0.7, 1.0, 128))])
cmap = torch.from_numpy(cmap).to(seg)[:, :3]

img = self.plot_image(img, not_log=True)
seg_img = self.plot_segmentation(
Expand Down
Loading

0 comments on commit 9a43494

Please sign in to comment.