diff --git a/configs/data/ptg.yaml b/configs/data/ptg.yaml index 8c728b944..497e9e683 100644 --- a/configs/data/ptg.yaml +++ b/configs/data/ptg.yaml @@ -3,6 +3,7 @@ _target_: tcn_hpl.data.ptg_datamodule.PTGDataModule train_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: 15 + window_label_idx: -1 # A vectorizer is required to complete construction of a TCN Dataset. # We are not providing a default here given how hydra merged hyperparameters. # For example: @@ -21,6 +22,7 @@ train_dataset: val_dataset: _target_: tcn_hpl.data.tcn_dataset.TCNDataset window_size: ${data.train_dataset.window_size} + window_label_idx: ${data.train_dataset.window_label_idx} vectorize: ${data.train_dataset.vectorize} transform_frame_data: _target_: torchvision.transforms.Compose diff --git a/configs/experiment/m2/feat_locsconfs.yaml b/configs/experiment/m2/feat_locsconfs.yaml index 717200902..8e7f1ec0b 100644 --- a/configs/experiment/m2/feat_locsconfs.yaml +++ b/configs/experiment/m2/feat_locsconfs.yaml @@ -1,10 +1,5 @@ # @package _global_ -# to execute this experiment run: -# python train.py experiment=example -topic: "medical" -task: "m2" - defaults: - override /data: ptg - override /model: ptg @@ -33,17 +28,23 @@ tags: ["m2", "ms_tcn", "debug"] seed: 12345 +#callbacks: +# model_checkpoint: +# # save all ~80MB checkpoints for post-training investigation. +# # Total: ~45GB +# save_top_k: 500 + trainer: min_epochs: 50 max_epochs: 500 log_every_n_steps: 1 model: - num_classes: 9 + num_classes: 9 # number of activity classification classes compile: false net: # Length of feature vector for a single frame. - # Currently derived from feature version and other hyperparameters. + # Currently derived from the parameterization of dataset vectorizer. dim: 102 # # Once upon a time defaults @@ -67,11 +68,14 @@ data: batch_size: 512 num_workers: 16 target_framerate: 15 # BBN Hololens2 Framerate - # This is a little more than the number of windows in the training dataset. - epoch_length: 100000 + # This is a bit more than the number of windows in the training dataset so + # the weighted sampler has more of an opportunity to sample the space + # proportionally. + epoch_length: 300000 train_dataset: window_size: 25 + window_label_idx: ${model.pred_frame_index} vectorize: _target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs top_k: 1 @@ -93,11 +97,23 @@ data: pose_latency: 0.1 dets_throughput_std: 0.2 pose_throughput_std: 0.2 + fixed_pattern: false val_dataset: # Augmentations on windows of frame data before performing vectorization. # Sharing transform with training dataset as it is only the drop-out aug to # simulate stream processing dropout the same. - transform_frame_data: ${data.train_dataset.transform_frame_data} + transform_frame_data: + transforms: + - _target_: tcn_hpl.data.frame_data_aug.window_frame_dropout.DropoutFrameDataTransform + # Mirror training hparams, except used fixed patterns. + frame_rate: ${data.train_dataset.frame_rate} + dets_throughput_mean: ${data.train_dataset.dets_throughput_mean} + pose_throughput_mean: ${data.train_dataset.pose_throughput_mean} + dets_latency: ${data.train_dataset.dets_latency} + pose_latency: ${data.train_dataset.pose_latency} + dets_throughput_std: ${data.train_dataset.dets_throughput_std} + pose_throughput_std: ${data.train_dataset.pose_throughput_std} + fixed_pattern: true # Test dataset usually configured the same as val, unless there is some # different set of transforms that should be used during test/prediction. diff --git a/configs/experiment/r18/feat_locsconfs.yaml b/configs/experiment/r18/feat_locsconfs.yaml index 462266342..6ab708b4d 100644 --- a/configs/experiment/r18/feat_locsconfs.yaml +++ b/configs/experiment/r18/feat_locsconfs.yaml @@ -1,10 +1,5 @@ # @package _global_ -# to execute this experiment run: -# python train.py experiment=example -topic: "medical" -task: "r18" - defaults: - override /data: ptg - override /model: ptg @@ -67,11 +62,14 @@ data: batch_size: 512 num_workers: 16 target_framerate: 15 # BBN Hololens2 Framerate - # This is a little more than the number of windows in the training dataset. - epoch_length: 80000 + # This is a bit more than the number of windows in the training dataset so + # the weighted sampler has more of an opportunity to sample the space + # proportionally. + epoch_length: 300000 train_dataset: window_size: 25 + window_label_idx: ${model.pred_frame_index} vectorize: _target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs top_k: 1 @@ -93,11 +91,23 @@ data: pose_latency: 0.1 dets_throughput_std: 0.2 pose_throughput_std: 0.2 + fixed_pattern: false val_dataset: # Augmentations on windows of frame data before performing vectorization. # Sharing transform with training dataset as it is only the drop-out aug to # simulate stream processing dropout the same. - transform_frame_data: ${data.train_dataset.transform_frame_data} + transform_frame_data: + transforms: + - _target_: tcn_hpl.data.frame_data_aug.window_frame_dropout.DropoutFrameDataTransform + # Mirror training hparams, except used fixed patterns. + frame_rate: ${data.train_dataset.frame_rate} + dets_throughput_mean: ${data.train_dataset.dets_throughput_mean} + pose_throughput_mean: ${data.train_dataset.pose_throughput_mean} + dets_latency: ${data.train_dataset.dets_latency} + pose_latency: ${data.train_dataset.pose_latency} + dets_throughput_std: ${data.train_dataset.dets_throughput_std} + pose_throughput_std: ${data.train_dataset.pose_throughput_std} + fixed_pattern: true # Test dataset usually configured the same as val, unless there is some # different set of transforms that should be used during test/prediction. diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index 1c570a0ce..37613986c 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -38,7 +38,10 @@ smoothing_loss: 0.0015 use_smoothing_loss: False # Number of classes -num_classes: ${data.num_classes} +num_classes: 9 # compile model for faster training with pytorch 2.0 compile: false + +# Which frame in a window of predictions should represent the while window. +pred_frame_index: -1 diff --git a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py index f3e40d096..7dd98b476 100644 --- a/tcn_hpl/data/frame_data_aug/window_frame_dropout.py +++ b/tcn_hpl/data/frame_data_aug/window_frame_dropout.py @@ -43,6 +43,12 @@ class DropoutFrameDataTransform(torch.nn.Module): Standard deviation of the throughput rate for object detections. pose_throughput_std: Standard deviation of the throughput rate for pose estimations. + fixed_pattern: + Create a single, fixed dropout pattern to be applied to every + window based on the input throughput and latency mean values with + no random variation. This is idea to use for validation and test + dataset passes that require dropout simulation but do not want + random variation. """ def __init__( @@ -54,6 +60,7 @@ def __init__( pose_latency: Optional[float] = None, dets_throughput_std: float = 0.0, pose_throughput_std: float = 0.0, + fixed_pattern: bool = False, ): super().__init__() self.frame_rate = frame_rate @@ -68,6 +75,7 @@ def __init__( self.pose_latency = ( pose_latency if pose_latency is not None else 1.0 / pose_throughput_mean ) + self.fixed_pattern = fixed_pattern def forward(self, window: Sequence[FrameData]) -> List[FrameData]: # Starting from some latency back from the end of the window, start @@ -94,32 +102,41 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: # Define processing intervals (how often a frame is processed) # This cursed formatting is because of `black`. - dets_interval = ( - 1.0 - / torch.normal( - mean=self.dets_throughput_mean, - std=self.dets_throughput_std, - size=(n_frames,), - ).numpy() - ) - pose_interval = ( - 1.0 - / torch.normal( - mean=self.pose_throughput_mean, - std=self.pose_throughput_std, - size=(n_frames,), - ).numpy() - ) + if self.fixed_pattern: + dets_interval = 1.0 / np.full(n_frames, self.dets_throughput_mean) + pose_interval = 1.0 / np.full(n_frames, self.pose_throughput_mean) + # Fixed simulation of half-way into processing previous frame. + dets_initial_end = 0.5 * dets_interval[0] + pose_initial_end = 0.5 * pose_interval[0] + else: + dets_interval = ( + 1.0 + / torch.normal( + mean=self.dets_throughput_mean, + std=self.dets_throughput_std, + size=(n_frames,), + ).numpy() + ) + pose_interval = ( + 1.0 + / torch.normal( + mean=self.pose_throughput_mean, + std=self.pose_throughput_std, + size=(n_frames,), + ).numpy() + ) + dets_initial_end = torch.rand(1).item() * dets_interval[0] + pose_initial_end = torch.rand(1).item() * pose_interval[0] # Initialize end time trackers for processing detections and poses. # Simulate that agents may already be part-way through processing a # frame before the beginning of this window, utilizing the first value # from respective interval vectors. dets_processing_end = np.full( - n_frames + 1, torch.rand(1).item() * dets_interval[0] + n_frames + 1, dets_initial_end ) pose_processing_end = np.full( - n_frames + 1, torch.rand(1).item() * pose_interval[0] + n_frames + 1, pose_initial_end ) # Boolean arrays to keep track of whether a frame can be processed @@ -184,7 +201,8 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]: def test(): - import numpy as np + from IPython.core.getipython import get_ipython + import pandas as pd from tcn_hpl.data.frame_data import FrameObjectDetections, FramePoses frame1 = FrameData( @@ -216,18 +234,24 @@ def test(): pose_latency=1/10, # (1 / 10) - (1 / 14.5), dets_throughput_std=0.2, pose_throughput_std=0.2, + fixed_pattern=True, ) modified_sequence = transform(sequence) - for idx, frame in enumerate(modified_sequence): - print( - f"Frame {idx}: Object Detections: {frame.object_detections is not None}, Poses: {frame.poses is not None}" - ) - - from IPython import get_ipython + print( + pd.DataFrame({ + "object detections": [ + frame.object_detections is not None for frame in modified_sequence + ], + "pose estimation": [ + frame.poses is not None for frame in modified_sequence + ] + }) + ) ipython = get_ipython() - ipython.run_line_magic("timeit", "transform(sequence)") + if ipython is not None: + ipython.run_line_magic("timeit", "transform(sequence)") if __name__ == "__main__": diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index b632ac3e7..9d9ba070a 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -203,12 +203,19 @@ def val_dataloader(self) -> DataLoader[Any]: :return: The validation dataloader. """ + val_sampler = torch.utils.data.WeightedRandomSampler( + self.data_val.window_weights, + len(self.data_val) * 3, + replacement=True, + generator=None, + ) return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, - shuffle=False, + # shuffle=False, + sampler=val_sampler, ) def test_dataloader(self) -> DataLoader[Any]: diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index e89599ac7..bdded9000 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -51,25 +51,33 @@ class TCNDataset(Dataset): window_size: The size of the sliding window used to collect inputs from either a real-time or offline source. + vectorize: + Vectorization functor to convert frame data into an embedding + space. + window_label_idx: + Indicate which frame in a window of frames the window's truth label + should be drawn from. E.g. `-1` means assign the truth of the + window from the truth of the last frame in the window, `5` means + assign the truth of the window from the truth of the 5th frame in + the window, etc. transform_frame_data: Optional augmentation function that operates on a window of FrameData before being input to vectorization. Such an augmentation function should not modify the input FrameData. - vectorize: - Vectorization functor to convert frame data into an embedding - space. """ def __init__( self, window_size: int, vectorize: Vectorize, + window_label_idx: int = -1, transform_frame_data: Optional[ Callable[[Sequence[FrameData]], Sequence[FrameData]] ] = None, ): self.window_size = window_size self.vectorize = vectorize + self.window_label_idx = window_label_idx self.transform_frame_data = transform_frame_data # For offline mode, pre-cut videos into clips according to window @@ -376,7 +384,7 @@ class that has pose keypoints associated with it. The current # Collect for weighting the truth labels for the final frames of # windows, which is the truth value for the window as a whole. - window_final_class_ids = self._window_truth[:, -1] + window_final_class_ids = self._window_truth[:, self.window_label_idx] cls_ids, cls_counts = np.unique(window_final_class_ids, return_counts=True) # Some classes may not be represented in the truth, so initialize the # weights vector separately, and then assign weight values based on diff --git a/tcn_hpl/models/components/ms_tcs_net.py b/tcn_hpl/models/components/ms_tcs_net.py index 47da140c3..c9508f80a 100644 --- a/tcn_hpl/models/components/ms_tcs_net.py +++ b/tcn_hpl/models/components/ms_tcs_net.py @@ -71,6 +71,8 @@ def __init__(self, dims: Sequence[int], dropout_p): """ Simple linear skip connection block. + This happens to make use of GELU as the linear unit utilized. + :param dims: A number of internal dimensions, creating N*2 linear layers connecting each dimensional shift. :param dropout_p: P-value for the drop-out layers utilized. @@ -95,6 +97,45 @@ def forward(self, x): return x +class LinearResidual(nn.Module): + """ + Sequence of fully connected layers with residual connections. + + There is a single skip connection from the input to the output in order to + connect the two. + + The number of layer specified refers to the number of *internal* layers + between the linear layers that transforms the interface dimension + (input/output) and the layer dimensions. + + :param interface_dim: Dimension of input and output tensors. + :param layer_dim: Dimension of internal sequential residual layers. + :param n_layers: Number of internal layers. This should be 0 or greater. + """ + + def __init__( + self, + interface_dim: int, + layer_dim: int = 512, + n_layers: int = 5, + dropout_p: float = 0.25, + ): + super().__init__() + self.l_first = nn.Sequential(nn.Linear(interface_dim, layer_dim), nn.GELU(), nn.Dropout(dropout_p)) + self.l_inner = nn.ModuleList([ + nn.Sequential(nn.Linear(layer_dim, layer_dim), nn.GELU(), nn.Dropout(dropout_p)) + for i in range(n_layers) + ]) + self.l_last = nn.Sequential(nn.Linear(layer_dim, interface_dim), nn.GELU(), nn.Dropout(dropout_p)) + + def forward(self, x): + out = self.l_first(x) + for layer in self.l_inner: + out = layer(out) + out + out = self.l_last(out) + x + return out + + class SingleStageModel(nn.Module): def __init__(self, num_layers, num_f_maps, dim, num_classes): super(SingleStageModel, self).__init__() diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 5de396dd0..f396de9ef 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -56,6 +56,7 @@ def __init__( use_smoothing_loss: bool, num_classes: int, compile: bool, + pred_frame_index: int = -1, ) -> None: """Initialize a `PTGLitModule`. @@ -63,6 +64,11 @@ def __init__( :param criterion: Loss Computation :param optimizer: The optimizer to use for training. :param scheduler: The learning rate scheduler to use for training. + :param pred_frame_index: + Index of a frame in the window whose predicted class and + probabilities should represent the window as a whole. Negative + indices are valid. Must be a valid index into the window range + specified by the dataset """ super().__init__() @@ -86,6 +92,16 @@ def __init__( self.test_acc = Accuracy( task="multiclass", average="weighted", num_classes=num_classes ) + # Track per-class accuracy for separated logging + self.train_acc_perclass = Accuracy( + task="multiclass", average="none", num_classes=num_classes + ) + self.val_acc_perclass = Accuracy( + task="multiclass", average="none", num_classes=num_classes + ) + self.test_acc_perclass = Accuracy( + task="multiclass", average="none", num_classes=num_classes + ) self.train_f1 = F1Score( num_classes=num_classes, average="weighted", task="multiclass" @@ -96,6 +112,16 @@ def __init__( self.test_f1 = F1Score( num_classes=num_classes, average="weighted", task="multiclass" ) + # Track per-class F1 for separated logging + self.train_f1_perclass = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + self.val_f1_perclass = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_f1_perclass = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) self.train_recall = Recall( num_classes=num_classes, average="weighted", task="multiclass" @@ -235,10 +261,11 @@ def model_step( for p in logits: loss += self.compute_loss(p, y, m) + pred_frame_index = self.hparams.pred_frame_index probs = torch.softmax( - logits[-1, :, :, -1], dim=1 + logits[-1, :, :, pred_frame_index], dim=1 ) # shape (batch size, self.hparams.num_classes) - preds = torch.argmax(logits[-1, :, :, -1], dim=1) # shape: batch size + preds = torch.argmax(logits[-1, :, :, pred_frame_index], dim=1) # shape: batch size return loss, probs, preds, y, source_vid, source_frame @@ -260,7 +287,7 @@ def training_step( # update and log metrics self.train_loss(loss) - self.train_acc(preds, targets[:, -1]) + self.train_acc(preds, targets[:, self.hparams.pred_frame_index]) self.log( "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True @@ -269,14 +296,24 @@ def training_step( "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True ) + self.train_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) + for c_i, c_acc in enumerate(self.train_acc_perclass.compute()): + self.log( + f"train/acc-per-class/c{c_i}", + c_acc, + prog_bar=False, + on_step=False, + on_epoch=True, + ) + # return loss or backpropagation will fail return { "loss": loss, "preds": preds, "probs": probs, - "targets": targets[:, -1], - "source_vid": source_vid[:, -1], - "source_frame": source_frame[:, -1], + "targets": targets[:, self.hparams.pred_frame_index], + "source_vid": source_vid[:, self.hparams.pred_frame_index], + "source_frame": source_frame[:, self.hparams.pred_frame_index], } def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: @@ -289,6 +326,15 @@ def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: self.log("train/f1", self.train_f1, prog_bar=True, on_epoch=True) self.log("train/recall", self.train_recall, prog_bar=True, on_epoch=True) self.log("train/precision", self.train_precision, prog_bar=True, on_epoch=True) + # vector metrics + self.train_f1_perclass(all_preds, all_targets) + for c_i, c_f1 in enumerate(self.train_f1_perclass.compute()): + self.log( + f"train/f1-per-class/c{c_i}", + c_f1, + prog_bar=False, + on_epoch=True, + ) def validation_step( self, @@ -307,20 +353,30 @@ def validation_step( # update and log metrics self.val_loss(loss) - self.val_acc(preds, targets[:, -1]) + self.val_acc(preds, targets[:, self.hparams.pred_frame_index]) self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + self.val_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) + for c_i, c_acc in enumerate(self.val_acc_perclass.compute()): + self.log( + f"val/acc-per-class/c{c_i}", + c_acc, + prog_bar=False, + on_step=False, + on_epoch=True, + ) + # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. return { "loss": loss, "preds": preds, "probs": probs, - "targets": targets[:, -1], - "source_vid": source_vid[:, -1], - "source_frame": source_frame[:, -1], + "targets": targets[:, self.hparams.pred_frame_index], + "source_vid": source_vid[:, self.hparams.pred_frame_index], + "source_frame": source_frame[:, self.hparams.pred_frame_index], } def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: @@ -333,6 +389,15 @@ def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) self.log("val/f1", self.val_f1, prog_bar=True, on_epoch=True) self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True) self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True) + # vector metrics + self.val_f1_perclass(all_preds, all_targets) + for c_i, c_f1 in enumerate(self.val_f1_perclass.compute()): + self.log( + f"val/f1-per-class/c{c_i}", + c_f1, + prog_bar=False, + on_epoch=True, + ) # log `val_f1_best` as a value through `.compute()` return, instead of # as a metric object otherwise metric would be reset by lightning after @@ -357,21 +422,31 @@ def test_step( # update and log metrics self.test_loss(loss) - self.test_acc(preds, targets[:, -1]) + self.test_acc(preds, targets[:, self.hparams.pred_frame_index]) self.log( "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) + for c_i, c_acc in enumerate(self.test_acc_perclass.compute()): + self.log( + f"test/acc-per-class/c{c_i}", + c_acc, + prog_bar=False, + on_step=False, + on_epoch=True, + ) + # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. return { "loss": loss, "preds": preds, "probs": probs, - "targets": targets[:, -1], - "source_vid": source_vid[:, -1], - "source_frame": source_frame[:, -1], + "targets": targets[:, self.hparams.pred_frame_index], + "source_vid": source_vid[:, self.hparams.pred_frame_index], + "source_frame": source_frame[:, self.hparams.pred_frame_index], } def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: @@ -385,6 +460,15 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) + # vector metrics + self.test_f1_perclass(all_preds, all_targets) + for c_i, c_f1 in enumerate(self.test_f1_perclass.compute()): + self.log( + f"test/f1-per-class/c{c_i}", + c_f1, + prog_bar=False, + on_epoch=True, + ) def setup(self, stage: Optional[str] = None) -> None: """Lightning hook that is called at the beginning of fit (train + validate), validate, diff --git a/tcn_hpl/train.py b/tcn_hpl/train.py index 7c8c29266..ceaaf18ff 100644 --- a/tcn_hpl/train.py +++ b/tcn_hpl/train.py @@ -82,9 +82,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Logging hyperparameters!") utils.log_hyperparameters(object_dict) + configured_ckpt_path = cfg.get("ckpt_path") + if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=configured_ckpt_path) train_metrics = trainer.callback_metrics @@ -92,9 +94,13 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") - ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path="best") + if configured_ckpt_path is not None: + log.warning("Best ckpt not found! Using configured weights for testing...") + ckpt_path = configured_ckpt_path + else: + log.warning("Best ckpt not found! Using current weights for testing...") + ckpt_path = None + trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}")