From 41b00a32daa45c5ee768a1a56664cf2c3e5b6021 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 8 May 2024 14:58:22 +0800 Subject: [PATCH 1/9] feat: enable users to customize training loss func and val metric func; --- pypots/base.py | 54 ++++++++- pypots/classification/base.py | 62 +++++++--- pypots/classification/brits/core.py | 8 +- pypots/classification/brits/model.py | 22 +++- pypots/classification/grud/core.py | 6 +- pypots/classification/grud/model.py | 18 ++- pypots/classification/raindrop/core.py | 6 +- pypots/classification/raindrop/model.py | 18 ++- pypots/classification/template/model.py | 2 + pypots/clustering/base.py | 26 ++-- pypots/clustering/crli/core.py | 5 - pypots/clustering/crli/model.py | 16 ++- pypots/clustering/template/model.py | 2 + pypots/clustering/vader/core.py | 113 +++++++++--------- pypots/clustering/vader/model.py | 20 +++- pypots/forecasting/base.py | 26 ++-- pypots/forecasting/csdi/core.py | 12 +- pypots/forecasting/csdi/model.py | 18 +-- pypots/forecasting/template/model.py | 7 +- pypots/imputation/autoformer/core.py | 4 +- pypots/imputation/autoformer/model.py | 17 ++- pypots/imputation/base.py | 38 ++++-- pypots/imputation/brits/core.py | 4 +- pypots/imputation/brits/model.py | 15 ++- pypots/imputation/crossformer/core.py | 4 +- pypots/imputation/crossformer/model.py | 17 ++- pypots/imputation/csdi/core.py | 12 +- pypots/imputation/csdi/model.py | 18 +-- pypots/imputation/dlinear/core.py | 4 +- pypots/imputation/dlinear/model.py | 17 ++- pypots/imputation/etsformer/core.py | 4 +- pypots/imputation/etsformer/model.py | 17 ++- pypots/imputation/fedformer/core.py | 4 +- pypots/imputation/fedformer/model.py | 17 ++- pypots/imputation/film/core.py | 4 +- pypots/imputation/film/model.py | 16 ++- pypots/imputation/frets/core.py | 4 +- pypots/imputation/frets/model.py | 17 ++- pypots/imputation/gpvae/core.py | 4 +- pypots/imputation/gpvae/model.py | 18 +-- pypots/imputation/informer/core.py | 4 +- pypots/imputation/informer/model.py | 17 ++- pypots/imputation/itransformer/core.py | 4 +- pypots/imputation/itransformer/model.py | 15 ++- pypots/imputation/mrnn/core.py | 4 +- pypots/imputation/mrnn/model.py | 15 ++- .../nonstationary_transformer/core.py | 4 +- .../nonstationary_transformer/model.py | 16 ++- pypots/imputation/patchtst/core.py | 4 +- pypots/imputation/patchtst/model.py | 16 ++- pypots/imputation/pyraformer/core.py | 4 +- pypots/imputation/pyraformer/model.py | 17 ++- pypots/imputation/saits/core.py | 19 ++- pypots/imputation/saits/model.py | 35 ++++-- pypots/imputation/template/model.py | 7 +- pypots/imputation/timesnet/core.py | 4 +- pypots/imputation/timesnet/model.py | 17 ++- pypots/imputation/transformer/core.py | 4 +- pypots/imputation/transformer/model.py | 15 ++- pypots/imputation/usgan/core.py | 17 +-- pypots/imputation/usgan/model.py | 8 +- pypots/nn/modules/csdi/backbone.py | 12 +- pypots/nn/modules/usgan/backbone.py | 3 +- 63 files changed, 664 insertions(+), 293 deletions(-) diff --git a/pypots/base.py b/pypots/base.py index d10c7c6e..402fb0b8 100644 --- a/pypots/base.py +++ b/pypots/base.py @@ -9,7 +9,7 @@ from abc import ABC from abc import abstractmethod from datetime import datetime -from typing import Optional, Union, Iterable +from typing import Optional, Union, Iterable, Callable import torch from torch.utils.tensorboard import SummaryWriter @@ -219,7 +219,9 @@ def _save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None # save all items containing "loss" or "error" in the name # WDU: may enable customization keywords in the future if ("loss" in item_name) or ("error" in item_name): - self.summary_writer.add_scalar(f"{stage}/{item_name}", loss.sum(), step) + if isinstance(loss, torch.Tensor): + loss = loss.sum() + self.summary_writer.add_scalar(f"{stage}/{item_name}", loss, step) def _auto_save_model_if_necessary( self, @@ -414,9 +416,17 @@ class BaseNNModel(BaseModel): Training epochs, i.e. the maximum rounds of the model to be trained with. patience : - Number of epochs the training procedure will keep if loss doesn't decrease. - Once exceeding the number, the training will stop. - Must be smaller than or equal to the value of ``epochs``. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. num_workers : The number of subprocesses to use for data loading. @@ -471,6 +481,8 @@ def __init__( batch_size: int, epochs: int, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, @@ -482,6 +494,7 @@ def __init__( model_saving_strategy, ) + # check patience if patience is None: patience = -1 # early stopping on patience won't work if it is set as < 0 else: @@ -489,10 +502,39 @@ def __init__( patience <= epochs ), f"patience must be smaller than epochs which is {epochs}, but got patience={patience}" - # training hype-parameters + # check train_loss_func and val_metric_func + train_loss_func_name, val_metric_func_name = "default", "loss (default)" + if train_loss_func is not None: + assert ( + len(train_loss_func) == 1 + ), f"train_loss_func should have only 1 item, but got {len(train_loss_func)}" + train_loss_func_name, train_loss_func = train_loss_func.popitem() + assert isinstance( + train_loss_func, Callable + ), "train_loss_func should be a callable function" + logger.info( + f"Using customized {train_loss_func_name} as the training loss function." + ) + if val_metric_func is not None: + assert ( + len(val_metric_func) == 1 + ), f"val_metric_func should have only 1 item, but got {len(val_metric_func)}" + val_metric_func_name, val_metric_func = val_metric_func.popitem() + assert isinstance( + val_metric_func, Callable + ), "val_metric_func should be a callable function" + logger.info( + f"Using customized {val_metric_func_name} as the validation metric function." + ) + + # set up the hype-parameters self.batch_size = batch_size self.epochs = epochs self.patience = patience + self.train_loss_func = train_loss_func + self.train_loss_func_name = train_loss_func_name + self.val_metric_func = val_metric_func + self.val_metric_func_name = val_metric_func_name self.original_patience = patience self.num_workers = num_workers diff --git a/pypots/classification/base.py b/pypots/classification/base.py index a758fed3..37732f79 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -16,6 +16,7 @@ from ..base import BaseModel, BaseNNModel from ..utils.logging import logger +from ..utils.metrics import calc_acc try: import nni @@ -151,9 +152,17 @@ class BaseNNClassifier(BaseNNModel): Training epochs, i.e. the maximum rounds of the model to be trained with. patience : - Number of epochs the training procedure will keep if loss doesn't decrease. - Once exceeding the number, the training will stop. - Must be smaller than or equal to the value of ``epochs``. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. num_workers : The number of subprocesses to use for data loading. @@ -196,6 +205,8 @@ def __init__( batch_size: int, epochs: int, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, @@ -205,6 +216,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -212,6 +225,14 @@ def __init__( ) self.n_classes = n_classes + # set default training loss function and validation metric function if not given + if train_loss_func is None: + self.train_loss_func = torch.nn.functional.cross_entropy + self.train_loss_func_name = "CrossEntropy" + if val_metric_func is None: + self.val_metric_func = calc_acc + self.val_metric_func_name = "Accuracy" + @abstractmethod def _assemble_input_for_training(self, data: list) -> dict: """Assemble the given data into a dictionary for training input. @@ -300,33 +321,48 @@ def _train_model( if val_loader is not None: self.model.eval() - epoch_val_loss_collector = [] + epoch_val_pred_collector = [] + epoch_val_label_collector = [] with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs) - epoch_val_loss_collector.append( - results["loss"].sum().item() + results = self.model(inputs) + epoch_val_pred_collector.append( + results["classification_pred"] ) + epoch_val_label_collector.append(inputs["y"]) + + epoch_val_pred_collector = torch.cat( + epoch_val_pred_collector, dim=-1 + ) + epoch_val_label_collector = torch.cat( + epoch_val_label_collector, dim=-1 + ) - mean_val_loss = np.mean(epoch_val_loss_collector) + # TODO: refactor the following code to a function + epoch_val_pred_collector = np.argmax( + epoch_val_pred_collector, axis=1 + ) + mean_val_loss = self.val_metric_func( + epoch_val_pred_collector, epoch_val_label_collector + ) # save validation loss logs into the tensorboard file for every epoch if in need if self.summary_writer is not None: val_loss_dict = { - "classification_loss": mean_val_loss, + self.val_metric_func_name: mean_val_loss, } self._save_log_into_tb_file(epoch, "validating", val_loss_dict) logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss @@ -431,8 +467,6 @@ def classify( ) -> np.ndarray: """Classify the input data with the trained model. - - Parameters ---------- test_set : diff --git a/pypots/classification/brits/core.py b/pypots/classification/brits/core.py index ebd1bae3..3b676bfb 100644 --- a/pypots/classification/brits/core.py +++ b/pypots/classification/brits/core.py @@ -36,7 +36,7 @@ def __init__( self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes) self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: ( imputed_data, f_reconstruction, @@ -59,11 +59,11 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: results["consistency_loss"] = consistency_loss results["reconstruction_loss"] = reconstruction_loss - f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["label"]) - b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["label"]) + f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["y"]) + b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["y"]) classification_loss = (f_classification_loss + b_classification_loss) / 2 loss = ( consistency_loss diff --git a/pypots/classification/brits/model.py b/pypots/classification/brits/model.py index b52bf4d5..a818ed63 100644 --- a/pypots/classification/brits/model.py +++ b/pypots/classification/brits/model.py @@ -53,6 +53,14 @@ class BRITS(BaseNNClassifier): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -94,6 +102,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -105,6 +115,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -117,6 +129,10 @@ def __init__( self.classification_weight = classification_weight self.reconstruction_weight = reconstruction_weight + # CSDI has its own defined loss function, so we set them as None here + self.train_loss_func = None + self.train_loss_func_name = "default" + # set up the model self.model = _BRITS( self.n_steps, @@ -143,13 +159,13 @@ def _assemble_input_for_training(self, data: list) -> dict: back_X, back_missing_mask, back_deltas, - label, + y, ) = self._send_data_to_given_device(data) # assemble input data inputs = { "indices": indices, - "label": label, + "y": y, "forward": { "X": X, "missing_mask": missing_mask, @@ -244,7 +260,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) classification_pred = results["classification_pred"] classification_collector.append(classification_pred) diff --git a/pypots/classification/grud/core.py b/pypots/classification/grud/core.py index 16cd2723..1af7e56d 100644 --- a/pypots/classification/grud/core.py +++ b/pypots/classification/grud/core.py @@ -40,7 +40,7 @@ def __init__( ) self.classifier = nn.Linear(self.rnn_hidden_size, self.n_classes) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: """Forward processing of GRU-D. Parameters @@ -71,9 +71,9 @@ def forward(self, inputs: dict, training: bool = True) -> dict: results = {"classification_pred": classification_pred} # if in training mode, return results with losses - if training: + if self.training: classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] + torch.log(classification_pred), inputs["y"] ) results["loss"] = classification_loss diff --git a/pypots/classification/grud/model.py b/pypots/classification/grud/model.py index f6413d9e..fc5c0123 100644 --- a/pypots/classification/grud/model.py +++ b/pypots/classification/grud/model.py @@ -48,6 +48,14 @@ class GRUD(BaseNNClassifier): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -87,6 +95,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -98,6 +108,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -132,7 +144,7 @@ def _assemble_input_for_training(self, data: list) -> dict: missing_mask, deltas, empirical_mean, - label, + y, ) = self._send_data_to_given_device(data) # assemble input data @@ -143,7 +155,7 @@ def _assemble_input_for_training(self, data: list) -> dict: "missing_mask": missing_mask, "deltas": deltas, "empirical_mean": empirical_mean, - "label": label, + "y": y, } return inputs @@ -221,7 +233,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) prediction = results["classification_pred"] classification_collector.append(prediction) diff --git a/pypots/classification/raindrop/core.py b/pypots/classification/raindrop/core.py index 5e6deb99..e6be79e9 100644 --- a/pypots/classification/raindrop/core.py +++ b/pypots/classification/raindrop/core.py @@ -64,7 +64,7 @@ def __init__( nn.Linear(d_final, n_classes), ) - def forward(self, inputs, training=True): + def forward(self, inputs): X, missing_mask, static, timestamps, lengths = ( inputs["X"], inputs["missing_mask"], @@ -115,9 +115,9 @@ def forward(self, inputs, training=True): results = {"classification_pred": classification_pred} # if in training mode, return results with losses - if training: + if self.training: classification_loss = F.nll_loss( - torch.log(classification_pred), inputs["label"] + torch.log(classification_pred), inputs["y"] ) results["loss"] = classification_loss diff --git a/pypots/classification/raindrop/model.py b/pypots/classification/raindrop/model.py index 78d64267..d9e28c58 100644 --- a/pypots/classification/raindrop/model.py +++ b/pypots/classification/raindrop/model.py @@ -74,6 +74,14 @@ class Raindrop(BaseNNClassifier): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -121,6 +129,8 @@ def __init__( batch_size=32, epochs=100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -132,6 +142,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -172,7 +184,7 @@ def _assemble_input_for_training(self, data: list) -> dict: missing_mask, deltas, empirical_mean, - label, + y, ) = self._send_data_to_given_device(data) bz, n_steps, n_features = X.shape @@ -185,7 +197,7 @@ def _assemble_input_for_training(self, data: list) -> dict: "timestamps": times, "lengths": lengths, "missing_mask": missing_mask, - "label": label, + "y": y, } return inputs @@ -265,7 +277,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) prediction = results["classification_pred"] classification_collector.append(prediction) diff --git a/pypots/classification/template/model.py b/pypots/classification/template/model.py index dec46806..8445b4a3 100644 --- a/pypots/classification/template/model.py +++ b/pypots/classification/template/model.py @@ -35,6 +35,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/clustering/base.py b/pypots/clustering/base.py index bdf68645..0df0f298 100644 --- a/pypots/clustering/base.py +++ b/pypots/clustering/base.py @@ -150,9 +150,17 @@ class BaseNNClusterer(BaseNNModel): Training epochs, i.e. the maximum rounds of the model to be trained with. patience : - Number of epochs the training procedure will keep if loss doesn't decrease. - Once exceeding the number, the training will stop. - Must be smaller than or equal to the value of ``epochs``. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. num_workers : The number of subprocesses to use for data loading. @@ -193,8 +201,10 @@ def __init__( self, n_clusters: int, batch_size: int, - epochs: int, + epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, @@ -204,6 +214,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -319,13 +331,13 @@ def _train_model( mean_val_loss = np.mean(epoch_val_loss_collector) logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss diff --git a/pypots/clustering/crli/core.py b/pypots/clustering/crli/core.py index 755d9ff7..74bc7605 100644 --- a/pypots/clustering/crli/core.py +++ b/pypots/clustering/crli/core.py @@ -55,7 +55,6 @@ def forward( self, inputs: dict, training_object: str = "generator", - training: bool = True, ) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] imputation_latent, discrimination, reconstruction, fcn_latent = self.backbone( @@ -68,10 +67,6 @@ def forward( "fcn_latent": fcn_latent, } - # return results directly, skip loss calculation to reduce inference time - if not training: - return results - if training_object == "discriminator": l_D = F.binary_cross_entropy_with_logits(discrimination, missing_mask) results["discrimination_loss"] = l_D diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index abe4e655..640e3e0c 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -74,6 +74,14 @@ class CRLI(BaseNNClusterer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + G_optimizer : The optimizer for the generator training. If not given, will use a default Adam optimizer. @@ -123,6 +131,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, G_optimizer: Optional[Optimizer] = Adam(), D_optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, @@ -135,6 +145,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -259,7 +271,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, training=True) + results = self.model.forward(inputs) epoch_val_loss_G_collector.append( results["generation_loss"].sum().item() ) @@ -415,7 +427,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - inputs = self.model.forward(inputs, training=False) + inputs = self.model.forward(inputs) clustering_latent_collector.append(inputs["fcn_latent"]) if return_latent_vars: imputation_collector.append(inputs["imputation_latent"]) diff --git a/pypots/clustering/template/model.py b/pypots/clustering/template/model.py index 0ed75220..f86e172c 100644 --- a/pypots/clustering/template/model.py +++ b/pypots/clustering/template/model.py @@ -35,6 +35,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, diff --git a/pypots/clustering/vader/core.py b/pypots/clustering/vader/core.py index 52843b72..bb19c1e3 100644 --- a/pypots/clustering/vader/core.py +++ b/pypots/clustering/vader/core.py @@ -75,7 +75,6 @@ def forward( self, inputs: dict, pretrain: bool = False, - training: bool = True, ) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] device = X.device @@ -113,63 +112,59 @@ def forward( results["loss"] = reconstruction_loss return results - # if in training mode, return results with losses - if training: - # calculate the latent loss for model training - var_tilde = torch.exp(stddev_tilde) - stddev_c = torch.log(var_c + self.eps) - log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) - log_phi_c = torch.log(phi_c + self.eps) - - batch_size = z.shape[0] - - ii, jj = torch.meshgrid( - torch.arange(self.n_clusters, dtype=torch.int64, device=device), - torch.arange(batch_size, dtype=torch.int64, device=device), - indexing="ij", - ) - ii = ii.flatten() - jj = jj.flatten() - - lsc_b = stddev_c.index_select(dim=0, index=ii) - mc_b = mu_c.index_select(dim=0, index=ii) - sc_b = var_c.index_select(dim=0, index=ii) - z_b = z.index_select(dim=0, index=jj) - log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) - log_pdf_z = log_pdf_z.reshape( - [batch_size, self.n_clusters, self.d_mu_stddev] - ) - - log_p = log_phi_c + log_pdf_z.sum(dim=2) - lse_p = log_p.logsumexp(dim=1, keepdim=True) - log_gamma_c = log_p - lse_p - gamma_c = torch.exp(log_gamma_c) - - term1 = torch.log(var_c + self.eps) - st_b = var_tilde.index_select(dim=0, index=jj) - sc_b = var_c.index_select(dim=0, index=ii) - term2 = torch.reshape( - st_b / (sc_b + self.eps), - [batch_size, self.n_clusters, self.d_mu_stddev], - ) - mt_b = mu_tilde.index_select(dim=0, index=jj) - mc_b = mu_c.index_select(dim=0, index=ii) - term3 = torch.reshape( - torch.square(mt_b - mc_b) / (sc_b + self.eps), - [batch_size, self.n_clusters, self.d_mu_stddev], - ) - - latent_loss1 = 0.5 * torch.sum( - gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 - ) - latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) - latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) - - latent_loss1 = latent_loss1.mean() - latent_loss2 = latent_loss2.mean() - latent_loss3 = latent_loss3.mean() - latent_loss = latent_loss1 + latent_loss2 + latent_loss3 - - results["loss"] = reconstruction_loss + self.alpha * latent_loss + # calculate the latent loss for model training + var_tilde = torch.exp(stddev_tilde) + stddev_c = torch.log(var_c + self.eps) + log_2pi = torch.log(torch.tensor([2 * torch.pi], device=device)) + log_phi_c = torch.log(phi_c + self.eps) + + batch_size = z.shape[0] + + ii, jj = torch.meshgrid( + torch.arange(self.n_clusters, dtype=torch.int64, device=device), + torch.arange(batch_size, dtype=torch.int64, device=device), + indexing="ij", + ) + ii = ii.flatten() + jj = jj.flatten() + + lsc_b = stddev_c.index_select(dim=0, index=ii) + mc_b = mu_c.index_select(dim=0, index=ii) + sc_b = var_c.index_select(dim=0, index=ii) + z_b = z.index_select(dim=0, index=jj) + log_pdf_z = -0.5 * (lsc_b + log_2pi + torch.square(z_b - mc_b) / sc_b) + log_pdf_z = log_pdf_z.reshape([batch_size, self.n_clusters, self.d_mu_stddev]) + + log_p = log_phi_c + log_pdf_z.sum(dim=2) + lse_p = log_p.logsumexp(dim=1, keepdim=True) + log_gamma_c = log_p - lse_p + gamma_c = torch.exp(log_gamma_c) + + term1 = torch.log(var_c + self.eps) + st_b = var_tilde.index_select(dim=0, index=jj) + sc_b = var_c.index_select(dim=0, index=ii) + term2 = torch.reshape( + st_b / (sc_b + self.eps), + [batch_size, self.n_clusters, self.d_mu_stddev], + ) + mt_b = mu_tilde.index_select(dim=0, index=jj) + mc_b = mu_c.index_select(dim=0, index=ii) + term3 = torch.reshape( + torch.square(mt_b - mc_b) / (sc_b + self.eps), + [batch_size, self.n_clusters, self.d_mu_stddev], + ) + + latent_loss1 = 0.5 * torch.sum( + gamma_c * torch.sum(term1 + term2 + term3, dim=2), dim=1 + ) + latent_loss2 = -torch.sum(gamma_c * (log_phi_c - log_gamma_c), dim=1) + latent_loss3 = -0.5 * torch.sum(1 + stddev_tilde, dim=1) + + latent_loss1 = latent_loss1.mean() + latent_loss2 = latent_loss2.mean() + latent_loss3 = latent_loss3.mean() + latent_loss = latent_loss1 + latent_loss2 + latent_loss3 + + results["loss"] = reconstruction_loss + self.alpha * latent_loss return results diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index cfc85f97..22116682 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -63,6 +63,14 @@ class VaDER(BaseNNClusterer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -101,6 +109,8 @@ def __init__( epochs: int = 100, pretrain_epochs: int = 10, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -112,6 +122,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -288,13 +300,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss @@ -432,7 +444,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) mu_tilde = results["mu_tilde"].cpu().numpy() mu_tilde_collector.append(mu_tilde) diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 0b8a153d..85964dda 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -141,9 +141,17 @@ class BaseNNForecaster(BaseNNModel): Training epochs, i.e. the maximum rounds of the model to be trained with. patience : - Number of epochs the training procedure will keep if loss doesn't decrease. - Once exceeding the number, the training will stop. - Must be smaller than or equal to the value of ``epochs``. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. num_workers : The number of subprocesses to use for data loading. @@ -184,6 +192,8 @@ def __init__( batch_size: int, epochs: int, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, @@ -193,6 +203,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -291,7 +303,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) forecasting_mse = ( calc_mse( results["forecasting_data"], @@ -315,13 +327,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss diff --git a/pypots/forecasting/csdi/core.py b/pypots/forecasting/csdi/core.py index e488cb20..d869f2c0 100644 --- a/pypots/forecasting/csdi/core.py +++ b/pypots/forecasting/csdi/core.py @@ -96,9 +96,9 @@ def get_side_info(self, observed_tp, cond_mask, feature_id): return side_info - def forward(self, inputs, training=True, n_sampling_times=1): + def forward(self, inputs, n_sampling_times=1): results = {} - if training: # for training + if self.training: # for training (observed_data, indicating_mask, cond_mask, observed_tp, feature_id) = ( inputs["X_ori"], inputs["indicating_mask"], @@ -108,10 +108,10 @@ def forward(self, inputs, training=True, n_sampling_times=1): ) side_info = self.get_side_info(observed_tp, cond_mask, feature_id) training_loss = self.backbone.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, training + observed_data, cond_mask, indicating_mask, side_info ) results["loss"] = training_loss - elif not training and n_sampling_times == 0: # for validating + elif not self.training and n_sampling_times == 0: # for validating (observed_data, indicating_mask, cond_mask, observed_tp, feature_id) = ( inputs["X_ori"], inputs["indicating_mask"], @@ -121,10 +121,10 @@ def forward(self, inputs, training=True, n_sampling_times=1): ) side_info = self.get_side_info(observed_tp, cond_mask, feature_id) validating_loss = self.backbone.calc_loss_valid( - observed_data, cond_mask, indicating_mask, side_info, training + observed_data, cond_mask, indicating_mask, side_info ) results["loss"] = validating_loss - elif not training and n_sampling_times > 0: # for testing + elif not self.training and n_sampling_times > 0: # for testing observed_data, cond_mask, observed_tp, feature_id = ( inputs["X"], inputs["cond_mask"], diff --git a/pypots/forecasting/csdi/model.py b/pypots/forecasting/csdi/model.py index 77d32d86..a9915ee7 100644 --- a/pypots/forecasting/csdi/model.py +++ b/pypots/forecasting/csdi/model.py @@ -152,6 +152,8 @@ def __init__( batch_size, epochs, patience, + None, + None, num_workers, device, saving_path, @@ -168,6 +170,11 @@ def __init__( self.n_pred_steps = n_pred_steps self.n_pred_features = n_pred_features self.target_strategy = target_strategy + # CSDI has its own defined loss function and validation loss, so we set them as None here + self.train_loss_func = None + self.train_loss_func_name = "default" + self.val_metric_func = None + self.val_metric_func_name = "loss (default)" # set up the model self.model = _CSDI( @@ -268,9 +275,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=0 - ) + results = self.model.forward(inputs, n_sampling_times=0) val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.asarray(val_loss_collector).mean() @@ -284,13 +289,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss @@ -441,7 +446,6 @@ def predict( inputs = self._assemble_input_for_testing(data) results = self.model( inputs, - training=False, n_sampling_times=n_sampling_times, ) forecasting_data = results["forecasting_data"][ diff --git a/pypots/forecasting/template/model.py b/pypots/forecasting/template/model.py index fd817694..60595bd9 100644 --- a/pypots/forecasting/template/model.py +++ b/pypots/forecasting/template/model.py @@ -34,6 +34,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -44,12 +46,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, - ) - # set up the hyper-parameters + ) # set up the hyper-parameters # TODO: set up your model's hyper-parameters here # set up the model diff --git a/pypots/imputation/autoformer/core.py b/pypots/imputation/autoformer/core.py index fb883c4e..a382be5b 100644 --- a/pypots/imputation/autoformer/core.py +++ b/pypots/imputation/autoformer/core.py @@ -53,7 +53,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Autoformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -74,7 +74,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/autoformer/model.py b/pypots/imputation/autoformer/model.py index dcdc8b64..5fd0f818 100644 --- a/pypots/imputation/autoformer/model.py +++ b/pypots/imputation/autoformer/model.py @@ -71,6 +71,14 @@ class Autoformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -116,7 +124,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -127,12 +137,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -284,7 +295,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index f08d310f..e79d7e23 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -142,9 +142,17 @@ class BaseNNImputer(BaseNNModel): Training epochs, i.e. the maximum rounds of the model to be trained with. patience : - Number of epochs the training procedure will keep if loss doesn't decrease. - Once exceeding the number, the training will stop. - Must be smaller than or equal to the value of ``epochs``. + The patience for the early-stopping mechanism. Given a positive integer, the training process will be + stopped when the model does not perform better after that number of epochs. + Leaving it default as None will disable the early-stopping. + + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. num_workers : The number of subprocesses to use for data loading. @@ -185,6 +193,8 @@ def __init__( batch_size: int, epochs: int, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, @@ -194,12 +204,22 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) + # set default training loss function and validation metric function if not given + if train_loss_func is None: + self.train_loss_func = calc_mse + self.train_loss_func_name = "MSE" + if val_metric_func is None: + self.val_metric_func = calc_mse + self.val_metric_func_name = "MSE" + @abstractmethod def _assemble_input_for_training(self, data: list) -> dict: """Assemble the given data into a dictionary for training input. @@ -293,8 +313,8 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, training=False) - imputation_mse = ( + results = self.model.forward(inputs) + imputation_error = ( calc_mse( results["imputed_data"], inputs["X_ori"], @@ -304,7 +324,7 @@ def _train_model( .detach() .item() ) - imputation_loss_collector.append(imputation_mse) + imputation_loss_collector.append(imputation_error) mean_val_loss = np.mean(imputation_loss_collector) @@ -317,13 +337,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss diff --git a/pypots/imputation/brits/core.py b/pypots/imputation/brits/core.py index 9d1734c4..c6869c83 100644 --- a/pypots/imputation/brits/core.py +++ b/pypots/imputation/brits/core.py @@ -41,7 +41,7 @@ def __init__( self.model = BackboneBRITS(n_steps, n_features, rnn_hidden_size) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: ( imputed_data, f_reconstruction, @@ -57,7 +57,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: results["consistency_loss"] = consistency_loss results["reconstruction_loss"] = reconstruction_loss loss = consistency_loss + reconstruction_loss diff --git a/pypots/imputation/brits/model.py b/pypots/imputation/brits/model.py index 5f1676cf..d22bf86d 100644 --- a/pypots/imputation/brits/model.py +++ b/pypots/imputation/brits/model.py @@ -45,6 +45,14 @@ class BRITS(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -83,6 +91,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -93,12 +103,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features self.rnn_hidden_size = rnn_hidden_size @@ -238,7 +249,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/crossformer/core.py b/pypots/imputation/crossformer/core.py index e26f27ca..48d449b0 100644 --- a/pypots/imputation/crossformer/core.py +++ b/pypots/imputation/crossformer/core.py @@ -83,7 +83,7 @@ def __init__( # apply SAITS loss function to Crossformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Crossformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -113,7 +113,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/crossformer/model.py b/pypots/imputation/crossformer/model.py index 7db9aaba..6e826c56 100644 --- a/pypots/imputation/crossformer/model.py +++ b/pypots/imputation/crossformer/model.py @@ -74,6 +74,14 @@ class Crossformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -120,7 +128,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -131,12 +141,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -290,7 +301,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/csdi/core.py b/pypots/imputation/csdi/core.py index a80acce3..0c6ec279 100644 --- a/pypots/imputation/csdi/core.py +++ b/pypots/imputation/csdi/core.py @@ -88,9 +88,9 @@ def get_side_info(self, observed_tp, cond_mask): return side_info - def forward(self, inputs, training=True, n_sampling_times=1): + def forward(self, inputs, n_sampling_times=1): results = {} - if training: # for training + if self.training: # for training (observed_data, indicating_mask, cond_mask, observed_tp) = ( inputs["X_ori"], inputs["indicating_mask"], @@ -99,10 +99,10 @@ def forward(self, inputs, training=True, n_sampling_times=1): ) side_info = self.get_side_info(observed_tp, cond_mask) training_loss = self.backbone.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, training + observed_data, cond_mask, indicating_mask, side_info ) results["loss"] = training_loss - elif not training and n_sampling_times == 0: # for validating + elif not self.training and n_sampling_times == 0: # for validating (observed_data, indicating_mask, cond_mask, observed_tp) = ( inputs["X_ori"], inputs["indicating_mask"], @@ -111,10 +111,10 @@ def forward(self, inputs, training=True, n_sampling_times=1): ) side_info = self.get_side_info(observed_tp, cond_mask) validating_loss = self.backbone.calc_loss_valid( - observed_data, cond_mask, indicating_mask, side_info, training + observed_data, cond_mask, indicating_mask, side_info ) results["loss"] = validating_loss - elif not training and n_sampling_times > 0: # for testing + elif not self.training and n_sampling_times > 0: # for testing observed_data, cond_mask, observed_tp = ( inputs["X"], inputs["cond_mask"], diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 33c1535b..832e6dc4 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -144,6 +144,8 @@ def __init__( batch_size, epochs, patience, + None, + None, num_workers, device, saving_path, @@ -153,6 +155,11 @@ def __init__( assert schedule in ["quad", "linear"] self.n_steps = n_steps self.target_strategy = target_strategy + # CSDI has its own defined loss function and validation loss, so we set them as None here + self.train_loss_func = None + self.train_loss_func_name = "default" + self.val_metric_func = None + self.val_metric_func_name = "loss (default)" # set up the model self.model = _CSDI( @@ -248,9 +255,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=0 - ) + results = self.model.forward(inputs, n_sampling_times=0) val_loss_collector.append(results["loss"].sum().item()) mean_val_loss = np.asarray(val_loss_collector).mean() @@ -264,13 +269,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss @@ -420,7 +425,6 @@ def predict( inputs = self._assemble_input_for_testing(data) results = self.model( inputs, - training=False, n_sampling_times=n_sampling_times, ) imputed_data = results["imputed_data"] diff --git a/pypots/imputation/dlinear/core.py b/pypots/imputation/dlinear/core.py index 78d3bcbd..3f43f4a0 100644 --- a/pypots/imputation/dlinear/core.py +++ b/pypots/imputation/dlinear/core.py @@ -48,7 +48,7 @@ def __init__( # apply SAITS loss function to Transformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # input preprocessing and embedding for DLinear @@ -78,7 +78,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/dlinear/model.py b/pypots/imputation/dlinear/model.py index 3721eead..cb48708d 100644 --- a/pypots/imputation/dlinear/model.py +++ b/pypots/imputation/dlinear/model.py @@ -60,6 +60,14 @@ class DLinear(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -101,7 +109,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -112,12 +122,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -261,7 +272,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/etsformer/core.py b/pypots/imputation/etsformer/core.py index 92c61f5d..162b6502 100644 --- a/pypots/imputation/etsformer/core.py +++ b/pypots/imputation/etsformer/core.py @@ -77,7 +77,7 @@ def __init__( # apply SAITS loss function to ETSformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original ETSformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -98,7 +98,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index 6dbb2fbc..8dbd9009 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -71,6 +71,14 @@ class ETSformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -116,7 +124,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -127,12 +137,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -284,7 +295,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/fedformer/core.py b/pypots/imputation/fedformer/core.py index 617a1462..679d5868 100644 --- a/pypots/imputation/fedformer/core.py +++ b/pypots/imputation/fedformer/core.py @@ -58,7 +58,7 @@ def __init__( # apply SAITS loss function to ETSformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original FEDformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -78,7 +78,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py index 2d8ca073..94da1858 100644 --- a/pypots/imputation/fedformer/model.py +++ b/pypots/imputation/fedformer/model.py @@ -79,6 +79,14 @@ class FEDformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -126,7 +134,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -137,12 +147,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -298,7 +309,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/film/core.py b/pypots/imputation/film/core.py index 2e48f8c2..63270edb 100644 --- a/pypots/imputation/film/core.py +++ b/pypots/imputation/film/core.py @@ -48,7 +48,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original FiLM paper isn't proposed for imputation task. Hence the model doesn't take @@ -69,7 +69,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/film/model.py b/pypots/imputation/film/model.py index 8caae0d5..2bbe7e6d 100644 --- a/pypots/imputation/film/model.py +++ b/pypots/imputation/film/model.py @@ -67,6 +67,14 @@ class FiLM(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -111,7 +119,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -122,6 +132,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -278,7 +290,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/frets/core.py b/pypots/imputation/frets/core.py index 488880d9..3247adb4 100644 --- a/pypots/imputation/frets/core.py +++ b/pypots/imputation/frets/core.py @@ -45,7 +45,7 @@ def __init__( self.output_projection = nn.Linear(embed_size, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original FreTS paper isn't proposed for imputation task. Hence the model doesn't take @@ -65,7 +65,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/frets/model.py b/pypots/imputation/frets/model.py index 42101ac4..15c06f5a 100644 --- a/pypots/imputation/frets/model.py +++ b/pypots/imputation/frets/model.py @@ -59,6 +59,14 @@ class FreTS(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -100,7 +108,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -111,12 +121,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -260,7 +271,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/gpvae/core.py b/pypots/imputation/gpvae/core.py index 79b8f724..0a99fe69 100644 --- a/pypots/imputation/gpvae/core.py +++ b/pypots/imputation/gpvae/core.py @@ -89,11 +89,11 @@ def __init__( window_size, ) - def forward(self, inputs, training=True, n_sampling_times=1): + def forward(self, inputs, n_sampling_times=1): X, missing_mask = inputs["X"], inputs["missing_mask"] results = {} - if training: + if self.training: elbo_loss = self.backbone(X, missing_mask) results["loss"] = elbo_loss else: diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 0af6a73d..471e3996 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -132,6 +132,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -142,6 +144,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -266,9 +270,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=1 - ) + results = self.model.forward(inputs, n_sampling_times=1) imputed_data = results["imputed_data"].mean(axis=1) imputation_mse = ( calc_mse( @@ -293,13 +295,13 @@ def _train_model( logger.info( f"Epoch {epoch:03d} - " - f"training loss: {mean_train_loss:.4f}, " - f"validation loss: {mean_val_loss:.4f}" + f"training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}, " + f"validation {self.val_metric_func_name}: {mean_val_loss:.4f}" ) mean_loss = mean_val_loss else: logger.info( - f"Epoch {epoch:03d} - training loss: {mean_train_loss:.4f}" + f"Epoch {epoch:03d} - training loss ({self.train_loss_func_name}): {mean_train_loss:.4f}" ) mean_loss = mean_train_loss @@ -440,9 +442,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward( - inputs, training=False, n_sampling_times=n_sampling_times - ) + results = self.model.forward(inputs, n_sampling_times=n_sampling_times) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/informer/core.py b/pypots/imputation/informer/core.py index e9199b02..a06eb709 100644 --- a/pypots/imputation/informer/core.py +++ b/pypots/imputation/informer/core.py @@ -69,7 +69,7 @@ def __init__( self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Informer paper isn't proposed for imputation task. Hence the model doesn't take @@ -91,7 +91,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/informer/model.py b/pypots/imputation/informer/model.py index 85b2b1be..df72b5ed 100644 --- a/pypots/imputation/informer/model.py +++ b/pypots/imputation/informer/model.py @@ -68,6 +68,14 @@ class Informer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -112,7 +120,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -123,12 +133,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -278,7 +289,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/itransformer/core.py b/pypots/imputation/itransformer/core.py index 5747f12e..cba9d416 100644 --- a/pypots/imputation/itransformer/core.py +++ b/pypots/imputation/itransformer/core.py @@ -53,7 +53,7 @@ def __init__( # apply SAITS loss function to Transformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Informer paper isn't proposed for imputation task. Hence the model doesn't take @@ -79,7 +79,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/itransformer/model.py b/pypots/imputation/itransformer/model.py index 045bd2dc..adae4068 100644 --- a/pypots/imputation/itransformer/model.py +++ b/pypots/imputation/itransformer/model.py @@ -81,6 +81,14 @@ class iTransformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -128,6 +136,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -138,12 +148,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - if d_model != n_heads * d_k: logger.warning( "‼️ d_model must = n_heads * d_k, it should be divisible by n_heads " @@ -283,7 +294,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/mrnn/core.py b/pypots/imputation/mrnn/core.py index 0cdf7084..2092d10e 100644 --- a/pypots/imputation/mrnn/core.py +++ b/pypots/imputation/mrnn/core.py @@ -18,7 +18,7 @@ def __init__(self, n_steps, n_features, rnn_hidden_size): super().__init__() self.backbone = BackboneMRNN(n_steps, n_features, rnn_hidden_size) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X = inputs["forward"]["X"] M = inputs["forward"]["missing_mask"] @@ -30,7 +30,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: RNN_loss = calc_rmse(RNN_estimation, X, M) FCN_loss = calc_rmse(FCN_estimation, RNN_imputed_data) reconstruction_loss = RNN_loss + FCN_loss diff --git a/pypots/imputation/mrnn/model.py b/pypots/imputation/mrnn/model.py index e3527432..e4326aab 100644 --- a/pypots/imputation/mrnn/model.py +++ b/pypots/imputation/mrnn/model.py @@ -46,6 +46,14 @@ class MRNN(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -84,6 +92,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -94,12 +104,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features self.rnn_hidden_size = rnn_hidden_size @@ -240,7 +251,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/nonstationary_transformer/core.py b/pypots/imputation/nonstationary_transformer/core.py index 9ca21e1d..cf17c8f8 100644 --- a/pypots/imputation/nonstationary_transformer/core.py +++ b/pypots/imputation/nonstationary_transformer/core.py @@ -72,7 +72,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] X_enc, means, stdev = nonstationary_norm(X, missing_mask) @@ -98,7 +98,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/nonstationary_transformer/model.py b/pypots/imputation/nonstationary_transformer/model.py index 9786ccd7..12c455cf 100644 --- a/pypots/imputation/nonstationary_transformer/model.py +++ b/pypots/imputation/nonstationary_transformer/model.py @@ -73,6 +73,14 @@ class NonstationaryTransformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -118,7 +126,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -129,6 +139,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -290,7 +302,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/patchtst/core.py b/pypots/imputation/patchtst/core.py index 9a356173..f51c18eb 100644 --- a/pypots/imputation/patchtst/core.py +++ b/pypots/imputation/patchtst/core.py @@ -53,7 +53,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original PatchTST paper isn't proposed for imputation task. Hence the model doesn't take @@ -80,7 +80,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: "imputed_data": imputed_data, } - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/patchtst/model.py b/pypots/imputation/patchtst/model.py index f4033a49..745cc20b 100644 --- a/pypots/imputation/patchtst/model.py +++ b/pypots/imputation/patchtst/model.py @@ -85,6 +85,14 @@ class PatchTST(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -133,7 +141,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -144,6 +154,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -316,7 +328,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/pyraformer/core.py b/pypots/imputation/pyraformer/core.py index 3087d90a..7acb0817 100644 --- a/pypots/imputation/pyraformer/core.py +++ b/pypots/imputation/pyraformer/core.py @@ -52,7 +52,7 @@ def __init__( self.output_projection = nn.Linear((len(window_size) + 1) * d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Pyraformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -73,7 +73,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/pyraformer/model.py b/pypots/imputation/pyraformer/model.py index 757e96f3..7bc351bf 100644 --- a/pypots/imputation/pyraformer/model.py +++ b/pypots/imputation/pyraformer/model.py @@ -74,6 +74,14 @@ class Pyraformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -120,7 +128,9 @@ def __init__( MIT_weight: float = 1, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -131,12 +141,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -290,7 +301,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/saits/core.py b/pypots/imputation/saits/core.py index f5189ab3..93f79aea 100644 --- a/pypots/imputation/saits/core.py +++ b/pypots/imputation/saits/core.py @@ -32,7 +32,7 @@ def __init__( diagonal_attention_mask: bool = True, ORT_weight: float = 1, MIT_weight: float = 1, - customized_loss_func: Callable = calc_mae, + loss_func: Callable = calc_mae, ): super().__init__() self.n_layers = n_layers @@ -40,7 +40,7 @@ def __init__( self.diagonal_attention_mask = diagonal_attention_mask self.ORT_weight = ORT_weight self.MIT_weight = MIT_weight - self.customized_loss_func = customized_loss_func + self.loss_func = loss_func self.encoder = BackboneSAITS( n_steps, @@ -59,13 +59,12 @@ def forward( self, inputs: dict, diagonal_attention_mask: bool = True, - training: bool = True, ) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # determine the attention mask - if (training and self.diagonal_attention_mask) or ( - (not training) and diagonal_attention_mask + if (self.training and self.diagonal_attention_mask) or ( + (not self.training) and diagonal_attention_mask ): diagonal_attention_mask = (1 - torch.eye(self.n_steps)).to(X.device) # then broadcast on the batch axis @@ -95,21 +94,21 @@ def forward( } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] # calculate loss for the observed reconstruction task (ORT) # this calculation is more complicated that pypots.nn.modules.saits.SaitsLoss because # SAITS model structure has three parts of representation ORT_loss = 0 - ORT_loss += self.customized_loss_func(X_tilde_1, X, missing_mask) - ORT_loss += self.customized_loss_func(X_tilde_2, X, missing_mask) - ORT_loss += self.customized_loss_func(X_tilde_3, X, missing_mask) + ORT_loss += self.loss_func(X_tilde_1, X, missing_mask) + ORT_loss += self.loss_func(X_tilde_2, X, missing_mask) + ORT_loss += self.loss_func(X_tilde_3, X, missing_mask) ORT_loss /= 3 ORT_loss = self.ORT_weight * ORT_loss # calculate loss for the masked imputation task (MIT) - MIT_loss = self.MIT_weight * self.customized_loss_func( + MIT_loss = self.MIT_weight * self.loss_func( X_tilde_3, X_ori, indicating_mask ) # `loss` is always the item for backward propagating to update the model diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index ad0fd97b..bf523e05 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -6,7 +6,7 @@ # Created by Wenjie Du # License: BSD-3-Clause -from typing import Union, Optional, Callable +from typing import Union, Optional import numpy as np import torch @@ -20,7 +20,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import calc_mae +from ...utils.metrics import calc_mae, calc_mse class SAITS(BaseNNImputer): @@ -84,9 +84,13 @@ class SAITS(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. - customized_loss_func: - The customized loss function designed by users for the model to optimize. - If not given, will use the default MAE loss as claimed in the original paper. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. optimizer : The optimizer for model training. @@ -136,7 +140,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, - customized_loss_func: Callable = calc_mae, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -147,6 +152,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -163,6 +170,14 @@ def __init__( f"⚠️ d_model is reset to {d_model} = n_heads ({n_heads}) * d_k ({d_k})" ) + # set default training loss function and validation metric function if not given + if train_loss_func is None: + self.train_loss_func = calc_mae + self.train_loss_func_name = "MAE" + if val_metric_func is None: + self.val_metric_func = calc_mse + self.val_metric_func_name = "MSE" + self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -193,13 +208,11 @@ def __init__( self.diagonal_attention_mask, self.ORT_weight, self.MIT_weight, + self.train_loss_func, ) self._print_model_size() self._send_model_to_given_device() - # set up the loss function - self.customized_loss_func = customized_loss_func - # set up the optimizer self.optimizer = optimizer self.optimizer.init_optimizer(self.model.parameters()) @@ -332,9 +345,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward( - inputs, diagonal_attention_mask, training=False - ) + results = self.model.forward(inputs, diagonal_attention_mask) imputation_collector.append(results["imputed_data"]) if return_latent_vars: diff --git a/pypots/imputation/template/model.py b/pypots/imputation/template/model.py index 9f135893..577e631e 100644 --- a/pypots/imputation/template/model.py +++ b/pypots/imputation/template/model.py @@ -34,6 +34,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -44,12 +46,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, - ) - # set up the hyper-parameters + ) # set up the hyper-parameters # TODO: set up your model's hyper-parameters here # set up the model diff --git a/pypots/imputation/timesnet/core.py b/pypots/imputation/timesnet/core.py index 15aefecf..10503bbb 100644 --- a/pypots/imputation/timesnet/core.py +++ b/pypots/imputation/timesnet/core.py @@ -51,7 +51,7 @@ def __init__( # for the imputation task, the output dim is the same as input dim self.projection = nn.Linear(d_model, n_features) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: "imputed_data": imputed_data, } - if training: + if self.training: # `loss` is always the item for backward propagating to update the model loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss diff --git a/pypots/imputation/timesnet/model.py b/pypots/imputation/timesnet/model.py index 34aba691..a38e46a5 100644 --- a/pypots/imputation/timesnet/model.py +++ b/pypots/imputation/timesnet/model.py @@ -67,6 +67,14 @@ class TimesNet(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -110,7 +118,9 @@ def __init__( apply_nonstationary_norm: bool = False, batch_size: int = 32, epochs: int = 100, - patience: int = None, + patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -121,12 +131,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - self.n_steps = n_steps self.n_features = n_features # model hype-parameters @@ -274,7 +285,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/transformer/core.py b/pypots/imputation/transformer/core.py index e769a3aa..75ea8e2d 100644 --- a/pypots/imputation/transformer/core.py +++ b/pypots/imputation/transformer/core.py @@ -56,7 +56,7 @@ def __init__( # apply SAITS loss function to Transformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # apply the SAITS embedding strategy, concatenate X and missing mask for input @@ -76,7 +76,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func( reconstruction, X_ori, missing_mask, indicating_mask diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index d7d59097..9393ccff 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -82,6 +82,14 @@ class Transformer(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -129,6 +137,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -139,12 +149,13 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, model_saving_strategy, ) - if d_model != n_heads * d_k: logger.warning( "‼️ d_model must = n_heads * d_k, it should be divisible by n_heads " @@ -284,7 +295,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/usgan/core.py b/pypots/imputation/usgan/core.py index a0b04ade..5522821c 100644 --- a/pypots/imputation/usgan/core.py +++ b/pypots/imputation/usgan/core.py @@ -37,7 +37,6 @@ def forward( self, inputs: dict, training_object: str = "generator", - training: bool = True, ) -> dict: assert training_object in [ "generator", @@ -45,26 +44,18 @@ def forward( ], 'training_object should be "generator" or "discriminator"' results = {} - if training: + if self.training: if training_object == "discriminator": imputed_data, discrimination_loss = self.backbone( - inputs, training_object, training + inputs, training_object ) loss = discrimination_loss else: - imputed_data, generation_loss = self.backbone( - inputs, - training_object, - training, - ) + imputed_data, generation_loss = self.backbone(inputs, training_object) loss = generation_loss results["loss"] = loss else: - imputed_data = self.backbone( - inputs, - training_object, - training, - ) + imputed_data = self.backbone(inputs, training_object) results["imputed_data"] = imputed_data return results diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index 69f3bdd1..7c715758 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -114,6 +114,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: Optional[int] = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, G_optimizer: Optional[Optimizer] = Adam(), D_optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, @@ -125,6 +127,8 @@ def __init__( batch_size, epochs, patience, + train_loss_func, + val_metric_func, num_workers, device, saving_path, @@ -282,7 +286,7 @@ def _train_model( with torch.no_grad(): for idx, data in enumerate(val_loader): inputs = self._assemble_input_for_validating(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_mse = ( calc_mse( results["imputed_data"], @@ -425,7 +429,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/nn/modules/csdi/backbone.py b/pypots/nn/modules/csdi/backbone.py index 26051060..ca6bf27b 100644 --- a/pypots/nn/modules/csdi/backbone.py +++ b/pypots/nn/modules/csdi/backbone.py @@ -83,23 +83,19 @@ def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): return total_input - def calc_loss_valid( - self, observed_data, cond_mask, indicating_mask, side_info, is_train - ): + def calc_loss_valid(self, observed_data, cond_mask, indicating_mask, side_info): loss_sum = 0 for t in range(self.n_diffusion_steps): # calculate loss for all t loss = self.calc_loss( - observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=t + observed_data, cond_mask, indicating_mask, side_info, set_t=t ) loss_sum += loss.detach() return loss_sum / self.n_diffusion_steps - def calc_loss( - self, observed_data, cond_mask, indicating_mask, side_info, is_train, set_t=-1 - ): + def calc_loss(self, observed_data, cond_mask, indicating_mask, side_info, set_t=-1): B, K, L = observed_data.shape device = observed_data.device - if is_train != 1: # for validation + if self.training != 1: # for validation t = (torch.ones(B) * set_t).long().to(device) else: t = torch.randint(0, self.n_diffusion_steps, [B]).to(device) diff --git a/pypots/nn/modules/usgan/backbone.py b/pypots/nn/modules/usgan/backbone.py index 42d7f430..fdc5bcbd 100644 --- a/pypots/nn/modules/usgan/backbone.py +++ b/pypots/nn/modules/usgan/backbone.py @@ -43,7 +43,6 @@ def forward( self, inputs: dict, training_object: str = "generator", - training: bool = True, ) -> Tuple[torch.Tensor, ...]: ( imputed_data, @@ -56,7 +55,7 @@ def forward( ) = self.generator(inputs) # if in training mode, return results with losses - if training: + if self.training: forward_X = inputs["forward"]["X"] forward_missing_mask = inputs["forward"]["missing_mask"] From 620a95c610be65347c13e3595722d18e1206e569 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 8 May 2024 18:50:33 +0800 Subject: [PATCH 2/9] refactor: move pypots.utils.metrics to pypots.nn.functional; --- pypots/classification/base.py | 2 +- pypots/clustering/crli/core.py | 2 +- pypots/clustering/vader/core.py | 2 +- pypots/clustering/vader/model.py | 2 +- pypots/forecasting/base.py | 2 +- pypots/imputation/base.py | 2 +- pypots/imputation/gpvae/data.py | 1 + pypots/imputation/gpvae/model.py | 2 +- pypots/imputation/mrnn/core.py | 2 +- .../nonstationary_transformer/core.py | 2 +- pypots/imputation/saits/core.py | 2 +- pypots/imputation/saits/model.py | 2 +- pypots/imputation/timesnet/core.py | 2 +- pypots/imputation/transformer/model.py | 2 +- pypots/imputation/usgan/model.py | 2 +- pypots/nn/functional/__init__.py | 49 +++++++++++++++++++ .../functional}/classification.py | 0 .../metrics => nn/functional}/clustering.py | 0 .../{utils/metrics => nn/functional}/error.py | 0 pypots/nn/modules/film/__init__.py | 2 +- pypots/nn/modules/transformer/attention.py | 2 +- pypots/optim/lr_scheduler/__init__.py | 9 ++-- pypots/utils/metrics/__init__.py | 13 +++-- 23 files changed, 80 insertions(+), 24 deletions(-) rename pypots/{utils/metrics => nn/functional}/classification.py (100%) rename pypots/{utils/metrics => nn/functional}/clustering.py (100%) rename pypots/{utils/metrics => nn/functional}/error.py (100%) diff --git a/pypots/classification/base.py b/pypots/classification/base.py index fef6675b..81340288 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -15,8 +15,8 @@ from torch.utils.data import DataLoader from ..base import BaseModel, BaseNNModel +from ..nn.functional import calc_acc from ..utils.logging import logger -from ..utils.metrics import calc_acc try: import nni diff --git a/pypots/clustering/crli/core.py b/pypots/clustering/crli/core.py index 74bc7605..fb90ad39 100644 --- a/pypots/clustering/crli/core.py +++ b/pypots/clustering/crli/core.py @@ -13,8 +13,8 @@ import torch.nn.functional as F from sklearn.cluster import KMeans +from ...nn.functional import calc_mse from ...nn.modules.crli import BackboneCRLI -from ...utils.metrics import calc_mse class _CRLI(nn.Module): diff --git a/pypots/clustering/vader/core.py b/pypots/clustering/vader/core.py index bb19c1e3..0268c7cd 100644 --- a/pypots/clustering/vader/core.py +++ b/pypots/clustering/vader/core.py @@ -11,8 +11,8 @@ import torch import torch.nn as nn +from ...nn.functional import calc_mse from ...nn.modules.vader import BackboneVaDER -from ...utils.metrics import calc_mse def inverse_softplus(x: np.ndarray) -> np.ndarray: diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index d44257db..caff4453 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -16,8 +16,8 @@ from sklearn.mixture import GaussianMixture from torch.utils.data import DataLoader -from .data import DatasetForVaDER from .core import inverse_softplus, _VaDER +from .data import DatasetForVaDER from ..base import BaseNNClusterer from ...optim.adam import Adam from ...optim.base import Optimizer diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index fb88f7f0..427544b4 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader from ..base import BaseModel, BaseNNModel +from ..nn.functional import calc_mse from ..utils.logging import logger -from ..utils.metrics.error import calc_mse try: import nni diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index a7335a6c..828c7b58 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -15,8 +15,8 @@ from torch.utils.data import DataLoader from ..base import BaseModel, BaseNNModel +from ..nn.functional import calc_mse from ..utils.logging import logger -from ..utils.metrics import calc_mse try: import nni diff --git a/pypots/imputation/gpvae/data.py b/pypots/imputation/gpvae/data.py index af61ace3..4b00a10a 100644 --- a/pypots/imputation/gpvae/data.py +++ b/pypots/imputation/gpvae/data.py @@ -9,6 +9,7 @@ import torch from pygrinder import fill_and_get_mask_torch + from ...data.dataset import BaseDataset diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 0ef68d40..b85ea782 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -27,7 +27,7 @@ from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import calc_mse +from ...nn.functional import calc_mse class GPVAE(BaseNNImputer): diff --git a/pypots/imputation/mrnn/core.py b/pypots/imputation/mrnn/core.py index 2092d10e..b9490c29 100644 --- a/pypots/imputation/mrnn/core.py +++ b/pypots/imputation/mrnn/core.py @@ -9,8 +9,8 @@ import torch.nn as nn +from ...nn.functional import calc_rmse from ...nn.modules.mrnn import BackboneMRNN -from ...utils.metrics import calc_rmse class _MRNN(nn.Module): diff --git a/pypots/imputation/nonstationary_transformer/core.py b/pypots/imputation/nonstationary_transformer/core.py index cf17c8f8..cf631777 100644 --- a/pypots/imputation/nonstationary_transformer/core.py +++ b/pypots/imputation/nonstationary_transformer/core.py @@ -8,12 +8,12 @@ import torch.nn as nn +from ...nn.functional import nonstationary_norm, nonstationary_denorm from ...nn.modules.nonstationary_transformer import ( NonstationaryTransformerEncoder, Projector, ) from ...nn.modules.saits import SaitsLoss, SaitsEmbedding -from ...nn.functional.normalization import nonstationary_norm, nonstationary_denorm class _NonstationaryTransformer(nn.Module): diff --git a/pypots/imputation/saits/core.py b/pypots/imputation/saits/core.py index 93f79aea..9c7f327b 100644 --- a/pypots/imputation/saits/core.py +++ b/pypots/imputation/saits/core.py @@ -12,8 +12,8 @@ import torch import torch.nn as nn +from ...nn.functional import calc_mae from ...nn.modules.saits import BackboneSAITS -from ...utils.metrics import calc_mae class _SAITS(nn.Module): diff --git a/pypots/imputation/saits/model.py b/pypots/imputation/saits/model.py index bf523e05..6340503f 100644 --- a/pypots/imputation/saits/model.py +++ b/pypots/imputation/saits/model.py @@ -17,10 +17,10 @@ from ..base import BaseNNImputer from ...data.checking import key_in_data_set from ...data.dataset import BaseDataset +from ...nn.functional import calc_mae, calc_mse from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import calc_mae, calc_mse class SAITS(BaseNNImputer): diff --git a/pypots/imputation/timesnet/core.py b/pypots/imputation/timesnet/core.py index 10503bbb..cd49ba30 100644 --- a/pypots/imputation/timesnet/core.py +++ b/pypots/imputation/timesnet/core.py @@ -7,10 +7,10 @@ import torch.nn as nn +from ...nn.functional import calc_mse from ...nn.functional import nonstationary_norm, nonstationary_denorm from ...nn.modules.timesnet import BackboneTimesNet from ...nn.modules.transformer.embedding import DataEmbedding -from ...utils.metrics import calc_mse class _TimesNet(nn.Module): diff --git a/pypots/imputation/transformer/model.py b/pypots/imputation/transformer/model.py index 9393ccff..d2676194 100644 --- a/pypots/imputation/transformer/model.py +++ b/pypots/imputation/transformer/model.py @@ -15,8 +15,8 @@ from .core import _Transformer from .data import DatasetForTransformer from ..base import BaseNNImputer -from ...data.dataset import BaseDataset from ...data.checking import key_in_data_set +from ...data.dataset import BaseDataset from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index cb890c95..8b497d30 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -17,10 +17,10 @@ from .data import DatasetForUSGAN from ..base import BaseNNImputer from ...data.checking import key_in_data_set +from ...nn.functional import calc_mse from ...optim.adam import Adam from ...optim.base import Optimizer from ...utils.logging import logger -from ...utils.metrics import calc_mse try: import nni diff --git a/pypots/nn/functional/__init__.py b/pypots/nn/functional/__init__.py index 36df2bc6..368260ab 100644 --- a/pypots/nn/functional/__init__.py +++ b/pypots/nn/functional/__init__.py @@ -5,10 +5,59 @@ # Created by Wenjie Du # License: BSD-3-Clause +from .classification import ( + calc_binary_classification_metrics, + calc_precision_recall_f1, + calc_pr_auc, + calc_roc_auc, + calc_acc, +) +from .clustering import ( + calc_rand_index, + calc_adjusted_rand_index, + calc_cluster_purity, + calc_nmi, + calc_chs, + calc_dbs, + calc_silhouette, + calc_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, +) +from .error import ( + calc_mae, + calc_mse, + calc_rmse, + calc_mre, + calc_quantile_crps, + calc_quantile_crps_sum, +) from .normalization import nonstationary_norm, nonstationary_denorm __all__ = [ # normalization functions "nonstationary_norm", "nonstationary_denorm", + # error + "calc_mae", + "calc_mse", + "calc_rmse", + "calc_mre", + "calc_quantile_crps", + "calc_quantile_crps_sum", + # classification + "calc_binary_classification_metrics", + "calc_precision_recall_f1", + "calc_pr_auc", + "calc_roc_auc", + "calc_acc", + # clustering + "calc_rand_index", + "calc_adjusted_rand_index", + "calc_cluster_purity", + "calc_nmi", + "calc_chs", + "calc_dbs", + "calc_silhouette", + "calc_internal_cluster_validation_metrics", + "calc_external_cluster_validation_metrics", ] diff --git a/pypots/utils/metrics/classification.py b/pypots/nn/functional/classification.py similarity index 100% rename from pypots/utils/metrics/classification.py rename to pypots/nn/functional/classification.py diff --git a/pypots/utils/metrics/clustering.py b/pypots/nn/functional/clustering.py similarity index 100% rename from pypots/utils/metrics/clustering.py rename to pypots/nn/functional/clustering.py diff --git a/pypots/utils/metrics/error.py b/pypots/nn/functional/error.py similarity index 100% rename from pypots/utils/metrics/error.py rename to pypots/nn/functional/error.py diff --git a/pypots/nn/modules/film/__init__.py b/pypots/nn/modules/film/__init__.py index 4f97f20b..8d12cd09 100644 --- a/pypots/nn/modules/film/__init__.py +++ b/pypots/nn/modules/film/__init__.py @@ -17,8 +17,8 @@ # License: BSD-3-Clause -from .layers import HiPPO_LegT, SpectralConv1d from .backbone import BackboneFiLM +from .layers import HiPPO_LegT, SpectralConv1d __all__ = [ "HiPPO_LegT", diff --git a/pypots/nn/modules/transformer/attention.py b/pypots/nn/modules/transformer/attention.py index ecc4f85e..20f1c3f9 100644 --- a/pypots/nn/modules/transformer/attention.py +++ b/pypots/nn/modules/transformer/attention.py @@ -11,12 +11,12 @@ # Created by Wenjie Du # License: BSD-3-Clause +from abc import abstractmethod from typing import Tuple, Optional import torch import torch.nn as nn import torch.nn.functional as F -from abc import abstractmethod class AttentionOperator(nn.Module): diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py index 02d5cfe4..1a90802e 100644 --- a/pypots/optim/lr_scheduler/__init__.py +++ b/pypots/optim/lr_scheduler/__init__.py @@ -9,14 +9,13 @@ # Created by Wenjie Du # License: BSD-3-Clause -from .lambda_lrs import LambdaLR -from .multiplicative_lrs import MultiplicativeLR -from .step_lrs import StepLR -from .multistep_lrs import MultiStepLR from .constant_lrs import ConstantLR from .exponential_lrs import ExponentialLR +from .lambda_lrs import LambdaLR from .linear_lrs import LinearLR - +from .multiplicative_lrs import MultiplicativeLR +from .multistep_lrs import MultiStepLR +from .step_lrs import StepLR __all__ = [ "LambdaLR", diff --git a/pypots/utils/metrics/__init__.py b/pypots/utils/metrics/__init__.py index ed309ce3..5453993c 100644 --- a/pypots/utils/metrics/__init__.py +++ b/pypots/utils/metrics/__init__.py @@ -5,14 +5,16 @@ # Created by Wenjie Du # License: BSD-3-Clause -from .classification import ( + +from ..logging import logger +from ...nn.functional.classification import ( calc_binary_classification_metrics, calc_precision_recall_f1, calc_pr_auc, calc_roc_auc, calc_acc, ) -from .clustering import ( +from ...nn.functional.clustering import ( calc_rand_index, calc_adjusted_rand_index, calc_cluster_purity, @@ -23,7 +25,7 @@ calc_internal_cluster_validation_metrics, calc_external_cluster_validation_metrics, ) -from .error import ( +from ...nn.functional.error import ( calc_mae, calc_mse, calc_rmse, @@ -32,6 +34,11 @@ calc_quantile_crps_sum, ) +logger.warning( + "🚨 Importing metrics from pypots.utils.metrics is deprecated. " + "Please import from pypots.nn.functional instead." +) + __all__ = [ # error "calc_mae", From c7dc6beaf401fbdcf4175a3fe4c271b81cc53a03 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 9 May 2024 12:27:55 +0800 Subject: [PATCH 3/9] feat: add BaseLoss and BaseMetric; --- pypots/nn/modules/loss.py | 28 ++++++++++++++++++++++++++++ pypots/nn/modules/metric.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 pypots/nn/modules/loss.py create mode 100644 pypots/nn/modules/metric.py diff --git a/pypots/nn/modules/loss.py b/pypots/nn/modules/loss.py new file mode 100644 index 00000000..0868d2ca --- /dev/null +++ b/pypots/nn/modules/loss.py @@ -0,0 +1,28 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from .metric import BaseMetric +from ..functional import calc_mse + + +class BaseLoss(BaseMetric): + def __init__( + self, + ): + super().__init__() + + def forward(self, prediction, target): + raise NotImplementedError + + +class MAE_Loss(BaseLoss): + def __init__(self): + super().__init__() + + def forward(self, prediction, target, mask=None): + return calc_mse(prediction, target, mask) diff --git a/pypots/nn/modules/metric.py b/pypots/nn/modules/metric.py new file mode 100644 index 00000000..faea3d78 --- /dev/null +++ b/pypots/nn/modules/metric.py @@ -0,0 +1,29 @@ +""" + +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +import torch.nn as nn + +from ..functional import calc_pr_auc + + +class BaseMetric(nn.Module): + def __init__(self, lower_better: bool = True): + super().__init__() + self.lower_better = lower_better + + def forward(self, prediction, target): + raise NotImplementedError + + +class PR_AUC(BaseMetric): + def __init__(self): + super().__init__(lower_better=False) + + def forward(self, prediction, target): + pr_auc, _, _, _ = calc_pr_auc(prediction, target) + return pr_auc From f674a3c25ea4d7affe251554dbe896b8a66d79eb Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 27 Sep 2024 14:30:49 +0800 Subject: [PATCH 4/9] refactor: make FITS able to apply customized loss func; --- pypots/imputation/fits/core.py | 4 +-- pypots/imputation/fits/model.py | 53 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/pypots/imputation/fits/core.py b/pypots/imputation/fits/core.py index 701ec4ca..ba3f2661 100644 --- a/pypots/imputation/fits/core.py +++ b/pypots/imputation/fits/core.py @@ -46,7 +46,7 @@ def __init__( self.output_projection = nn.Linear(n_features, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/fits/model.py b/pypots/imputation/fits/model.py index 2664da26..b5c9dc7a 100644 --- a/pypots/imputation/fits/model.py +++ b/pypots/imputation/fits/model.py @@ -33,26 +33,11 @@ class FITS(BaseNNImputer): n_features : The number of features in the time-series data sample. - n_layers : - The number of layers in the FITS model. + cut_freq : + The cut-off frequency for the Fourier transformation. - d_model : - The dimension of the model. - - n_heads : - The number of heads in each layer of FITS. - - d_ffn : - The dimension of the feed-forward network. - - factor : - The factor of the auto correlation mechanism for the FITS model. - - moving_avg_window_size : - The window size of moving average. - - dropout : - The dropout rate for the model. + individual : + Whether to use individual Fourier transformation for each feature. ORT_weight : The weight for the ORT loss, the same as SAITS. @@ -71,6 +56,14 @@ class FITS(BaseNNImputer): stopped when the model does not perform better after that number of epochs. Leaving it default as None will disable the early-stopping. + train_loss_func: + The customized loss function designed by users for training the model. + If not given, will use the default loss as claimed in the original paper. + + val_metric_func: + The customized metric function designed by users for validating the model. + If not given, will use the default MSE metric. + optimizer : The optimizer for model training. If not given, will use a default Adam optimizer. @@ -115,6 +108,8 @@ def __init__( batch_size: int = 32, epochs: int = 100, patience: int = None, + train_loss_func: Optional[dict] = None, + val_metric_func: Optional[dict] = None, optimizer: Optional[Optimizer] = Adam(), num_workers: int = 0, device: Optional[Union[str, torch.device, list]] = None, @@ -123,14 +118,16 @@ def __init__( verbose: bool = True, ): super().__init__( - batch_size, - epochs, - patience, - num_workers, - device, - saving_path, - model_saving_strategy, - verbose, + batch_size=batch_size, + epochs=epochs, + patience=patience, + train_loss_func=train_loss_func, + val_metric_func=val_metric_func, + num_workers=num_workers, + device=device, + saving_path=saving_path, + model_saving_strategy=model_saving_strategy, + verbose=verbose, ) self.n_steps = n_steps @@ -272,7 +269,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return From 40e9602f6de20dc75bcd5f65388fe9fb1ec02809 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Fri, 27 Sep 2024 15:42:52 +0800 Subject: [PATCH 5/9] refactor: replace arg training with self attribute in new added models; --- pypots/imputation/grud/core.py | 4 ++-- pypots/imputation/grud/model.py | 2 +- pypots/imputation/imputeformer/core.py | 4 ++-- pypots/imputation/imputeformer/model.py | 2 +- pypots/imputation/koopa/model.py | 2 +- pypots/imputation/micn/core.py | 4 ++-- pypots/imputation/micn/model.py | 2 +- pypots/imputation/moderntcn/core.py | 4 ++-- pypots/imputation/moderntcn/model.py | 2 +- pypots/imputation/reformer/core.py | 4 ++-- pypots/imputation/reformer/model.py | 2 +- pypots/imputation/revinscinet/core.py | 4 ++-- pypots/imputation/revinscinet/model.py | 2 +- pypots/imputation/scinet/core.py | 4 ++-- pypots/imputation/scinet/model.py | 2 +- pypots/imputation/stemgnn/core.py | 4 ++-- pypots/imputation/stemgnn/model.py | 2 +- pypots/imputation/tcn/core.py | 4 ++-- pypots/imputation/tcn/model.py | 2 +- pypots/imputation/tefn/core.py | 5 +++-- pypots/imputation/tefn/model.py | 2 +- pypots/imputation/tide/core.py | 4 ++-- pypots/imputation/tide/model.py | 2 +- pypots/imputation/timemixer/core.py | 5 +++-- pypots/imputation/timemixer/model.py | 2 +- 25 files changed, 39 insertions(+), 37 deletions(-) diff --git a/pypots/imputation/grud/core.py b/pypots/imputation/grud/core.py index 98f368e0..713259c2 100644 --- a/pypots/imputation/grud/core.py +++ b/pypots/imputation/grud/core.py @@ -33,7 +33,7 @@ def __init__( ) self.output_projection = nn.Linear(rnn_hidden_size, n_features) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: """Forward processing of GRU-D. Parameters @@ -66,7 +66,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: results["loss"] = calc_mse(reconstruction, X, missing_mask) return results diff --git a/pypots/imputation/grud/model.py b/pypots/imputation/grud/model.py index 008408a9..b2ea4a0d 100644 --- a/pypots/imputation/grud/model.py +++ b/pypots/imputation/grud/model.py @@ -225,7 +225,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/imputeformer/core.py b/pypots/imputation/imputeformer/core.py index ceb81630..46d49aed 100644 --- a/pypots/imputation/imputeformer/core.py +++ b/pypots/imputation/imputeformer/core.py @@ -92,7 +92,7 @@ def __init__( # apply SAITS loss function to Transformer on the imputation task self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: x, missing_mask = inputs["X"], inputs["missing_mask"] # x: (batch_size, in_steps, num_nodes) @@ -132,7 +132,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/imputeformer/model.py b/pypots/imputation/imputeformer/model.py index 2c1e229c..334badd9 100644 --- a/pypots/imputation/imputeformer/model.py +++ b/pypots/imputation/imputeformer/model.py @@ -283,7 +283,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputed_data = results["imputed_data"] imputation_collector.append(imputed_data) diff --git a/pypots/imputation/koopa/model.py b/pypots/imputation/koopa/model.py index 76c72c1f..ab780553 100644 --- a/pypots/imputation/koopa/model.py +++ b/pypots/imputation/koopa/model.py @@ -295,7 +295,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/micn/core.py b/pypots/imputation/micn/core.py index 11bfa394..153413e2 100644 --- a/pypots/imputation/micn/core.py +++ b/pypots/imputation/micn/core.py @@ -60,7 +60,7 @@ def __init__( # for the imputation task, the output dim is the same as input dim self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] seasonal_init, trend_init = self.decomp_multi(X) @@ -82,7 +82,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/micn/model.py b/pypots/imputation/micn/model.py index 3456d539..59061d93 100644 --- a/pypots/imputation/micn/model.py +++ b/pypots/imputation/micn/model.py @@ -276,7 +276,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/moderntcn/core.py b/pypots/imputation/moderntcn/core.py index 3ca8e8f5..4c64d2b8 100644 --- a/pypots/imputation/moderntcn/core.py +++ b/pypots/imputation/moderntcn/core.py @@ -66,7 +66,7 @@ def __init__( individual, ) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -88,7 +88,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: loss = calc_mse(reconstruction, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss diff --git a/pypots/imputation/moderntcn/model.py b/pypots/imputation/moderntcn/model.py index 68ed84ba..145a70d8 100644 --- a/pypots/imputation/moderntcn/model.py +++ b/pypots/imputation/moderntcn/model.py @@ -306,7 +306,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/reformer/core.py b/pypots/imputation/reformer/core.py index ec55c7ad..6c74f01d 100644 --- a/pypots/imputation/reformer/core.py +++ b/pypots/imputation/reformer/core.py @@ -54,7 +54,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original Reformer paper isn't proposed for imputation task. Hence the model doesn't take @@ -75,7 +75,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/reformer/model.py b/pypots/imputation/reformer/model.py index 072c6894..c61f6a85 100644 --- a/pypots/imputation/reformer/model.py +++ b/pypots/imputation/reformer/model.py @@ -295,7 +295,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/revinscinet/core.py b/pypots/imputation/revinscinet/core.py index 16d199d3..a3188d5d 100644 --- a/pypots/imputation/revinscinet/core.py +++ b/pypots/imputation/revinscinet/core.py @@ -59,7 +59,7 @@ def __init__( # for the imputation task, the output dim is the same as input dim self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] X = self.revin(X, missing_mask, mode="norm") @@ -80,7 +80,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/revinscinet/model.py b/pypots/imputation/revinscinet/model.py index 9056f552..4d10967c 100644 --- a/pypots/imputation/revinscinet/model.py +++ b/pypots/imputation/revinscinet/model.py @@ -300,7 +300,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/scinet/core.py b/pypots/imputation/scinet/core.py index 4d2b02a1..1df706c7 100644 --- a/pypots/imputation/scinet/core.py +++ b/pypots/imputation/scinet/core.py @@ -57,7 +57,7 @@ def __init__( # for the imputation task, the output dim is the same as input dim self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original SCINet paper isn't proposed for imputation task. Hence the model doesn't take @@ -76,7 +76,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/scinet/model.py b/pypots/imputation/scinet/model.py index 4dd0fa27..29dac999 100644 --- a/pypots/imputation/scinet/model.py +++ b/pypots/imputation/scinet/model.py @@ -302,7 +302,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/stemgnn/core.py b/pypots/imputation/stemgnn/core.py index d8d51efb..650e7bb5 100644 --- a/pypots/imputation/stemgnn/core.py +++ b/pypots/imputation/stemgnn/core.py @@ -48,7 +48,7 @@ def __init__( self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original StemGNN paper isn't proposed for imputation task. Hence the model doesn't take @@ -69,7 +69,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/stemgnn/model.py b/pypots/imputation/stemgnn/model.py index ecee2c80..2f75f4d6 100644 --- a/pypots/imputation/stemgnn/model.py +++ b/pypots/imputation/stemgnn/model.py @@ -276,7 +276,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/tcn/core.py b/pypots/imputation/tcn/core.py index c38390b5..7274af14 100644 --- a/pypots/imputation/tcn/core.py +++ b/pypots/imputation/tcn/core.py @@ -45,7 +45,7 @@ def __init__( self.output_projection = nn.Linear(channel_sizes[-1], n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # WDU: the original TCN paper isn't proposed for imputation task. Hence the model doesn't take @@ -68,7 +68,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/tcn/model.py b/pypots/imputation/tcn/model.py index 2b33251a..28e987f3 100644 --- a/pypots/imputation/tcn/model.py +++ b/pypots/imputation/tcn/model.py @@ -270,7 +270,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/tefn/core.py b/pypots/imputation/tefn/core.py index f71927a6..ca11825e 100644 --- a/pypots/imputation/tefn/core.py +++ b/pypots/imputation/tefn/core.py @@ -32,7 +32,7 @@ def __init__( n_fod, ) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -51,7 +51,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: "imputed_data": imputed_data, } - if training: + # if in training mode, return results with losses + if self.training: # `loss` is always the item for backward propagating to update the model loss = calc_mse(out, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss diff --git a/pypots/imputation/tefn/model.py b/pypots/imputation/tefn/model.py index 2925d8a6..6d55bd7c 100644 --- a/pypots/imputation/tefn/model.py +++ b/pypots/imputation/tefn/model.py @@ -250,7 +250,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/tide/core.py b/pypots/imputation/tide/core.py index e826cbeb..876b3ec4 100644 --- a/pypots/imputation/tide/core.py +++ b/pypots/imputation/tide/core.py @@ -82,7 +82,7 @@ def __init__( # self.output_projection = nn.Linear(d_model, n_features) self.saits_loss_func = SaitsLoss(ORT_weight, MIT_weight) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] # # WDU: the original TiDE paper isn't proposed for imputation task. Hence the model doesn't take @@ -112,7 +112,7 @@ def forward(self, inputs: dict, training: bool = True) -> dict: } # if in training mode, return results with losses - if training: + if self.training: X_ori, indicating_mask = inputs["X_ori"], inputs["indicating_mask"] loss, ORT_loss, MIT_loss = self.saits_loss_func(reconstruction, X_ori, missing_mask, indicating_mask) results["ORT_loss"] = ORT_loss diff --git a/pypots/imputation/tide/model.py b/pypots/imputation/tide/model.py index 7b14a5a6..693e5c5d 100644 --- a/pypots/imputation/tide/model.py +++ b/pypots/imputation/tide/model.py @@ -282,7 +282,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return diff --git a/pypots/imputation/timemixer/core.py b/pypots/imputation/timemixer/core.py index c094d2ff..f988c6fc 100644 --- a/pypots/imputation/timemixer/core.py +++ b/pypots/imputation/timemixer/core.py @@ -56,7 +56,7 @@ def __init__( use_future_temporal_feature=False, ) - def forward(self, inputs: dict, training: bool = True) -> dict: + def forward(self, inputs: dict) -> dict: X, missing_mask = inputs["X"], inputs["missing_mask"] if self.apply_nonstationary_norm: @@ -75,7 +75,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict: "imputed_data": imputed_data, } - if training: + # if in training mode, return results with losses + if self.training: # `loss` is always the item for backward propagating to update the model loss = calc_mse(dec_out, inputs["X_ori"], inputs["indicating_mask"]) results["loss"] = loss diff --git a/pypots/imputation/timemixer/model.py b/pypots/imputation/timemixer/model.py index 89b24011..7ebf10ca 100644 --- a/pypots/imputation/timemixer/model.py +++ b/pypots/imputation/timemixer/model.py @@ -307,7 +307,7 @@ def predict( with torch.no_grad(): for idx, data in enumerate(test_loader): inputs = self._assemble_input_for_testing(data) - results = self.model.forward(inputs, training=False) + results = self.model.forward(inputs) imputation_collector.append(results["imputed_data"]) # Step 3: output collection and return From cc286eea7461434c9516500dc9cf48a8c0a030c5 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 30 Sep 2024 17:39:12 +0800 Subject: [PATCH 6/9] fix: globally replace importing from with ; --- pypots/imputation/grud/core.py | 2 +- pypots/imputation/moderntcn/core.py | 2 +- pypots/imputation/tefn/core.py | 2 +- pypots/imputation/timemixer/core.py | 2 +- pypots/imputation/timesnet/core.py | 2 +- pypots/nn/modules/brits/backbone.py | 2 +- pypots/nn/modules/saits/loss.py | 2 +- pypots/nn/modules/usgan/backbone.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pypots/imputation/grud/core.py b/pypots/imputation/grud/core.py index 713259c2..44f85e44 100644 --- a/pypots/imputation/grud/core.py +++ b/pypots/imputation/grud/core.py @@ -9,8 +9,8 @@ import torch.nn as nn +from ...nn.functional import calc_mse from ...nn.modules.grud import BackboneGRUD -from ...utils.metrics import calc_mse class _GRUD(nn.Module): diff --git a/pypots/imputation/moderntcn/core.py b/pypots/imputation/moderntcn/core.py index 4c64d2b8..91a43495 100644 --- a/pypots/imputation/moderntcn/core.py +++ b/pypots/imputation/moderntcn/core.py @@ -9,9 +9,9 @@ import torch.nn as nn from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.functional import calc_mse from ...nn.modules.moderntcn import BackboneModernTCN from ...nn.modules.patchtst.layers import FlattenHead -from ...utils.metrics import calc_mse class _ModernTCN(nn.Module): diff --git a/pypots/imputation/tefn/core.py b/pypots/imputation/tefn/core.py index ca11825e..c5bdc6d7 100644 --- a/pypots/imputation/tefn/core.py +++ b/pypots/imputation/tefn/core.py @@ -7,9 +7,9 @@ import torch.nn as nn +from ...nn.functional import calc_mse from ...nn.functional import nonstationary_norm, nonstationary_denorm from ...nn.modules.tefn import BackboneTEFN -from ...utils.metrics import calc_mse class _TEFN(nn.Module): diff --git a/pypots/imputation/timemixer/core.py b/pypots/imputation/timemixer/core.py index f988c6fc..8127e04a 100644 --- a/pypots/imputation/timemixer/core.py +++ b/pypots/imputation/timemixer/core.py @@ -11,8 +11,8 @@ nonstationary_norm, nonstationary_denorm, ) +from ...nn.functional import calc_mse from ...nn.modules.timemixer import BackboneTimeMixer -from ...utils.metrics import calc_mse class _TimeMixer(nn.Module): diff --git a/pypots/imputation/timesnet/core.py b/pypots/imputation/timesnet/core.py index 26f8424a..c4b203db 100644 --- a/pypots/imputation/timesnet/core.py +++ b/pypots/imputation/timesnet/core.py @@ -7,8 +7,8 @@ import torch.nn as nn -from ...nn.functional import calc_mse from ...nn.functional import nonstationary_norm, nonstationary_denorm +from ...nn.functional import calc_mse from ...nn.modules.timesnet import BackboneTimesNet from ...nn.modules.transformer.embedding import DataEmbedding diff --git a/pypots/nn/modules/brits/backbone.py b/pypots/nn/modules/brits/backbone.py index eef07cc2..c779cbc9 100644 --- a/pypots/nn/modules/brits/backbone.py +++ b/pypots/nn/modules/brits/backbone.py @@ -12,7 +12,7 @@ from .layers import FeatureRegression from ..grud.layers import TemporalDecay -from ....utils.metrics import calc_mae +from ....nn.functional import calc_mae class BackboneRITS(nn.Module): diff --git a/pypots/nn/modules/saits/loss.py b/pypots/nn/modules/saits/loss.py index 0052dce2..dc19ad4a 100644 --- a/pypots/nn/modules/saits/loss.py +++ b/pypots/nn/modules/saits/loss.py @@ -10,7 +10,7 @@ import torch.nn as nn -from ....utils.metrics import calc_mae +from ....nn.functional import calc_mae class SaitsLoss(nn.Module): diff --git a/pypots/nn/modules/usgan/backbone.py b/pypots/nn/modules/usgan/backbone.py index 9b7fa079..4ecbfef5 100644 --- a/pypots/nn/modules/usgan/backbone.py +++ b/pypots/nn/modules/usgan/backbone.py @@ -13,7 +13,7 @@ from .layers import UsganDiscriminator from ..brits import BackboneBRITS -from ....utils.metrics import calc_mse +from ....nn.functional import calc_mse class BackboneUSGAN(nn.Module): From ba4588ce16c030f864b407fe963680e212da33e8 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Mon, 30 Sep 2024 20:51:26 +0800 Subject: [PATCH 7/9] refactor: still keep pypots.utils.metrics for future compatibility; --- pypots/utils/metrics/classification.py | 23 +++++++++++++++++++++++ pypots/utils/metrics/clustering.py | 24 ++++++++++++++++++++++++ pypots/utils/metrics/error.py | 21 +++++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 pypots/utils/metrics/classification.py create mode 100644 pypots/utils/metrics/clustering.py create mode 100644 pypots/utils/metrics/error.py diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py new file mode 100644 index 00000000..20481631 --- /dev/null +++ b/pypots/utils/metrics/classification.py @@ -0,0 +1,23 @@ +""" +Evaluation metrics related to classification. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + + +from ..logging import logger +from ...nn.functional.classification import * + +# pypots.utils.metrics.classification is deprecated, and moved to pypots.nn.functional.classification +logger.warning( + "🚨 Please import from pypots.nn.functional.classification instead of pypots.utils.metrics.classification" +) + +__all__ = [ + "calc_binary_classification_metrics", + "calc_precision_recall_f1", + "calc_pr_auc", + "calc_roc_auc", + "calc_acc", +] diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py new file mode 100644 index 00000000..eb6a668b --- /dev/null +++ b/pypots/utils/metrics/clustering.py @@ -0,0 +1,24 @@ +""" +Evaluation metrics related to clustering. +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from ..logging import logger +from ...nn.functional.clustering import * + +# pypots.utils.metrics.clustering is deprecated, and moved to pypots.nn.functional.clustering +logger.warning("🚨 Please import from pypots.nn.functional.clustering instead of pypots.utils.metrics.clustering") + +__all__ = [ + "calc_rand_index", + "calc_adjusted_rand_index", + "calc_cluster_purity", + "calc_nmi", + "calc_chs", + "calc_dbs", + "calc_silhouette", + "calc_internal_cluster_validation_metrics", + "calc_external_cluster_validation_metrics", +] diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py new file mode 100644 index 00000000..a25e7647 --- /dev/null +++ b/pypots/utils/metrics/error.py @@ -0,0 +1,21 @@ +""" +Evaluation metrics related to error calculation (like in tasks regression, imputation etc). +""" + +# Created by Wenjie Du +# License: BSD-3-Clause + +from ..logging import logger +from ...nn.functional.error import * + +# pypots.utils.metrics.error is deprecated, and moved to pypots.nn.functional.error +logger.warning("🚨 Please import from pypots.nn.functional.error instead of pypots.utils.metrics.error") + +__all__ = [ + "calc_mae", + "calc_mse", + "calc_rmse", + "calc_mre", + "calc_quantile_crps", + "calc_quantile_crps_sum", +] From 86340d0bb2fceebfe593bddd0a889f10297665c0 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Thu, 3 Oct 2024 23:10:33 +0800 Subject: [PATCH 8/9] refactor: do not expose by default; --- pypots/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/utils/__init__.py b/pypots/utils/__init__.py index d564bf31..24542ad4 100644 --- a/pypots/utils/__init__.py +++ b/pypots/utils/__init__.py @@ -10,7 +10,7 @@ # content files in this package "file", "logging", - "metrics", + # "metrics", # deprecated and everything is moved to nn.functional, hence do not import it by default "random", "visual", ] From b4b5b489c860894ccac7f07e3aa6e2842adb9a07 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Tue, 8 Oct 2024 21:07:05 +0800 Subject: [PATCH 9/9] refactor: remove lingting issues; --- pypots/utils/metrics/__init__.py | 5 +---- pypots/utils/metrics/classification.py | 8 +++++++- pypots/utils/metrics/clustering.py | 12 +++++++++++- pypots/utils/metrics/error.py | 9 ++++++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/pypots/utils/metrics/__init__.py b/pypots/utils/metrics/__init__.py index 5453993c..e3627f9c 100644 --- a/pypots/utils/metrics/__init__.py +++ b/pypots/utils/metrics/__init__.py @@ -34,10 +34,7 @@ calc_quantile_crps_sum, ) -logger.warning( - "🚨 Importing metrics from pypots.utils.metrics is deprecated. " - "Please import from pypots.nn.functional instead." -) +logger.warning("‼️ `pypots.utils.metrics` is deprecated. Please import from `pypots.nn.functional` instead.") __all__ = [ # error diff --git a/pypots/utils/metrics/classification.py b/pypots/utils/metrics/classification.py index 20481631..eaef8e68 100644 --- a/pypots/utils/metrics/classification.py +++ b/pypots/utils/metrics/classification.py @@ -7,7 +7,13 @@ from ..logging import logger -from ...nn.functional.classification import * +from ...nn.functional.classification import ( + calc_binary_classification_metrics, + calc_precision_recall_f1, + calc_pr_auc, + calc_roc_auc, + calc_acc, +) # pypots.utils.metrics.classification is deprecated, and moved to pypots.nn.functional.classification logger.warning( diff --git a/pypots/utils/metrics/clustering.py b/pypots/utils/metrics/clustering.py index eb6a668b..a37dba58 100644 --- a/pypots/utils/metrics/clustering.py +++ b/pypots/utils/metrics/clustering.py @@ -6,7 +6,17 @@ # License: BSD-3-Clause from ..logging import logger -from ...nn.functional.clustering import * +from ...nn.functional.clustering import ( + calc_rand_index, + calc_adjusted_rand_index, + calc_cluster_purity, + calc_nmi, + calc_chs, + calc_dbs, + calc_silhouette, + calc_internal_cluster_validation_metrics, + calc_external_cluster_validation_metrics, +) # pypots.utils.metrics.clustering is deprecated, and moved to pypots.nn.functional.clustering logger.warning("🚨 Please import from pypots.nn.functional.clustering instead of pypots.utils.metrics.clustering") diff --git a/pypots/utils/metrics/error.py b/pypots/utils/metrics/error.py index a25e7647..0ff62e28 100644 --- a/pypots/utils/metrics/error.py +++ b/pypots/utils/metrics/error.py @@ -6,7 +6,14 @@ # License: BSD-3-Clause from ..logging import logger -from ...nn.functional.error import * +from ...nn.functional.error import ( + calc_mae, + calc_mse, + calc_rmse, + calc_mre, + calc_quantile_crps, + calc_quantile_crps_sum, +) # pypots.utils.metrics.error is deprecated, and moved to pypots.nn.functional.error logger.warning("🚨 Please import from pypots.nn.functional.error instead of pypots.utils.metrics.error")