From 701fb96df74dca90d1dad7a14a4556476696bfee Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Tue, 19 Nov 2024 19:01:37 -0500 Subject: [PATCH 1/7] Initial draft: write out prediction json at test step during training --- tcn_hpl/callbacks/plot_metrics.py | 32 +++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 7fdd99309..9f07d65e2 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -10,6 +10,7 @@ from pytorch_lightning.utilities.types import STEP_OUTPUT from sklearn.metrics import confusion_matrix import torch +import kwcoco try: from aim import Image @@ -336,6 +337,7 @@ def on_test_batch_end( batch: Any, batch_idx: int, dataloader_idx: int, + preds_dset_output_fpath: Path = "./tcn_activity_predictions.kwcoco.json" ) -> None: """Called when the test batch ends.""" # Re-using validation lists since test phase does not collide with @@ -345,6 +347,7 @@ def on_test_batch_end( self._val_all_targets.append(outputs["targets"].cpu()) self._val_all_source_vids.append(outputs["source_vid"].cpu()) self._val_all_source_frames.append(outputs["source_frame"].cpu()) + self._preds_dset_output_fpath = preds_dset_output_fpath def on_test_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" @@ -362,6 +365,35 @@ def on_test_epoch_end( test_acc = pl_module.test_acc.compute() test_f1 = pl_module.test_f1.compute() + # Create activity predictions KWCOCO JSON + truth_dset_fpath = trainer.datamodule.hparams["coco_test_activities"] + truth_dset = kwcoco.CocoDataset(truth_dset_fpath) + acts_dset = kwcoco.CocoDataset() + acts_dset.fpath = self._preds_dset_output_fpath + acts_dset.dataset['videos'] = truth_dset.dataset['videos'] + acts_dset.dataset['images'] = truth_dset.dataset['images'] + acts_dset.index.build(acts_dset) + acts_dset.dataset['categories'] = truth_dset.dataset['categories'] + # Create numpy lookup tables + for i in range(len(all_preds)): + frame_index = all_source_frames[i] + video_id = all_source_vids[i] + # This list could be as long as the number of videos in the dset + matching_frame_indexes = torch.where(all_source_frames == frame_index)[0] + assert video_id in all_source_vids[matching_frame_indexes] + sub_index = torch.where(all_source_vids[matching_frame_indexes] == video_id) + image_id = int(matching_frame_indexes[sub_index]) + + ann = { + "score": all_probs[i][all_preds[i]], + "prob": all_probs[i], + "category_id": all_preds[i], + "image_id": image_id + } + acts_dset.add_annotation(**ann) + acts_dset.dump(acts_dset.fpath, newlines=True) + + # # Plot per-video class predictions vs. GT across progressive frames in # that video. From f63f99706a39bc063326017de23c5f76a8aa84b7 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Wed, 20 Nov 2024 12:22:15 -0500 Subject: [PATCH 2/7] current state: breaks at i = 1159 --- tcn_hpl/callbacks/plot_metrics.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 9f07d65e2..b8c661dd9 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -361,6 +361,8 @@ def on_test_epoch_end( all_source_vids = torch.cat(self._val_all_source_vids) # shape: #frames all_source_frames = torch.cat(self._val_all_source_frames) # shape: #frames + #import ipdb; ipdb.set_trace() + current_epoch = pl_module.current_epoch test_acc = pl_module.test_acc.compute() test_f1 = pl_module.test_f1.compute() @@ -378,18 +380,26 @@ def on_test_epoch_end( for i in range(len(all_preds)): frame_index = all_source_frames[i] video_id = all_source_vids[i] + ''' # This list could be as long as the number of videos in the dset matching_frame_indexes = torch.where(all_source_frames == frame_index)[0] assert video_id in all_source_vids[matching_frame_indexes] sub_index = torch.where(all_source_vids[matching_frame_indexes] == video_id) - image_id = int(matching_frame_indexes[sub_index]) + frame_index = int(matching_frame_indexes[sub_index]) + ''' + # Now get the image_id that matches the frame_index and video_id. + sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)] + image_id = sorted_img_ids_for_one_video[frame_index] + # Sanity check: this image_id corresponds to the frame_index and video_id + assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index + assert acts_dset.index.imgs[image_id]['video_id'] == video_id ann = { + "image_id": image_id, + "category_id": all_preds[i], "score": all_probs[i][all_preds[i]], "prob": all_probs[i], - "category_id": all_preds[i], - "image_id": image_id - } + } acts_dset.add_annotation(**ann) acts_dset.dump(acts_dset.fpath, newlines=True) From 65dc076b1ab66f8a33077b8b8fa8f477d1e77ba0 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson Date: Wed, 20 Nov 2024 13:40:55 -0500 Subject: [PATCH 3/7] torch tensors were changing when being observed! Assigning needed variables to ints fixed the problem. --- tcn_hpl/callbacks/plot_metrics.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index b8c661dd9..9c8238571 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -337,7 +337,6 @@ def on_test_batch_end( batch: Any, batch_idx: int, dataloader_idx: int, - preds_dset_output_fpath: Path = "./tcn_activity_predictions.kwcoco.json" ) -> None: """Called when the test batch ends.""" # Re-using validation lists since test phase does not collide with @@ -347,7 +346,7 @@ def on_test_batch_end( self._val_all_targets.append(outputs["targets"].cpu()) self._val_all_source_vids.append(outputs["source_vid"].cpu()) self._val_all_source_frames.append(outputs["source_frame"].cpu()) - self._preds_dset_output_fpath = preds_dset_output_fpath + self._preds_dset_output_fpath = self.output_dir / "tcn_activity_predictions.kwcoco.json" def on_test_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" @@ -378,29 +377,26 @@ def on_test_epoch_end( acts_dset.dataset['categories'] = truth_dset.dataset['categories'] # Create numpy lookup tables for i in range(len(all_preds)): - frame_index = all_source_frames[i] - video_id = all_source_vids[i] - ''' - # This list could be as long as the number of videos in the dset - matching_frame_indexes = torch.where(all_source_frames == frame_index)[0] - assert video_id in all_source_vids[matching_frame_indexes] - sub_index = torch.where(all_source_vids[matching_frame_indexes] == video_id) - frame_index = int(matching_frame_indexes[sub_index]) - ''' + frame_index = all_source_frames[i].item() + video_id = all_source_vids[i].item() # Now get the image_id that matches the frame_index and video_id. sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)] image_id = sorted_img_ids_for_one_video[frame_index] # Sanity check: this image_id corresponds to the frame_index and video_id - assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index - assert acts_dset.index.imgs[image_id]['video_id'] == video_id + try: + assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index + assert acts_dset.index.imgs[image_id]['video_id'] == video_id + except: + import ipdb; ipdb.set_trace() ann = { "image_id": image_id, - "category_id": all_preds[i], - "score": all_probs[i][all_preds[i]], - "prob": all_probs[i], + "category_id": all_preds[i].item(), + "score": all_probs[i][all_preds[i]].item(), + "prob": all_probs[i].numpy().tolist(), } acts_dset.add_annotation(**ann) + print(f"Dumping activities file to {acts_dset.fpath}") acts_dset.dump(acts_dset.fpath, newlines=True) From 0a52afbe529cada08ab3a23023304495155cacbb Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Thu, 21 Nov 2024 12:13:21 -0500 Subject: [PATCH 4/7] Update plot_metrics.py --- tcn_hpl/callbacks/plot_metrics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 9c8238571..007727b1c 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -360,8 +360,6 @@ def on_test_epoch_end( all_source_vids = torch.cat(self._val_all_source_vids) # shape: #frames all_source_frames = torch.cat(self._val_all_source_frames) # shape: #frames - #import ipdb; ipdb.set_trace() - current_epoch = pl_module.current_epoch test_acc = pl_module.test_acc.compute() test_f1 = pl_module.test_f1.compute() From 872cc8b7c8c5c0265319d49f0f0e6d7c1553f76d Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:04:23 -0500 Subject: [PATCH 5/7] Update tcn_hpl/callbacks/plot_metrics.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/callbacks/plot_metrics.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 007727b1c..533d28c32 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -387,13 +387,12 @@ def on_test_epoch_end( except: import ipdb; ipdb.set_trace() - ann = { - "image_id": image_id, - "category_id": all_preds[i].item(), - "score": all_probs[i][all_preds[i]].item(), - "prob": all_probs[i].numpy().tolist(), - } - acts_dset.add_annotation(**ann) + acts_dset.add_annotation( + image_id=image_id, + category_id=all_preds[i].item(), + score=all_probs[i][all_preds[i]].item(), + prob=all_probs[i].numpy().tolist(), + ) print(f"Dumping activities file to {acts_dset.fpath}") acts_dset.dump(acts_dset.fpath, newlines=True) From 3868041e137007c7a6a25c2c5b79f36709548b61 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:04:33 -0500 Subject: [PATCH 6/7] Update tcn_hpl/callbacks/plot_metrics.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/callbacks/plot_metrics.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 533d28c32..f5c3504d4 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -381,11 +381,8 @@ def on_test_epoch_end( sorted_img_ids_for_one_video = acts_dset.index.vidid_to_gids[int(video_id)] image_id = sorted_img_ids_for_one_video[frame_index] # Sanity check: this image_id corresponds to the frame_index and video_id - try: - assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index - assert acts_dset.index.imgs[image_id]['video_id'] == video_id - except: - import ipdb; ipdb.set_trace() + assert acts_dset.index.imgs[image_id]['frame_index'] == frame_index + assert acts_dset.index.imgs[image_id]['video_id'] == video_id acts_dset.add_annotation( image_id=image_id, From f0c83727657cd54b97f47f36d4d133938df54410 Mon Sep 17 00:00:00 2001 From: cameron-a-johnson <43187095+cameron-a-johnson@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:04:45 -0500 Subject: [PATCH 7/7] Update tcn_hpl/callbacks/plot_metrics.py Co-authored-by: Paul Tunison <735270+Purg@users.noreply.github.com> --- tcn_hpl/callbacks/plot_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index f5c3504d4..a78d842d0 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -371,8 +371,8 @@ def on_test_epoch_end( acts_dset.fpath = self._preds_dset_output_fpath acts_dset.dataset['videos'] = truth_dset.dataset['videos'] acts_dset.dataset['images'] = truth_dset.dataset['images'] - acts_dset.index.build(acts_dset) acts_dset.dataset['categories'] = truth_dset.dataset['categories'] + acts_dset.index.build(acts_dset) # Create numpy lookup tables for i in range(len(all_preds)): frame_index = all_source_frames[i].item()