Skip to content

Commit

Permalink
Mid debugging deep supervision metric, the need for an running averag…
Browse files Browse the repository at this point in the history
…e wrapper has become obvious to reduce training function complexity
  • Loading branch information
ga84jog committed Sep 11, 2024
1 parent 6712ed8 commit 5a069ae
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions src/models/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ def __init__(self,
raise ValueError("Task not specified and could not be inferred from "
"output_dim and final_activation! Provide task on initialization.")

# TODO! this is for debugging remove:
self._labels = list()
self._outputs = list()

@property
def optimizer(self):
if hasattr(self, "_optimizer"):
Expand Down Expand Up @@ -290,24 +294,43 @@ def _on_epoch_end(self, epoch: int, prefix: str = ""):
loss_history[epoch] = avg_loss
else:
setattr(self._history, f"{prefix}_loss", avg_loss)

# Update best loss
if hasattr(self._history, f"best_{prefix}"):
best_epochs = getattr(self._history, f"best_{prefix}")
if best_epochs["loss"] > avg_loss:
best_epochs["loss"] = avg_loss
best_epochs["epoch"] = epoch

# Update metric history
if hasattr(self._history, f"{prefix}_metrics"):
metric_history = getattr(self._history, f"{prefix}_metrics")
for key, metric in self._metrics[prefix].items():
metric_history[key][epoch] = metric["value"]

# Reset metric values
for metric in self._metrics[prefix].values():
metric["value"] = 0.0

self._current_metrics = dict()
self._sample_count = 0
self._batch_count = 0
self._generator_size = 0
self._batch_size = 0
self._has_val = False
'''
# TODO! remove
labels = torch.cat(self._labels, axis=1).to("cpu").detach().squeeze().numpy()
outputs = torch.cat(self._outputs, axis=1).to("cpu").detach().squeeze().numpy()
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
precision, recall, _ = precision_recall_curve(labels, outputs)
print(f"PR AUC: {auc(recall, precision)}")
print(f"ROC AUC: {roc_auc_score(labels, outputs)}")
self._labels = list()
self._outputs = list()
'''

def _update_metrics(self,
loss: torch.Tensor,
Expand Down Expand Up @@ -432,6 +455,9 @@ def _train_with_arrays(self,
aggr_labels.append(label)

# Optimizer network on abtch
self._labels.append(label)
self._outputs.append(output)

aggr_outputs, \
aggr_labels = self._optimize_batch(outputs=aggr_outputs, \
labels=aggr_labels, \
Expand Down Expand Up @@ -662,10 +688,12 @@ def _evaluate_with_arrays(self,
# Only update history if test
if is_test:
self._on_epoch_end(epoch, prefix="test")
return list(dict(self._get_metrics(self._metrics["test"])).values())
return_metrics = list(dict(self._get_metrics(self._metrics["test"])).values())
return return_metrics
# Also complete progbar if eval
self._on_epoch_end(epoch, prefix="val")
return list(dict(self._get_metrics(self._metrics["val"])).values())
return_metrics = list(dict(self._get_metrics(self._metrics["val"])).values())
return return_metrics

def _evaluate_with_dataloader(self,
generator: DataLoader,
Expand Down Expand Up @@ -937,32 +965,8 @@ def _predict_numpy(self, x, batch_size=None, verbose="auto", steps=None, **kwarg

# On device data
x = torch.tensor(x, dtype=torch.float32).to(self._device)

aggr_outputs = []

with torch.no_grad():
for sample_idx, inputs, in enumerate(x):
if masking_flag:
# Set labels to zero when masking since forward does the same
mask = masks[sample_idx]
if len(inputs.shape) < 3:
mask = mask.unsqueeze(0)
mask = mask
else:
# TODO! Is there a better way. What if there is an actual patient with all zero across
# TODO! 59 columns? Am I paranoid. Maybe adding the discretizer masking can alleviat
inputs = self._remove_end_padding(inputs)
mask = None

# Adjust dimensions
if len(inputs.shape) < 3:
inputs = inputs.unsqueeze(0)

# Create predictions
output = self(inputs, masks=mask)
aggr_outputs.append(output.to("cpu").numpy())

return zeropad_samples(aggr_outputs)
return self(x, masks=masks).to("cpu").detach().numpy()

def _get_dataloader_or_array(self, *args, has_y=True, **kwargs):
args = list(args)
Expand Down

0 comments on commit 5a069ae

Please sign in to comment.