Skip to content

Commit

Permalink
started to update the confidence generator
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Aug 29, 2023
1 parent 1cacd24 commit 85d2103
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 72 deletions.
55 changes: 21 additions & 34 deletions wild_visual_navigation/learning/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,52 +4,39 @@
from typing import Optional
from wild_visual_navigation.utils import ConfidenceGenerator
from torch import nn
from torchmetrics import ROC, AUROC, Accuracy


class AnomalyLoss(nn.Module):
def __init__(self, confidence_std_factor, method):
def __init__(self, confidence_std_factor: float, method: str, log_enabled: bool, log_folder: str):
super(AnomalyLoss, self).__init__()
# self._confidence_generator = ConfidenceGenerator(
# std_factor=confidence_std_factor,
# method=method,
# log_enabled=False,
# log_folder="/tmp",
# )

self._confidence_generator = ConfidenceGenerator(
std_factor=confidence_std_factor,
method=method,
log_enabled=log_enabled,
log_folder=log_folder,
)

def forward(
self,
graph: Optional[Data],
res: dict,
loss_mean: int = None,
loss_std: int = None,
train: bool = False,
update_generator: bool = True,
step: int = 0,
log_step: bool = False,
):
loss_aux = {}
loss_aux["loss_trav"] = torch.tensor([0.0])
loss_aux["loss_reco"] = torch.tensor([0.0])
loss_aux["confidence"] = torch.tensor([0.0])

losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions

# print(torch.mean(losses))
l_clip = losses
if loss_mean is not None and loss_std is not None:
# Clip the losses
l_clip = torch.clip(losses, loss_mean - 2 * loss_std, loss_mean + 2 * loss_std)

# Normalize between 0 and 1
l_norm = (losses - torch.min(l_clip)) / (torch.max(l_clip) - torch.min(l_clip))
l_trav = 1 - l_norm
if update_generator:
confidence = self._confidence_generator.update(x=losses, x_positive=losses)

if train:
loss_aux["loss_mean"] = torch.median(losses)
loss_aux["loss_std"] = torch.std(losses)
loss_aux["confidence"] = confidence

return torch.mean(losses), loss_aux, l_trav
return torch.mean(losses), loss_aux, confidence

def update_node_confidence(self, node):
node.confidence = 0
Expand All @@ -58,16 +45,16 @@ def update_node_confidence(self, node):
class TraversabilityLoss(nn.Module):
def __init__(
self,
w_trav,
w_reco,
w_temp,
anomaly_balanced,
model,
method,
confidence_std_factor,
w_trav: float,
w_reco: float,
w_temp: float,
anomaly_balanced: bool,
model: nn.Module,
method: str,
confidence_std_factor: float,
log_enabled: bool,
log_folder: str,
trav_cross_entropy=False,
log_enabled: bool = False,
log_folder: str = "/tmp",
):
# TODO remove trav_cross_entropy default param when running in online mode
super(TraversabilityLoss, self).__init__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def __init__(
self._model.train()

if self._exp_cfg["model"]["name"] == "LinearRnvp":
self._traversability_loss = AnomalyLoss(**self._exp_cfg["loss_anomaly"])
self._traversability_loss = AnomalyLoss(
**self._exp_cfg["loss_anomaly"],
log_enabled=self._exp_cfg["general"]["log_confidence"],
log_folder=self._exp_cfg["general"]["model_path"],
)
self._traversability_loss.to(self._device)

else:
Expand Down Expand Up @@ -545,9 +549,6 @@ def load_checkpoint(self, checkpoint_path: str):
def make_batch(
self,
batch_size: int = 8,
anomaly_detection: bool = False,
n_features: int = 200,
vis_training_samples: bool = False,
):
"""Samples a batch from the mission_graph
Expand All @@ -557,7 +558,7 @@ def make_batch(

# Just sample N random nodes
mission_nodes = self._mission_graph.get_n_random_valid_nodes(n=batch_size)
batch = Batch.from_data_list([x.as_pyg_data(anomaly_detection=anomaly_detection) for x in mission_nodes])
batch = Batch.from_data_list([x.as_pyg_data(anomaly_detection=self._anomaly_detection) for x in mission_nodes])

return batch

Expand All @@ -575,35 +576,19 @@ def train(self):
return_dict = {"mission_graph_num_valid_node": num_valid_nodes}
if num_valid_nodes > self._min_samples_for_training:
# Prepare new batch
graph = self.make_batch(
self._exp_cfg["ablation_data_module"]["batch_size"],
anomaly_detection=self._anomaly_detection,
vis_training_samples=self._vis_training_samples,
)
graph = self.make_batch(self._exp_cfg["ablation_data_module"]["batch_size"])
if graph is not None:

self._loss_mean = None
self._loss_std = None

with self._learning_lock:
# Forward pass

res = self._model(graph)

log_step = (self._step % 20) == 0
self._loss, loss_aux, trav = self._traversability_loss(
graph,
res,
step=self._step,
log_step=log_step,
loss_mean=self._loss_mean,
loss_std=self._loss_std,
train=True,
graph, res, step=self._step, log_step=log_step
)

self._loss_mean = loss_aux["loss_mean"]
self._loss_std = loss_aux["loss_std"]

# Keep track of ROC during training for rescaling the loss when publishing
if self._scale_traversability:
# This mask should contain all the segments corrosponding to trees.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ min_samples_for_training: 5
prediction_per_pixel: false
traversability_threshold: 0.55
clip_to_binary: false
anomaly_detection: true
vis_training_samples: true

# Supervision Generator
Expand Down
25 changes: 11 additions & 14 deletions wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(self):
method=self.exp_cfg["loss"]["method"], std_factor=self.exp_cfg["loss"]["confidence_std_factor"]
)
self.scale_traversability = True
self.traversability_threshold = 0.5
else:
self.traversability_loss = AnomalyLoss(**self.exp_cfg["loss_anomaly"])
self.traversability_loss.to(self.device)
Expand Down Expand Up @@ -355,22 +354,20 @@ def load_model(self):
self.log_data[f"nr_model_updates"] += 1

self.model.load_state_dict(res, strict=False)
# if res["traversability_threshold"] is not None:
# self.traversability_threshold = res["traversability_threshold"]
# if res["confidence_generator"] is not None:
# self.confidence_generator_state = res["confidence_generator"]

# self.traversability_threshold = 0.5
try:
if res["traversability_threshold"] is not None:
self.traversability_threshold = res["traversability_threshold"]
if res["confidence_generator"] is not None:
self.confidence_generator_state = res["confidence_generator"]

# self.confidence_generator_state = res["confidence_generator"]
self.confidence_generator_state = res["confidence_generator"]
self.confidence_generator.var = self.confidence_generator_state["var"]
self.confidence_generator.mean = self.confidence_generator_state["mean"]
self.confidence_generator.std = self.confidence_generator_state["std"]
except:
pass

# self.confidence_generator.var = 0
# self.confidence_generator.mean = 1
# self.confidence_generator.std = 1

# self.confidence_generator.var = self.confidence_generator_state["var"]
# self.confidence_generator.mean = self.confidence_generator_state["mean"]
# self.confidence_generator.std = self.confidence_generator_state["std"]
except Exception as e:
if self.verbose:
print(f"Model Loading Failed: {e}")
Expand Down

0 comments on commit 85d2103

Please sign in to comment.