Skip to content

Commit

Permalink
Fixes to recent commits
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinSchmid7 committed Aug 30, 2023
1 parent 85d2103 commit 112aa61
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion wild_visual_navigation/cfg/experiment_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class LossParams:

@dataclass
class LossAnomalyParams:
method: str = "latest_measurment"
method: str = "latest_measurment" # "latest_measurment", "running_mean"
confidence_std_factor: float = 0.5

loss_anomaly: LossAnomalyParams = LossAnomalyParams()
Expand Down
6 changes: 3 additions & 3 deletions wild_visual_navigation/learning/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def forward(
loss_aux["loss_trav"] = torch.tensor([0.0])
loss_aux["loss_reco"] = torch.tensor([0.0])

losses = -(res["logprob"].sum(1) + res["log_det"]) # Sum over all channels, resulting in h*w output dimensions
losses = res["logprob"].sum(1) + res["log_det"] # Sum over all channels, resulting in h*w output dimensions

if update_generator:
confidence = self._confidence_generator.update(x=losses, x_positive=losses)
confidence = self._confidence_generator.update(x=losses, x_positive=losses, step=step)

loss_aux["confidence"] = confidence

return torch.mean(losses), loss_aux, confidence
return -torch.mean(losses), loss_aux, confidence

def update_node_confidence(self, node):
node.confidence = 0
Expand Down
4 changes: 2 additions & 2 deletions wild_visual_navigation/traversability_estimator/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def as_pyg_data(self, previous_node: Optional[BaseNode] = None, anomaly_detectio
if anomaly_detection:
return Data(
x=self.features[self._supervision_signal_valid],
edge_index=self._feature_edges[self._supervision_signal_valid],
edge_index=self._feature_edges,
y=self._supervision_signal[self._supervision_signal_valid],
y_valid=self._supervision_signal_valid[self._supervision_signal_valid],
)
Expand All @@ -205,7 +205,7 @@ def as_pyg_data(self, previous_node: Optional[BaseNode] = None, anomaly_detectio
if anomaly_detection:
return Data(
x=self.features[self._supervision_signal_valid],
edge_index=self._feature_edges[self._supervision_signal_valid],
edge_index=self._feature_edges,
y=self._supervision_signal[self._supervision_signal_valid],
y_valid=self._supervision_signal_valid[self._supervision_signal_valid],
x_previous=previous_node.features,
Expand Down
5 changes: 4 additions & 1 deletion wild_visual_navigation/utils/confidence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def inference_without_update(self, x: torch.tensor):
if x.device != self.mean.device:
return torch.zeros_like(x)

confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
# confidence = torch.exp(-(((x - self.mean) / (self.std * self.std_factor)) ** 2) * 0.5)
x = torch.clip(x, self.mean - 2 * self.std, self.mean + 2 * self.std)
confidence = (x - torch.min(x)) / (torch.max(x) - torch.min(x))

return confidence.type(torch.float32)

def forward(self, x: torch.tensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def __init__(self):
)
self.scale_traversability = True
else:
self.traversability_loss = AnomalyLoss(**self.exp_cfg["loss_anomaly"])
self.traversability_loss = AnomalyLoss(**self.exp_cfg["loss_anomaly"],
log_enabled=self.exp_cfg["general"]["log_confidence"],
log_folder=self.exp_cfg["general"]["model_path"])
self.traversability_loss.to(self.device)
self.scale_traversability = False

Expand Down Expand Up @@ -348,7 +350,7 @@ def load_model(self):
res = torch.load(f"{WVN_ROOT_DIR}/tmp_state_dict2.pt")
k = list(self.model.state_dict().keys())[-1]

if (self.model.state_dict()[k] != res[k]).any(): # TODO: model params are changing?
if (self.model.state_dict()[k] != res[k]).any():
if self.verbose:
self.log_data[f"time_last_model"] = rospy.get_time()
self.log_data[f"nr_model_updates"] += 1
Expand Down

0 comments on commit 112aa61

Please sign in to comment.