Skip to content

Commit

Permalink
added valid nodes to state - dino working
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Aug 16, 2023
1 parent a4268d3 commit 81825ba
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 37 deletions.
6 changes: 5 additions & 1 deletion wild_visual_navigation/feature_extractor/dino_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from omegaconf import DictConfig
from torchvision import transforms as T
from stego.src.train_segmentation import DinoFeaturizer
from kornia.filters import filter2d


class DinoInterface:
Expand Down Expand Up @@ -68,6 +69,8 @@ def __init__(
# Just normalization
self.norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

self.mean_kernel = torch.ones((1, 5, 5), device=device) / 25

def change_device(self, device):
"""Changes the device of all the class members
Expand Down Expand Up @@ -136,7 +139,8 @@ def inference(self, img: torch.tensor, interpolate: bool = False):
pad = int((W - H) / 2)
features = F.interpolate(features, new_size, mode="bilinear", align_corners=True)
features = F.pad(features, pad=[pad, pad, 0, 0])

# Optionally turn on image feature smoothing
# features = filter2d(features, self.mean_kernel, "replicate")
return features

@property
Expand Down
16 changes: 8 additions & 8 deletions wild_visual_navigation/feature_extractor/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
pass

def extract(self, img, **kwargs):
if kwargs.get("fast_random", False):
if self._segmentation_type == "random":
dense_feat = self.compute_features(img, None, None, **kwargs)

H, W = img.shape[2:]
Expand All @@ -83,16 +83,16 @@ def extract(self, img, **kwargs):
return None, feat, seg, None

# Compute segments, their centers, and edges connecting them (graph structure)
with Timer("feature_extractor - compute_segments"):
edges, seg, center = self.compute_segments(img, **kwargs)
# with Timer("feature_extractor - compute_segments"):
edges, seg, center = self.compute_segments(img, **kwargs)

# Compute features
with Timer("feature_extractor - compute_features"):
dense_feat = self.compute_features(img, seg, center, **kwargs)
# with Timer("feature_extractor - compute_features"):
dense_feat = self.compute_features(img, seg, center, **kwargs)

with Timer("feature_extractor - compute_features"):
# Sparsify features to match the centers if required
feat = self.sparsify_features(dense_feat, seg)
# with Timer("feature_extractor - compute_features"):
# Sparsify features to match the centers if required
feat = self.sparsify_features(dense_feat, seg)

if kwargs.get("return_dense_features", False):
return edges, feat, seg, center, dense_feat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,9 @@ def train(self):
if self._pause_training:
return {}

return_dict = {}
if self._mission_graph.get_num_valid_nodes() > self._min_samples_for_training:
num_valid_nodes = self._mission_graph.get_num_valid_nodes()
return_dict = {"mission_graph_num_valid_node": num_valid_nodes}
if num_valid_nodes > self._min_samples_for_training:
# Prepare new batch
graph = self.make_batch(self._exp_cfg["ablation_data_module"]["batch_size"])

Expand Down Expand Up @@ -737,8 +738,10 @@ def train(self):
return_dict["loss_total"] = self._loss.item()
return_dict["loss_trav"] = loss_aux["loss_trav"].item()
return_dict["loss_reco"] = loss_aux["loss_reco"].item()

return return_dict
return {"loss_total": -1}
return_dict["loss_total"] = -1
return return_dict

@accumulate_time
def plot_mission_node_prediction(self, node: MissionNode):
Expand Down
2 changes: 2 additions & 0 deletions wild_visual_navigation_msgs/msg/SystemState.msg
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Pipeline operation mode
uint32 mode
# Current valid samples
uint32 mission_graph_num_valid_node
# Training loss
float32 loss_total
# Training loss
Expand Down
Loading

0 comments on commit 81825ba

Please sign in to comment.