Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable to customized loss and val funcs #526

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 48 additions & 6 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Optional, Union, Iterable
from typing import Optional, Union, Iterable, Callable

import torch
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -221,7 +221,9 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None
# save all items containing "loss" or "error" in the name
# WDU: may enable customization keywords in the future
if ("loss" in item_name) or ("error" in item_name):
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step)
if isinstance(loss, torch.Tensor):
loss = loss.sum()
self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step)

def _auto_save_model_if_necessary(
self,
Expand Down Expand Up @@ -415,9 +417,17 @@ class BaseNNModel(BaseModel):
Training epochs, i.e. the maximum rounds of the model to be trained with.

patience :
Number of epochs the training procedure will keep if loss doesn't decrease.
Once exceeding the number, the training will stop.
Must be smaller than or equal to the value of ``epochs``.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default MSE metric.

num_workers :
The number of subprocesses to use for data loading.
Expand Down Expand Up @@ -474,6 +484,8 @@ def __init__(
batch_size: int,
epochs: int,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
Expand All @@ -487,17 +499,47 @@ def __init__(
verbose,
)

# check patience
if patience is None:
patience = -1 # early stopping on patience won't work if it is set as < 0
else:
assert (
patience <= epochs
), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}"

# training hype-parameters
# check train_loss_func and val_metric_func
train_loss_func_name, val_metric_func_name = "default", "loss (default)"
if train_loss_func is not None:
assert (
len(train_loss_func) == 1
), f"train_loss_func should have only 1 item, but got {len(train_loss_func)}"
train_loss_func_name, train_loss_func = train_loss_func.popitem()
assert isinstance(
train_loss_func, Callable
), "train_loss_func should be a callable function"
logger.info(
f"Using customized {train_loss_func_name} as the training loss function."
)
if val_metric_func is not None:
assert (
len(val_metric_func) == 1
), f"val_metric_func should have only 1 item, but got {len(val_metric_func)}"
val_metric_func_name, val_metric_func = val_metric_func.popitem()
assert isinstance(
val_metric_func, Callable
), "val_metric_func should be a callable function"
logger.info(
f"Using customized {val_metric_func_name} as the validation metric function."
)

# set up the hype-parameters
self.batch_size = batch_size
self.epochs = epochs
self.patience = patience
self.train_loss_func = train_loss_func
self.train_loss_func_name = train_loss_func_name
self.val_metric_func = val_metric_func
self.val_metric_func_name = val_metric_func_name
self.original_patience = patience
self.num_workers = num_workers

Expand Down
70 changes: 49 additions & 21 deletions pypots/classification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.utils.data import DataLoader

from ..base import BaseModel, BaseNNModel
from ..nn.functional import calc_acc
from ..utils.logging import logger

try:
Expand Down Expand Up @@ -155,9 +156,17 @@ class BaseNNClassifier(BaseNNModel):
Training epochs, i.e. the maximum rounds of the model to be trained with.

patience :
Number of epochs the training procedure will keep if loss doesn't decrease.
Once exceeding the number, the training will stop.
Must be smaller than or equal to the value of ``epochs``.
The patience for the early-stopping mechanism. Given a positive integer, the training process will be
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default MSE metric.

num_workers :
The number of subprocesses to use for data loading.
Expand Down Expand Up @@ -202,24 +211,36 @@ def __init__(
batch_size: int,
epochs: int,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Optional[str] = "best",
verbose: bool = True,
):
super().__init__(
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
batch_size=batch_size,
epochs=epochs,
patience=patience,
train_loss_func=train_loss_func,
val_metric_func=val_metric_func,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)
self.n_classes = n_classes

# set default training loss function and validation metric function if not given
if train_loss_func is None:
self.train_loss_func = torch.nn.functional.cross_entropy
self.train_loss_func_name = "CrossEntropy"
if val_metric_func is None:
self.val_metric_func = calc_acc
self.val_metric_func_name = "Accuracy"

@abstractmethod
def _assemble_input_for_training(self, data: list) -> dict:
"""Assemble the given data into a dictionary for training input.
Expand Down Expand Up @@ -308,30 +329,39 @@ def _train_model(

if val_loader is not None:
self.model.eval()
epoch_val_loss_collector = []
epoch_val_pred_collector = []
epoch_val_label_collector = []
with torch.no_grad():
for idx, data in enumerate(val_loader):
inputs = self._assemble_input_for_validating(data)
results = self.model.forward(inputs)
epoch_val_loss_collector.append(results["loss"].sum().item())
results = self.model(inputs)
epoch_val_pred_collector.append(results["classification_pred"])
epoch_val_label_collector.append(inputs["y"])

epoch_val_pred_collector = torch.cat(epoch_val_pred_collector, dim=-1)
epoch_val_label_collector = torch.cat(epoch_val_label_collector, dim=-1)

mean_val_loss = np.mean(epoch_val_loss_collector)
# TODO: refactor the following code to a function
epoch_val_pred_collector = np.argmax(epoch_val_pred_collector, axis=1)
mean_val_loss = self.val_metric_func(epoch_val_pred_collector, epoch_val_label_collector.numpy())

# save validation loss logs into the tensorboard file for every epoch if in need
if self.summary_writer is not None:
val_loss_dict = {
"classification_loss": mean_val_loss,
self.val_metric_func_name: mean_val_loss,
}
self._save_log_into_tb_file(epoch, "validating", val_loss_dict)

logger.info(
f"Epoch {epoch:03d} - "
f"training loss: {mean_train_loss:.4f}, "
f"validation loss: {mean_val_loss:.4f}"
f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, "
f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
logger.info(f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}")
logger.info(
f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}"
)
mean_loss = mean_train_loss

if np.isnan(mean_loss):
Expand Down Expand Up @@ -431,8 +461,6 @@ def classify(
) -> np.ndarray:
"""Classify the input data with the trained model.



Parameters
----------
test_set :
Expand Down
8 changes: 4 additions & 4 deletions pypots/classification/brits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
(
imputed_data,
f_reconstruction,
Expand All @@ -59,11 +59,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
}

# if in training mode, return results with losses
if training:
if self.training:
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["label"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["label"])
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["y"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["y"])
classification_loss = (f_classification_loss + b_classification_loss) / 2
loss = (
consistency_loss
Expand Down
40 changes: 28 additions & 12 deletions pypots/classification/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ class BRITS(BaseNNClassifier):
stopped when the model does not perform better after that number of epochs.
Leaving it default as None will disable the early-stopping.

train_loss_func:
The customized loss function designed by users for training the model.
If not given, will use the default loss as claimed in the original paper.

val_metric_func:
The customized metric function designed by users for validating the model.
If not given, will use the default MSE metric.

optimizer :
The optimizer for model training.
If not given, will use a default Adam optimizer.
Expand Down Expand Up @@ -96,6 +104,8 @@ def __init__(
batch_size: int = 32,
epochs: int = 100,
patience: Optional[int] = None,
train_loss_func: Optional[dict] = None,
val_metric_func: Optional[dict] = None,
optimizer: Optional[Optimizer] = Adam(),
num_workers: int = 0,
device: Optional[Union[str, torch.device, list]] = None,
Expand All @@ -104,15 +114,17 @@ def __init__(
verbose: bool = True,
):
super().__init__(
n_classes,
batch_size,
epochs,
patience,
num_workers,
device,
saving_path,
model_saving_strategy,
verbose,
n_classes=n_classes,
batch_size=batch_size,
epochs=epochs,
patience=patience,
train_loss_func=train_loss_func,
val_metric_func=val_metric_func,
num_workers=num_workers,
device=device,
saving_path=saving_path,
model_saving_strategy=model_saving_strategy,
verbose=verbose,
)

self.n_steps = n_steps
Expand All @@ -121,6 +133,10 @@ def __init__(
self.classification_weight = classification_weight
self.reconstruction_weight = reconstruction_weight

# CSDI has its own defined loss function, so we set them as None here
self.train_loss_func = None
self.train_loss_func_name = "default"

# set up the model
self.model = _BRITS(
self.n_steps,
Expand All @@ -147,13 +163,13 @@ def _assemble_input_for_training(self, data: list) -> dict:
back_X,
back_missing_mask,
back_deltas,
label,
y,
) = self._send_data_to_given_device(data)

# assemble input data
inputs = {
"indices": indices,
"label": label,
"y": y,
"forward": {
"X": X,
"missing_mask": missing_mask,
Expand Down Expand Up @@ -248,7 +264,7 @@ def predict(
with torch.no_grad():
for idx, data in enumerate(test_loader):
inputs = self._assemble_input_for_testing(data)
results = self.model.forward(inputs, training=False)
results = self.model.forward(inputs)
classification_pred = results["classification_pred"]
classification_collector.append(classification_pred)

Expand Down
6 changes: 3 additions & 3 deletions pypots/classification/grud/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
)
self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes)

def forward(self, inputs: dict, training: bool = True) -> dict:
def forward(self, inputs: dict) -> dict:
"""Forward processing of GRU-D.

Parameters
Expand Down Expand Up @@ -65,8 +65,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
results = {"classification_pred": classification_pred}

# if in training mode, return results with losses
if training:
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["label"])
if self.training:
classification_loss = F.nll_loss(torch.log(classification_pred), inputs["y"])
results["loss"] = classification_loss

return results
Loading
Loading