-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added fit to numpy dataset option for torch models. Added aucpr for t…
…orch models. Addded AUCPR for tensorflow models
- Loading branch information
ga84jog
committed
Jun 19, 2024
1 parent
49f5368
commit e4dbb2d
Showing
15 changed files
with
2,315 additions
and
242 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
from torcheval.metrics import BinaryAUPRC, MulticlassAUPRC, MultilabelAUPRC | ||
import torch.nn as nn | ||
|
||
# TODO! This absolutetly needs testing | ||
|
||
|
||
class AUCPRC: | ||
|
||
def __init__(self, task: str, num_classes: int = 1): | ||
if task == "binary": | ||
self.metric = BinaryAUPRC() | ||
elif task == "multiclass": | ||
self.metric = MulticlassAUPRC(num_classes=num_classes) | ||
elif task == "multilabel": | ||
self.metric = MultilabelAUPRC(num_labels=num_classes) | ||
else: | ||
raise ValueError("Unsupported task type or activation function") | ||
self._task = task | ||
|
||
def update(self, predictions, labels): | ||
# Reshape predictions and labels to handle the batch dimension | ||
if self._task == "binary": | ||
predictions = predictions.view(-1) | ||
labels = labels.view(-1) | ||
else: | ||
predictions = predictions.view(-1, predictions.shape[-1]) | ||
labels = labels.view(-1) | ||
|
||
self.metric.update(predictions, labels) | ||
|
||
def to(self, device): | ||
# Move the metric to the specified device | ||
self.metric = self.metric.to(device) | ||
return self | ||
|
||
def __getattr__(self, name): | ||
# Redirect attribute access to self.metric if it exists there | ||
if hasattr(self.metric, name) and not name in ["update", "to", "__dict__"]: | ||
return getattr(self.metric, name) | ||
# if name in self.__dict__: | ||
# return self.__dict__[name] | ||
# raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") | ||
|
||
def __setattr__(self, name, value): | ||
if name in ['_task', 'metric']: | ||
# Set attributes normally if they are part of the AUCPRC class | ||
super().__setattr__(name, value) | ||
elif hasattr(self, 'metric') and hasattr(self.metric, name): | ||
# Redirect attribute setting to self.metric if it exists there | ||
setattr(self.metric, name, value) | ||
else: | ||
# Set attributes normally otherwise | ||
super().__setattr__(name, value) |
Oops, something went wrong.