diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index 3f7ef2d2b..1867a9afc 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -38,3 +38,6 @@ num_classes: ${data.num_classes} # compile model for faster training with pytorch 2.0 compile: false + +# Hydra output dir +output_dir: ${paths.output_dir} diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 457bb8510..0816ae6ae 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -15,8 +15,6 @@ import kwcoco -from hydra.core.hydra_config import HydraConfig - from angel_system.data.common.load_data import ( time_from_name, ) @@ -71,7 +69,8 @@ def __init__( data_dir: str, num_classes: int, compile: bool, - mapping_file_name: str = "mapping.txt" + mapping_file_name: str = "mapping.txt", + output_dir: str = None, ) -> None: """Initialize a `PTGLitModule`. @@ -96,6 +95,7 @@ def __init__( actions_dict = dict() for a in actions: actions_dict[a.split()[1]] = int(a.split()[0]) + self.class_ids = list(actions_dict.values()) self.classes = list(actions_dict.keys()) self.action_id_to_str = dict(zip(self.class_ids, self.classes)) @@ -123,36 +123,8 @@ def __init__( self.validation_step_outputs_source_vid = [] self.validation_step_outputs_source_frame = [] - hydra_cfg = HydraConfig.get() - self.output_dir = hydra_cfg['runtime']['output_dir'] - - # Load val vidoes - vid_list_file_val = f"{self.hparams.data_dir}/splits/val.split1.bundle" - with open(vid_list_file_val, "r") as val_f: - self.val_videos = val_f.read().split("\n")[:-1] - - self.val_frames = {} - for video in self.val_videos: - # Load frame filenames for the video - frame_list_file_val = f"{self.hparams.data_dir}/frames/{video}" - with open(frame_list_file_val, "r") as val_f: - val_fns = val_f.read().split("\n")[:-1] - - self.val_frames[video[:-4]] = val_fns - - # Load test vidoes - vid_list_file_tst = f"{self.hparams.data_dir}/splits/test.split1.bundle" - with open(vid_list_file_tst, "r") as test_f: - self.test_videos = test_f.read().split("\n")[:-1] - - self.test_frames = {} - for video in self.test_videos: - # Load frame filenames for the video - frame_list_file_tst = f"{self.hparams.data_dir}/frames/{video}" - with open(frame_list_file_tst, "r") as test_f: - test_fns = test_f.read().split("\n")[:-1] - - self.test_frames[video[:-4]] = test_fns + self.val_frames = None + self.test_frames = None def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -292,9 +264,24 @@ def on_validation_epoch_end(self) -> None: all_source_vids = torch.concat(self.validation_step_outputs_source_vid) all_source_frames = torch.concat(self.validation_step_outputs_source_frame) + # Load val vidoes + if self.val_frames is None: + self.val_frames = {} + vid_list_file_val = f"{self.hparams.data_dir}/splits/val.split1.bundle" + with open(vid_list_file_val, "r") as val_f: + self.val_videos = val_f.read().split("\n")[:-1] + + for video in self.val_videos: + # Load frame filenames for the video + frame_list_file_val = f"{self.hparams.data_dir}/frames/{video}" + with open(frame_list_file_val, "r") as val_f: + val_fns = val_f.read().split("\n")[:-1] + + self.val_frames[video[:-4]] = val_fns + # Save results dset = kwcoco.CocoDataset() - dset.fpath = f"{self.output_dir}/val_activity_preds_epoch{self.current_epoch}.mscoco.json" + dset.fpath = f"{self.hparams.output_dir}/val_activity_preds_epoch{self.current_epoch}.mscoco.json" dset.dataset["info"].append({"activity_labels": self.action_id_to_str}) for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): @@ -324,8 +311,10 @@ def on_validation_epoch_end(self) -> None: normalize="true" ) - fig, ax = plt.subplots(figsize=(20,20)) - sns.heatmap(cm, annot=True, ax=ax, fmt=".2f") + num_act_classes = len(self.class_ids) + fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) + + sns.heatmap(cm, annot=True, ax=ax, fmt="0.0%", linewidth=0.5) # labels, title and ticks ax.set_xlabel('Predicted labels') @@ -375,9 +364,24 @@ def on_test_epoch_end(self) -> None: all_source_vids = torch.concat(self.validation_step_outputs_source_vid) all_source_frames = torch.concat(self.validation_step_outputs_source_frame) + # Load test vidoes + if self.test_frames is None: + self.test_frames = {} + vid_list_file_tst = f"{self.hparams.data_dir}/splits/test.split1.bundle" + with open(vid_list_file_tst, "r") as test_f: + self.test_videos = test_f.read().split("\n")[:-1] + + for video in self.test_videos: + # Load frame filenames for the video + frame_list_file_tst = f"{self.hparams.data_dir}/frames/{video}" + with open(frame_list_file_tst, "r") as test_f: + test_fns = test_f.read().split("\n")[:-1] + + self.test_frames[video[:-4]] = test_fns + # Save results dset = kwcoco.CocoDataset() - dset.fpath = f"{self.output_dir}/test_activity_preds.mscoco.json" + dset.fpath = f"{self.hparams.output_dir}/test_activity_preds.mscoco.json" dset.dataset["info"].append({"activity_labels": self.action_id_to_str}) for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): @@ -407,7 +411,9 @@ def on_test_epoch_end(self) -> None: normalize="true" ) - fig, ax = plt.subplots(figsize=(20,20)) + num_act_classes = len(self.class_ids) + fig, ax = plt.subplots(figsize=(num_act_classes,num_act_classes)) + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f") # labels, title and ticks @@ -417,6 +423,8 @@ def on_test_epoch_end(self) -> None: ax.xaxis.set_ticklabels(self.classes, rotation=90) ax.yaxis.set_ticklabels(self.classes, rotation=0) + fig.savefig(f"{self.hparams.output_dir}/test_confusion_mat.png", pad_inches=5) + self.logger.experiment.track(Image(fig), name=f'CM Test Epoch') plt.close(fig)