Skip to content

Commit

Permalink
Added fit to numpy dataset option for torch models. Added aucpr for t…
Browse files Browse the repository at this point in the history
…orch models. Addded AUCPR for tensorflow models
  • Loading branch information
ga84jog committed Jun 19, 2024
1 parent 49f5368 commit e4dbb2d
Show file tree
Hide file tree
Showing 15 changed files with 2,315 additions and 242 deletions.
1,794 changes: 1,764 additions & 30 deletions datalab/read_dataset.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions src/datasets/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,14 +1040,16 @@ def to_numpy(self,
read_masks=read_masks,
data_type=data_type)
else:
dataset, subject_ids = self.random_samples(n_subjects=n_samples,

dataset, subject_ids = self.random_samples(n_subjects=len(self.subject_ids),
read_timestamps=read_timestamps,
data_type=data_type,
return_ids=True,
read_masks=read_masks,
seed=seed)
for prefix in deepcopy(list(dataset.keys())):
dataset[prefix] = dataset[prefix][:min(n_samples, len(dataset[prefix]))]
if n_samples is not None:
for prefix in deepcopy(list(dataset.keys())):
dataset[prefix] = dataset[prefix][:min(n_samples, len(dataset[prefix]))]
if imputer is not None:
dataset["X"] = [imputer.transform(sample) for sample in dataset["X"]]
if scaler is not None:
Expand Down
24 changes: 11 additions & 13 deletions src/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __getitem__(self, index=None):

# Start with any remainder from the previous batch
X, y, M = next(self._generator) # if not deepsupervsion m is timestamps else mask
assert X.shape[1] == y.shape[1] == M.shape[1]
# Fetch new data until we have at least the required batch size
while X.shape[0] < self._batch_size:
X_res = self._remainder_X
Expand All @@ -82,15 +81,12 @@ def __getitem__(self, index=None):
m_res = self._remainder_M
M = self._stack_batches((M, m_res)) if m_res.size else M
y = self._stack_batches((y, y_res)) if y_res.size else y
assert X.shape[1] == y.shape[1] == M.shape[1]
else:
y = np.concatenate((y, y_res), axis=0, dtype=np.float32) if y_res.size else y
if X.shape[0] < self._batch_size:
self._remainder_X, \
self._remainder_y, \
self._remainder_M = next(self._generator)
assert self._remainder_X.shape[1] == self._remainder_y.shape[
1] == self._remainder_M.shape[1]

# If the accumulated batch is larger than required, split it
if X.shape[0] > self._batch_size:
Expand All @@ -114,7 +110,6 @@ def __getitem__(self, index=None):
self._remainder_M = np.array([])

if self._deep_supervision:
assert X.shape[1] == y.shape[1] == M.shape[1]
return X, y, M
return X, y

Expand All @@ -130,7 +125,8 @@ def __len__(self):
return self._steps

def __del__(self):
self._close()
if self._cpu_count:
self._close()

def _create_workers(self):
'''
Expand Down Expand Up @@ -179,7 +175,6 @@ def __generator(self):
dynamci_result = ray.get(ready_ids[0])
for object_result in dynamci_result:
X, y, t = ray.get(object_result)
assert X.shape[1] == y.shape[1] == t.shape[1]
yield X, y, t
else:
random.shuffle(self._random_ids)
Expand All @@ -189,7 +184,6 @@ def __generator(self):
reader=self._reader,
scaler=self._scaler,
bining=self._bining):
assert X.shape[1] == y.shape[1] == M.shape[1]
yield X, y, M
else:
for X, y, t in process_subject(args=(self._random_ids, self._batch_size),
Expand All @@ -198,7 +192,6 @@ def __generator(self):
row_only=self._row_only,
bining=self._bining,
target_replication=self._target_replication):
assert X.shape[1] == y.shape[1] == t.shape[1]
yield X, y, t

@staticmethod
Expand Down Expand Up @@ -236,10 +229,15 @@ def read_timeseries(X_df: pd.DataFrame, y_df: pd.DataFrame, row_only=False, bini
return Xs, ys, ts

def _close(self):
ray.get(self.__results)
for worker in self._ray_workers:
worker.exit.remote()
self._ray_workers.clear()
try:
ray.get(self.__results)
for worker in self._ray_workers:
worker.exit.remote()
self._ray_workers.clear()
except ValueError as e:
# If shutdown is quicker than this this will
# raise a ValueError. We can safely ignore this
pass

@staticmethod
def _stack_batches(data):
Expand Down
54 changes: 54 additions & 0 deletions src/metrics/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torcheval.metrics import BinaryAUPRC, MulticlassAUPRC, MultilabelAUPRC
import torch.nn as nn

# TODO! This absolutetly needs testing


class AUCPRC:

def __init__(self, task: str, num_classes: int = 1):
if task == "binary":
self.metric = BinaryAUPRC()
elif task == "multiclass":
self.metric = MulticlassAUPRC(num_classes=num_classes)
elif task == "multilabel":
self.metric = MultilabelAUPRC(num_labels=num_classes)
else:
raise ValueError("Unsupported task type or activation function")
self._task = task

def update(self, predictions, labels):
# Reshape predictions and labels to handle the batch dimension
if self._task == "binary":
predictions = predictions.view(-1)
labels = labels.view(-1)
else:
predictions = predictions.view(-1, predictions.shape[-1])
labels = labels.view(-1)

self.metric.update(predictions, labels)

def to(self, device):
# Move the metric to the specified device
self.metric = self.metric.to(device)
return self

def __getattr__(self, name):
# Redirect attribute access to self.metric if it exists there
if hasattr(self.metric, name) and not name in ["update", "to", "__dict__"]:
return getattr(self.metric, name)
# if name in self.__dict__:
# return self.__dict__[name]
# raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

def __setattr__(self, name, value):
if name in ['_task', 'metric']:
# Set attributes normally if they are part of the AUCPRC class
super().__setattr__(name, value)
elif hasattr(self, 'metric') and hasattr(self.metric, name):
# Redirect attribute setting to self.metric if it exists there
setattr(self.metric, name, value)
else:
# Set attributes normally otherwise
super().__setattr__(name, value)
Loading

0 comments on commit e4dbb2d

Please sign in to comment.