diff --git a/wild_visual_navigation/learning/utils/loss.py b/wild_visual_navigation/learning/utils/loss.py index 50f2d3b5..501e7d86 100644 --- a/wild_visual_navigation/learning/utils/loss.py +++ b/wild_visual_navigation/learning/utils/loss.py @@ -4,26 +4,23 @@ from typing import Optional from wild_visual_navigation.utils import ConfidenceGenerator from torch import nn -from torchmetrics import ROC, AUROC, Accuracy class AnomalyLoss(nn.Module): - def __init__(self, confidence_std_factor, method): + def __init__(self, confidence_std_factor: float, method: str, log_enabled: bool, log_folder: str): super(AnomalyLoss, self).__init__() - # self._confidence_generator = ConfidenceGenerator( - # std_factor=confidence_std_factor, - # method=method, - # log_enabled=False, - # log_folder="/tmp", - # ) + + self._confidence_generator = ConfidenceGenerator( + std_factor=confidence_std_factor, + method=method, + log_enabled=log_enabled, + log_folder=log_folder, + ) def forward( self, graph: Optional[Data], res: dict, - loss_mean: int = None, - loss_std: int = None, - train: bool = False, update_generator: bool = True, step: int = 0, log_step: bool = False, @@ -31,25 +28,15 @@ def forward( loss_aux = {} loss_aux["loss_trav"] = torch.tensor([0.0]) loss_aux["loss_reco"] = torch.tensor([0.0]) - loss_aux["confidence"] = torch.tensor([0.0]) losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions - # print(torch.mean(losses)) - l_clip = losses - if loss_mean is not None and loss_std is not None: - # Clip the losses - l_clip = torch.clip(losses, loss_mean - 2 * loss_std, loss_mean + 2 * loss_std) - - # Normalize between 0 and 1 - l_norm = (losses - torch.min(l_clip)) / (torch.max(l_clip) - torch.min(l_clip)) - l_trav = 1 - l_norm + if update_generator: + confidence = self._confidence_generator.update(x=losses, x_positive=losses) - if train: - loss_aux["loss_mean"] = torch.median(losses) - loss_aux["loss_std"] = torch.std(losses) + loss_aux["confidence"] = confidence - return torch.mean(losses), loss_aux, l_trav + return torch.mean(losses), loss_aux, confidence def update_node_confidence(self, node): node.confidence = 0 @@ -58,16 +45,16 @@ def update_node_confidence(self, node): class TraversabilityLoss(nn.Module): def __init__( self, - w_trav, - w_reco, - w_temp, - anomaly_balanced, - model, - method, - confidence_std_factor, + w_trav: float, + w_reco: float, + w_temp: float, + anomaly_balanced: bool, + model: nn.Module, + method: str, + confidence_std_factor: float, + log_enabled: bool, + log_folder: str, trav_cross_entropy=False, - log_enabled: bool = False, - log_folder: str = "/tmp", ): # TODO remove trav_cross_entropy default param when running in online mode super(TraversabilityLoss, self).__init__() diff --git a/wild_visual_navigation/traversability_estimator/traversability_estimator.py b/wild_visual_navigation/traversability_estimator/traversability_estimator.py index 49665f78..0417e3ad 100644 --- a/wild_visual_navigation/traversability_estimator/traversability_estimator.py +++ b/wild_visual_navigation/traversability_estimator/traversability_estimator.py @@ -123,7 +123,11 @@ def __init__( self._model.train() if self._exp_cfg["model"]["name"] == "LinearRnvp": - self._traversability_loss = AnomalyLoss(**self._exp_cfg["loss_anomaly"]) + self._traversability_loss = AnomalyLoss( + **self._exp_cfg["loss_anomaly"], + log_enabled=self._exp_cfg["general"]["log_confidence"], + log_folder=self._exp_cfg["general"]["model_path"], + ) self._traversability_loss.to(self._device) else: @@ -545,9 +549,6 @@ def load_checkpoint(self, checkpoint_path: str): def make_batch( self, batch_size: int = 8, - anomaly_detection: bool = False, - n_features: int = 200, - vis_training_samples: bool = False, ): """Samples a batch from the mission_graph @@ -557,7 +558,7 @@ def make_batch( # Just sample N random nodes mission_nodes = self._mission_graph.get_n_random_valid_nodes(n=batch_size) - batch = Batch.from_data_list([x.as_pyg_data(anomaly_detection=anomaly_detection) for x in mission_nodes]) + batch = Batch.from_data_list([x.as_pyg_data(anomaly_detection=self._anomaly_detection) for x in mission_nodes]) return batch @@ -575,16 +576,9 @@ def train(self): return_dict = {"mission_graph_num_valid_node": num_valid_nodes} if num_valid_nodes > self._min_samples_for_training: # Prepare new batch - graph = self.make_batch( - self._exp_cfg["ablation_data_module"]["batch_size"], - anomaly_detection=self._anomaly_detection, - vis_training_samples=self._vis_training_samples, - ) + graph = self.make_batch(self._exp_cfg["ablation_data_module"]["batch_size"]) if graph is not None: - self._loss_mean = None - self._loss_std = None - with self._learning_lock: # Forward pass @@ -592,18 +586,9 @@ def train(self): log_step = (self._step % 20) == 0 self._loss, loss_aux, trav = self._traversability_loss( - graph, - res, - step=self._step, - log_step=log_step, - loss_mean=self._loss_mean, - loss_std=self._loss_std, - train=True, + graph, res, step=self._step, log_step=log_step ) - self._loss_mean = loss_aux["loss_mean"] - self._loss_std = loss_aux["loss_std"] - # Keep track of ROC during training for rescaling the loss when publishing if self._scale_traversability: # This mask should contain all the segments corrosponding to trees. diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 12bd35a1..754b76d8 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -33,7 +33,6 @@ min_samples_for_training: 5 prediction_per_pixel: false traversability_threshold: 0.55 clip_to_binary: false -anomaly_detection: true vis_training_samples: true # Supervision Generator diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 86e5f6b9..c3809603 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -51,7 +51,6 @@ def __init__(self): method=self.exp_cfg["loss"]["method"], std_factor=self.exp_cfg["loss"]["confidence_std_factor"] ) self.scale_traversability = True - self.traversability_threshold = 0.5 else: self.traversability_loss = AnomalyLoss(**self.exp_cfg["loss_anomaly"]) self.traversability_loss.to(self.device) @@ -355,22 +354,20 @@ def load_model(self): self.log_data[f"nr_model_updates"] += 1 self.model.load_state_dict(res, strict=False) - # if res["traversability_threshold"] is not None: - # self.traversability_threshold = res["traversability_threshold"] - # if res["confidence_generator"] is not None: - # self.confidence_generator_state = res["confidence_generator"] - # self.traversability_threshold = 0.5 + try: + if res["traversability_threshold"] is not None: + self.traversability_threshold = res["traversability_threshold"] + if res["confidence_generator"] is not None: + self.confidence_generator_state = res["confidence_generator"] - # self.confidence_generator_state = res["confidence_generator"] + self.confidence_generator_state = res["confidence_generator"] + self.confidence_generator.var = self.confidence_generator_state["var"] + self.confidence_generator.mean = self.confidence_generator_state["mean"] + self.confidence_generator.std = self.confidence_generator_state["std"] + except: + pass - # self.confidence_generator.var = 0 - # self.confidence_generator.mean = 1 - # self.confidence_generator.std = 1 - - # self.confidence_generator.var = self.confidence_generator_state["var"] - # self.confidence_generator.mean = self.confidence_generator_state["mean"] - # self.confidence_generator.std = self.confidence_generator_state["std"] except Exception as e: if self.verbose: print(f"Model Loading Failed: {e}")