Skip to content

Commit

Permalink
Merge pull request #16 from PTG-Kitware/dev/remove_hydra_call
Browse files Browse the repository at this point in the history
Remove hydra config call in init
  • Loading branch information
Purg authored Oct 26, 2023
2 parents 58a7553 + b01296e commit 84bec29
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 38 deletions.
3 changes: 3 additions & 0 deletions configs/model/ptg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ num_classes: ${data.num_classes}

# compile model for faster training with pytorch 2.0
compile: false

# Hydra output dir
output_dir: ${paths.output_dir}
84 changes: 46 additions & 38 deletions tcn_hpl/models/ptg_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import kwcoco

from hydra.core.hydra_config import HydraConfig

from angel_system.data.common.load_data import (
time_from_name,
)
Expand Down Expand Up @@ -71,7 +69,8 @@ def __init__(
data_dir: str,
num_classes: int,
compile: bool,
mapping_file_name: str = "mapping.txt"
mapping_file_name: str = "mapping.txt",
output_dir: str = None,
) -> None:
"""Initialize a `PTGLitModule`.
Expand All @@ -96,6 +95,7 @@ def __init__(
actions_dict = dict()
for a in actions:
actions_dict[a.split()[1]] = int(a.split()[0])

self.class_ids = list(actions_dict.values())
self.classes = list(actions_dict.keys())
self.action_id_to_str = dict(zip(self.class_ids, self.classes))
Expand Down Expand Up @@ -123,36 +123,8 @@ def __init__(
self.validation_step_outputs_source_vid = []
self.validation_step_outputs_source_frame = []

hydra_cfg = HydraConfig.get()
self.output_dir = hydra_cfg['runtime']['output_dir']

# Load val vidoes
vid_list_file_val = f"{self.hparams.data_dir}/splits/val.split1.bundle"
with open(vid_list_file_val, "r") as val_f:
self.val_videos = val_f.read().split("\n")[:-1]

self.val_frames = {}
for video in self.val_videos:
# Load frame filenames for the video
frame_list_file_val = f"{self.hparams.data_dir}/frames/{video}"
with open(frame_list_file_val, "r") as val_f:
val_fns = val_f.read().split("\n")[:-1]

self.val_frames[video[:-4]] = val_fns

# Load test vidoes
vid_list_file_tst = f"{self.hparams.data_dir}/splits/test.split1.bundle"
with open(vid_list_file_tst, "r") as test_f:
self.test_videos = test_f.read().split("\n")[:-1]

self.test_frames = {}
for video in self.test_videos:
# Load frame filenames for the video
frame_list_file_tst = f"{self.hparams.data_dir}/frames/{video}"
with open(frame_list_file_tst, "r") as test_f:
test_fns = test_f.read().split("\n")[:-1]

self.test_frames[video[:-4]] = test_fns
self.val_frames = None
self.test_frames = None

def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass through the model `self.net`.
Expand Down Expand Up @@ -292,9 +264,24 @@ def on_validation_epoch_end(self) -> None:
all_source_vids = torch.concat(self.validation_step_outputs_source_vid)
all_source_frames = torch.concat(self.validation_step_outputs_source_frame)

# Load val vidoes
if self.val_frames is None:
self.val_frames = {}
vid_list_file_val = f"{self.hparams.data_dir}/splits/val.split1.bundle"
with open(vid_list_file_val, "r") as val_f:
self.val_videos = val_f.read().split("\n")[:-1]

for video in self.val_videos:
# Load frame filenames for the video
frame_list_file_val = f"{self.hparams.data_dir}/frames/{video}"
with open(frame_list_file_val, "r") as val_f:
val_fns = val_f.read().split("\n")[:-1]

self.val_frames[video[:-4]] = val_fns

# Save results
dset = kwcoco.CocoDataset()
dset.fpath = f"{self.output_dir}/val_activity_preds_epoch{self.current_epoch}.mscoco.json"
dset.fpath = f"{self.hparams.output_dir}/val_activity_preds_epoch{self.current_epoch}.mscoco.json"
dset.dataset["info"].append({"activity_labels": self.action_id_to_str})

for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames):
Expand Down Expand Up @@ -324,8 +311,10 @@ def on_validation_epoch_end(self) -> None:
normalize="true"
)

fig, ax = plt.subplots(figsize=(20,20))
sns.heatmap(cm, annot=True, ax=ax, fmt=".2f")
num_act_classes = len(self.class_ids)
fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt="0.0%", linewidth=0.5)

# labels, title and ticks
ax.set_xlabel('Predicted labels')
Expand Down Expand Up @@ -375,9 +364,24 @@ def on_test_epoch_end(self) -> None:
all_source_vids = torch.concat(self.validation_step_outputs_source_vid)
all_source_frames = torch.concat(self.validation_step_outputs_source_frame)

# Load test vidoes
if self.test_frames is None:
self.test_frames = {}
vid_list_file_tst = f"{self.hparams.data_dir}/splits/test.split1.bundle"
with open(vid_list_file_tst, "r") as test_f:
self.test_videos = test_f.read().split("\n")[:-1]

for video in self.test_videos:
# Load frame filenames for the video
frame_list_file_tst = f"{self.hparams.data_dir}/frames/{video}"
with open(frame_list_file_tst, "r") as test_f:
test_fns = test_f.read().split("\n")[:-1]

self.test_frames[video[:-4]] = test_fns

# Save results
dset = kwcoco.CocoDataset()
dset.fpath = f"{self.output_dir}/test_activity_preds.mscoco.json"
dset.fpath = f"{self.hparams.output_dir}/test_activity_preds.mscoco.json"
dset.dataset["info"].append({"activity_labels": self.action_id_to_str})

for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames):
Expand Down Expand Up @@ -407,7 +411,9 @@ def on_test_epoch_end(self) -> None:
normalize="true"
)

fig, ax = plt.subplots(figsize=(20,20))
num_act_classes = len(self.class_ids)
fig, ax = plt.subplots(figsize=(num_act_classes,num_act_classes))

sns.heatmap(cm, annot=True, ax=ax, fmt=".2f")

# labels, title and ticks
Expand All @@ -417,6 +423,8 @@ def on_test_epoch_end(self) -> None:
ax.xaxis.set_ticklabels(self.classes, rotation=90)
ax.yaxis.set_ticklabels(self.classes, rotation=0)

fig.savefig(f"{self.hparams.output_dir}/test_confusion_mat.png", pad_inches=5)

self.logger.experiment.track(Image(fig), name=f'CM Test Epoch')

plt.close(fig)
Expand Down

0 comments on commit 84bec29

Please sign in to comment.