diff --git a/src/__init__.py b/tcn_hpl/__init__.py similarity index 100% rename from src/__init__.py rename to tcn_hpl/__init__.py diff --git a/src/data/__init__.py b/tcn_hpl/data/__init__.py similarity index 100% rename from src/data/__init__.py rename to tcn_hpl/data/__init__.py diff --git a/src/data/components/PTG_dataset.py b/tcn_hpl/data/components/PTG_dataset.py similarity index 100% rename from src/data/components/PTG_dataset.py rename to tcn_hpl/data/components/PTG_dataset.py diff --git a/src/data/components/__init__.py b/tcn_hpl/data/components/__init__.py similarity index 100% rename from src/data/components/__init__.py rename to tcn_hpl/data/components/__init__.py diff --git a/src/data/components/augmentations.py b/tcn_hpl/data/components/augmentations.py similarity index 100% rename from src/data/components/augmentations.py rename to tcn_hpl/data/components/augmentations.py diff --git a/src/data/mnist_datamodule.py b/tcn_hpl/data/mnist_datamodule.py similarity index 100% rename from src/data/mnist_datamodule.py rename to tcn_hpl/data/mnist_datamodule.py diff --git a/src/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py similarity index 100% rename from src/data/ptg_datamodule.py rename to tcn_hpl/data/ptg_datamodule.py diff --git a/src/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py similarity index 98% rename from src/data/utils/ptg_datagenerator.py rename to tcn_hpl/data/utils/ptg_datagenerator.py index 34ae8ef5f..df6c13647 100644 --- a/src/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -43,12 +43,12 @@ "test": [f"all_activities_{x}" for x in [20, 33, 39, 50, 51, 52, 53, 54]], } # Coffee specific -feat_version = 2 +feat_version = 4 ##################### # Output ##################### -exp_name = f"coffee_conf_10_all_hands_feat_v{str(feat_version)}_fixed" +exp_name = f"coffee_conf_10_all_hands_feat_v{str(feat_version)}" output_data_dir = f"{data_dir}/TCN_data/{exp_name}" if not os.path.exists(output_data_dir): os.makedirs(output_data_dir) diff --git a/src/eval.py b/tcn_hpl/eval.py similarity index 100% rename from src/eval.py rename to tcn_hpl/eval.py diff --git a/src/models/__init__.py b/tcn_hpl/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to tcn_hpl/models/__init__.py diff --git a/src/models/components/__init__.py b/tcn_hpl/models/components/__init__.py similarity index 100% rename from src/models/components/__init__.py rename to tcn_hpl/models/components/__init__.py diff --git a/src/models/components/focal_loss.py b/tcn_hpl/models/components/focal_loss.py similarity index 100% rename from src/models/components/focal_loss.py rename to tcn_hpl/models/components/focal_loss.py diff --git a/src/models/components/ms_tcs_net.py b/tcn_hpl/models/components/ms_tcs_net.py similarity index 100% rename from src/models/components/ms_tcs_net.py rename to tcn_hpl/models/components/ms_tcs_net.py diff --git a/src/models/components/simple_dense_net.py b/tcn_hpl/models/components/simple_dense_net.py similarity index 100% rename from src/models/components/simple_dense_net.py rename to tcn_hpl/models/components/simple_dense_net.py diff --git a/src/models/mnist_module.py b/tcn_hpl/models/mnist_module.py similarity index 100% rename from src/models/mnist_module.py rename to tcn_hpl/models/mnist_module.py diff --git a/src/models/ptg_module.py b/tcn_hpl/models/ptg_module.py similarity index 99% rename from src/models/ptg_module.py rename to tcn_hpl/models/ptg_module.py index 2178fe726..b92dbd3c0 100644 --- a/src/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -247,6 +247,8 @@ def on_validation_epoch_end(self) -> None: self.logger.experiment.track(Image(fig), name=f'CM Validation Epoch') + plt.close(fig) + self.validation_step_outputs_target.clear() self.validation_step_outputs_pred.clear() @@ -296,6 +298,8 @@ def on_test_epoch_end(self) -> None: self.logger.experiment.track(Image(fig), name=f'CM Test Epoch') + plt.close(fig) + self.validation_step_outputs_target.clear() self.validation_step_outputs_pred.clear() diff --git a/src/train.py b/tcn_hpl/train.py similarity index 100% rename from src/train.py rename to tcn_hpl/train.py diff --git a/src/utils/__init__.py b/tcn_hpl/utils/__init__.py similarity index 100% rename from src/utils/__init__.py rename to tcn_hpl/utils/__init__.py diff --git a/src/utils/instantiators.py b/tcn_hpl/utils/instantiators.py similarity index 100% rename from src/utils/instantiators.py rename to tcn_hpl/utils/instantiators.py diff --git a/src/utils/logging_utils.py b/tcn_hpl/utils/logging_utils.py similarity index 100% rename from src/utils/logging_utils.py rename to tcn_hpl/utils/logging_utils.py diff --git a/src/utils/pylogger.py b/tcn_hpl/utils/pylogger.py similarity index 100% rename from src/utils/pylogger.py rename to tcn_hpl/utils/pylogger.py diff --git a/src/utils/rich_utils.py b/tcn_hpl/utils/rich_utils.py similarity index 100% rename from src/utils/rich_utils.py rename to tcn_hpl/utils/rich_utils.py diff --git a/src/utils/utils.py b/tcn_hpl/utils/utils.py similarity index 100% rename from src/utils/utils.py rename to tcn_hpl/utils/utils.py