From 731c629e2435fdf4821afbdf2c506afe56e0735e Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Tue, 8 Oct 2024 11:25:08 -0400 Subject: [PATCH 1/5] small cleanup --- configs/experiment/m2/feat_v6.yaml | 2 +- configs/experiment/m5/feat_v6.yaml | 2 +- tcn_hpl/eval.py | 6 +++-- tcn_hpl/models/ptg_module.py | 43 +++++++++++++++++++++++++----- tcn_hpl/train.py | 5 ++-- 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/configs/experiment/m2/feat_v6.yaml b/configs/experiment/m2/feat_v6.yaml index 5546bb2f0..a4fd3778d 100644 --- a/configs/experiment/m2/feat_v6.yaml +++ b/configs/experiment/m2/feat_v6.yaml @@ -86,7 +86,7 @@ data: num_workers: 16 epoch_length: 20000 window_size: 25 - sample_rate: 2 + sample_rate: 1 all_transforms: train_order: [] #["MoveCenterPts", "NormalizePixelPts"] diff --git a/configs/experiment/m5/feat_v6.yaml b/configs/experiment/m5/feat_v6.yaml index 5c08d692b..dbbd216f4 100644 --- a/configs/experiment/m5/feat_v6.yaml +++ b/configs/experiment/m5/feat_v6.yaml @@ -73,7 +73,7 @@ data: num_workers: 12 epoch_length: 20000 window_size: 25 - sample_rate: 2 + sample_rate: 1 all_transforms: diff --git a/tcn_hpl/eval.py b/tcn_hpl/eval.py index 841d79d17..fc7be4b8f 100644 --- a/tcn_hpl/eval.py +++ b/tcn_hpl/eval.py @@ -2,10 +2,12 @@ import hydra import rootutils -from lightning import LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger from omegaconf import DictConfig +import pytorch_lightning as L +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning.loggers import Logger + rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index fd5ab8d8b..2c3ab7258 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -10,6 +10,7 @@ from torchmetrics import MaxMetric, MeanMetric from torchmetrics.classification.accuracy import Accuracy +from torchmetrics.classification import F1Score, Recall, Precision from torchmetrics.classification import MulticlassConfusionMatrix from sklearn.metrics import confusion_matrix @@ -125,6 +126,15 @@ def __init__( self.test_acc = Accuracy( task="multiclass", average="weighted", num_classes=num_classes ) + + self.val_f1 = F1Score(num_classes=num_classes, average="none", task="multiclass") + self.test_f1 = F1Score(num_classes=num_classes, average="none", task="multiclass") + + self.val_recall = Recall(num_classes=num_classes, average="none", task="multiclass") + self.test_recall = Recall(num_classes=num_classes, average="none", task="multiclass") + + self.val_precision = Precision(num_classes=num_classes, average="none", task="multiclass") + self.test_precision = Precision(num_classes=num_classes, average="none", task="multiclass") # for averaging loss across batches self.train_loss = MeanMetric() @@ -467,6 +477,9 @@ def validation_step( windowed_ys = torch.tensor(windowed_ys).to(targets) self.val_acc(preds, targets[:, -1]) + + + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) @@ -480,13 +493,6 @@ def on_validation_epoch_end(self) -> None: "Lightning hook that is called when a validation epoch ends." acc = self.val_acc.compute() # get current val acc - if self.current_epoch >= 15: - self.val_acc_best(acc) # update best so far val acc - # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object - # otherwise metric would be reset by lightning after each epoch - best_val_acc = self.val_acc_best.compute() - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames all_probs = torch.cat( @@ -495,6 +501,29 @@ def on_validation_epoch_end(self) -> None: all_source_vids = torch.cat(self.validation_step_outputs_source_vid) all_source_frames = torch.cat(self.validation_step_outputs_source_frame) + if self.current_epoch >= 15: + current_best_val_acc = self.val_acc_best.compute() + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + best_val_acc = self.val_acc_best.compute() + + if best_val_acc > current_best_val_acc: + val_f1_score = self.val_f1(all_preds, all_targets) + val_recall_score = self.val_recall(all_preds, all_targets) + val_precision_score = self.val_precision(all_preds, all_targets) + + # print(f"preds: {all_preds}") + # print(f"all_targets: {all_targets}") + print(f"validation f1 score: {val_f1_score}") + print(f"validation recall score: {val_recall_score}") + print(f"validation precision score: {val_precision_score}") + + self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) + + + + # Load val vidoes if self.val_frames is None: self.val_frames = {} diff --git a/tcn_hpl/train.py b/tcn_hpl/train.py index fff4f05d2..2e7e90da3 100644 --- a/tcn_hpl/train.py +++ b/tcn_hpl/train.py @@ -84,11 +84,12 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - # trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) - trainer.fit(model=model, datamodule=datamodule) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + # trainer.fit(model=model, datamodule=datamodule) train_metrics = trainer.callback_metrics + if cfg.get("test"): log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path From c992f02fbe1e4b37b0380a9164be1c27162b2127 Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Tue, 8 Oct 2024 15:50:51 -0400 Subject: [PATCH 2/5] add GT to kwcoco script, tcn dataset object, generated sample kwcoco file, config file corrections --- tcn_hpl/data/add_gt_to_kwcoco.py | 125 +++++++++++++++++++++++++++ tcn_hpl/data/tcn_dataset.py | 142 +++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 tcn_hpl/data/add_gt_to_kwcoco.py create mode 100644 tcn_hpl/data/tcn_dataset.py diff --git a/tcn_hpl/data/add_gt_to_kwcoco.py b/tcn_hpl/data/add_gt_to_kwcoco.py new file mode 100644 index 000000000..78a775b9d --- /dev/null +++ b/tcn_hpl/data/add_gt_to_kwcoco.py @@ -0,0 +1,125 @@ +import kwcoco +from angel_system.data.medical.data_paths import TASK_TO_NAME +from angel_system.data.medical.data_paths import LAB_TASK_TO_NAME +import yaml +import argparse +import tcn_hpl.utils.utils as utils +import ubelt as ub + +def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict): + + # set background to everything first + activity_gt_list = [0 for x in range(num_frames)] + f = open(text_file, "r") + text = f.read() + f.close() + text = text.replace("\n", "\t") + text_list = text.split("\t") + if text_list[-1] == "": + text_list = text_list[:-1] + + # this check handles inconsistencies in the GT we get from BBN + if task == "r18" or task=="m3": + jump = 4 + elif task=="m2" or task=="m5": + jump = 3 + + for index in range(0, len(text_list), jump): + triplet = text_list[index : index + jump] + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[jump-1] + + gt_label = mapping[desc] + + if end_frame - 1 > num_frames: + ### address issue with GT activity labels + print( + "Max frame in GT is larger than number of frames in the video" + ) + + for label_index in range( + start_frame, min(end_frame - 1, num_frames) + ): + # print(f"label_index: {label_index}") + activity_gt_list[label_index] = gt_label + + return activity_gt_list + +def main(config_path: str): + + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + task_name = config["task"] + raw_data_root = f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/" + + dset = kwcoco.CocoDataset(config["data_gen"]["dataset_kwcoco"]) + + + with open(config['data_gen']['activity_config_fn'], "r") as stream: + activity_config = yaml.safe_load(stream) + activity_labels = activity_config["labels"] + + activity_labels_desc_mapping = {} + activity_labels_label_mapping = {} + for label in activity_labels: + i = label["id"] + label_str = label["label"] + if "description" in label.keys(): + # activity_labels_desc_mapping[label["description"]] = label["label"] + activity_labels_desc_mapping[label["description"]] = label["id"] + elif "full_str" in label.keys(): + # activity_labels_desc_mapping[label["full_str"]] = label["label"] + activity_labels_desc_mapping[label["full_str"]] = label["id"] + activity_labels_label_mapping[label_str] = label["id"] + if label_str == "done": + continue + + gt_paths_to_names_dict = {} + gt_paths = utils.dictionary_contents(raw_data_root, types=['*.txt']) + for gt_path in gt_paths: + name = gt_path.split('/')[-1].split('.')[0] + gt_paths_to_names_dict[name] = gt_path + + print(gt_paths_to_names_dict) + + if not "activity_gt" in list(dset.imgs.values())[0].keys(): + print("adding activity ground truth to the dataset") + for video_id in ub.ProgIter(dset.index.videos.keys()): + video = dset.index.videos[video_id] + video_name = video["name"] + if video_name in gt_paths_to_names_dict.keys(): + gt_text = gt_paths_to_names_dict[video_name] + else: + print(f"GT file does not exist for {video_name}. Continue...") + continue + + image_ids = dset.index.vidid_to_gids[video_id] + num_frames = len(image_ids) + + activity_gt_list = text_to_labels(gt_text, num_frames, task_name, activity_labels_desc_mapping) + + for index, img_id in enumerate(image_ids): + im = dset.index.imgs[img_id] + frame_index = int(im["frame_index"]) + dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index] + # print(f"video: {video}") + + dset.dump("test.mscoco.json", newlines=True) + # print(raw_data_root) + # print(activity_labels) + print(activity_labels_desc_mapping) + # print(paths) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml", help="" + ) + + args = parser.parse_args() + + main(args.config) \ No newline at end of file diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py new file mode 100644 index 000000000..9dd69de1b --- /dev/null +++ b/tcn_hpl/data/tcn_dataset.py @@ -0,0 +1,142 @@ +import torch +from torch.utils.data import Dataset +import kwcoco +from tqdm import tqdm +import logging +import tcn_hpl.utils.utils as utils + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +def standarize_online_inputs(inputs): + pass + + +def standarize_offline_inputs(data): + dset = data[0] + frames = data[1] + + + +def collect_inputs(data, offline=True): + """ + Collects inputs from either an offline or real-time source. + + Args: + data: data stream input. For offline, a coco dataset. For online, a set of ROS messages + + offline (bool): If True, fetches data from offline source; otherwise, real-time. + + Returns: + inputs: The collected inputs from either a real-time or offline source in standarized format. + Designed as a dict of lists. all lists must have the size of the widnow_size: + {object_dets: [], pose_estimations:[], left_hand: [], right_hand: []} + + """ + if offline: + # Fetch from offline data source + inputs = standarize_offline_inputs(data) + else: + # Fetch from real-time source (dummy logic here, replace with actual) + inputs = standarize_online_inputs(data) + + return inputs + + +def define_tcn_vector(inputs): + """ + Define the TCN vector using the collected inputs. + + Args: + inputs: The inputs collected from either real-time or offline source. + + Returns: + tcn_vector: The defined TCN vector. + """ + # Example TCN vector calculation (modify according to your actual logic) + tcn_vector = torch.tensor(inputs).float() # Dummy transformation + return tcn_vector + + +class TCNDataset(Dataset): + + def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int): + """ + Initializes the dataset. + + Args: + kwcoco_path: The source of data (can be real-time or offline). + sample_rate: + window_size: + """ + self.sample_rate = sample_rate + self.window_size = window_size + + self.dset = kwcoco.CocoDataset(kwcoco_path) + + # for offline training, pre-cut videos into clips according to window size for easy batching + self.frames = [] + + logger.info(f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos") + pber = tqdm(self.dset.index.videos.keys(), total=len(list(self.dset.index.videos.keys()))) + for vid in pber: + video_dict = self.dset.index.videos[vid] + + vid_frames = self.dset.index.vidid_to_gids[vid] + + for index in range(0, len(vid_frames)-window_size-1, sample_rate): + video_slice = vid_frames[index: index+window_size] + window_frame_dicts = [self.dset.index.imgs[gid] for gid in video_slice] + + # start_frame = window_frame_dicts[0]['frame_index'] + # end_frame = window_frame_dicts[-1]['frame_index'] + + # n_frames = end_frame - start_frame + 1 + + self.frames.append(window_frame_dicts) + + def __getitem__(self, index): + """ + Fetches the data point and defines its TCN vector. + + Args: + index: The index of the data point. + + Returns: + tcn_vector: The TCN vector for the given index. + """ + + data = self.frames[index] + + inputs = collect_inputs([self.dset, data], offline=True) + + # tcn_vector = define_tcn_vector(inputs) + + return + + def __len__(self): + """ + Returns the length of the dataset. + + Returns: + length: Length of the dataset. + """ + return len(self.frames) + + +if __name__ == "__main__": +# Example usage: + kwcoco_path = "/data/PTG/medical/training/yolo_object_detector/detect/r18_all/r18_all_all_obj_results_with_dets_and_pose.mscoco.json" + + dataset = TCNDataset(kwcoco_path=kwcoco_path, + sample_rate=1, + window_size=25) + + print(f"dataset: {len(dataset)}") + data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) + + for index, batch in enumerate(data_loader): + print(batch) # This will print the TCN vectors for the batch + if index > 15: + exit() From 19adc05c43af6fabd66ae9847fce9244cf8bc09b Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Tue, 8 Oct 2024 15:55:01 -0400 Subject: [PATCH 3/5] add GT to kwcoco script, tcn dataset object, generated sample kwcoco file, config file corrections --- tcn_hpl/data/tcn_dataset.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 9dd69de1b..05a75946d 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -16,8 +16,7 @@ def standarize_online_inputs(inputs): def standarize_offline_inputs(data): dset = data[0] frames = data[1] - - + def collect_inputs(data, offline=True): """ @@ -35,10 +34,8 @@ def collect_inputs(data, offline=True): """ if offline: - # Fetch from offline data source inputs = standarize_offline_inputs(data) else: - # Fetch from real-time source (dummy logic here, replace with actual) inputs = standarize_online_inputs(data) return inputs @@ -54,7 +51,6 @@ def define_tcn_vector(inputs): Returns: tcn_vector: The defined TCN vector. """ - # Example TCN vector calculation (modify according to your actual logic) tcn_vector = torch.tensor(inputs).float() # Dummy transformation return tcn_vector From 5d3b52698b2f3b5e8c8909e301b7ce5782ec986d Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 9 Oct 2024 16:18:03 -0400 Subject: [PATCH 4/5] Black reformatting --- tcn_hpl/data/add_gt_to_kwcoco.py | 65 ++++++++++++++++---------------- tcn_hpl/data/tcn_dataset.py | 20 +++++----- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/tcn_hpl/data/add_gt_to_kwcoco.py b/tcn_hpl/data/add_gt_to_kwcoco.py index 78a775b9d..fdd6867b9 100644 --- a/tcn_hpl/data/add_gt_to_kwcoco.py +++ b/tcn_hpl/data/add_gt_to_kwcoco.py @@ -6,8 +6,8 @@ import tcn_hpl.utils.utils as utils import ubelt as ub + def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict): - # set background to everything first activity_gt_list = [0 for x in range(num_frames)] f = open(text_file, "r") @@ -17,50 +17,47 @@ def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict): text_list = text.split("\t") if text_list[-1] == "": text_list = text_list[:-1] - + # this check handles inconsistencies in the GT we get from BBN - if task == "r18" or task=="m3": + if task == "r18" or task == "m3": jump = 4 - elif task=="m2" or task=="m5": + elif task == "m2" or task == "m5": jump = 3 for index in range(0, len(text_list), jump): triplet = text_list[index : index + jump] start_frame = int(triplet[0]) end_frame = int(triplet[1]) - desc = triplet[jump-1] - + desc = triplet[jump - 1] + gt_label = mapping[desc] if end_frame - 1 > num_frames: ### address issue with GT activity labels - print( - "Max frame in GT is larger than number of frames in the video" - ) + print("Max frame in GT is larger than number of frames in the video") - for label_index in range( - start_frame, min(end_frame - 1, num_frames) - ): + for label_index in range(start_frame, min(end_frame - 1, num_frames)): # print(f"label_index: {label_index}") activity_gt_list[label_index] = gt_label - + return activity_gt_list + def main(config_path: str): - with open(config_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) - + task_name = config["task"] - raw_data_root = f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/" - + raw_data_root = ( + f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/" + ) + dset = kwcoco.CocoDataset(config["data_gen"]["dataset_kwcoco"]) - - with open(config['data_gen']['activity_config_fn'], "r") as stream: + with open(config["data_gen"]["activity_config_fn"], "r") as stream: activity_config = yaml.safe_load(stream) activity_labels = activity_config["labels"] - + activity_labels_desc_mapping = {} activity_labels_label_mapping = {} for label in activity_labels: @@ -75,15 +72,15 @@ def main(config_path: str): activity_labels_label_mapping[label_str] = label["id"] if label_str == "done": continue - + gt_paths_to_names_dict = {} - gt_paths = utils.dictionary_contents(raw_data_root, types=['*.txt']) + gt_paths = utils.dictionary_contents(raw_data_root, types=["*.txt"]) for gt_path in gt_paths: - name = gt_path.split('/')[-1].split('.')[0] + name = gt_path.split("/")[-1].split(".")[0] gt_paths_to_names_dict[name] = gt_path - + print(gt_paths_to_names_dict) - + if not "activity_gt" in list(dset.imgs.values())[0].keys(): print("adding activity ground truth to the dataset") for video_id in ub.ProgIter(dset.index.videos.keys()): @@ -94,18 +91,20 @@ def main(config_path: str): else: print(f"GT file does not exist for {video_name}. Continue...") continue - + image_ids = dset.index.vidid_to_gids[video_id] num_frames = len(image_ids) - - activity_gt_list = text_to_labels(gt_text, num_frames, task_name, activity_labels_desc_mapping) - + + activity_gt_list = text_to_labels( + gt_text, num_frames, task_name, activity_labels_desc_mapping + ) + for index, img_id in enumerate(image_ids): im = dset.index.imgs[img_id] frame_index = int(im["frame_index"]) dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index] # print(f"video: {video}") - + dset.dump("test.mscoco.json", newlines=True) # print(raw_data_root) # print(activity_labels) @@ -117,9 +116,11 @@ def main(config_path: str): parser = argparse.ArgumentParser() parser.add_argument( - "--config", default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml", help="" + "--config", + default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml", + help="", ) args = parser.parse_args() - main(args.config) \ No newline at end of file + main(args.config) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 05a75946d..45cb7dc6a 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -56,7 +56,6 @@ def define_tcn_vector(inputs): class TCNDataset(Dataset): - def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int): """ Initializes the dataset. @@ -74,15 +73,20 @@ def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int): # for offline training, pre-cut videos into clips according to window size for easy batching self.frames = [] - logger.info(f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos") - pber = tqdm(self.dset.index.videos.keys(), total=len(list(self.dset.index.videos.keys()))) + logger.info( + f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos" + ) + pber = tqdm( + self.dset.index.videos.keys(), + total=len(list(self.dset.index.videos.keys())), + ) for vid in pber: video_dict = self.dset.index.videos[vid] vid_frames = self.dset.index.vidid_to_gids[vid] - for index in range(0, len(vid_frames)-window_size-1, sample_rate): - video_slice = vid_frames[index: index+window_size] + for index in range(0, len(vid_frames) - window_size - 1, sample_rate): + video_slice = vid_frames[index : index + window_size] window_frame_dicts = [self.dset.index.imgs[gid] for gid in video_slice] # start_frame = window_frame_dicts[0]['frame_index'] @@ -122,12 +126,10 @@ def __len__(self): if __name__ == "__main__": -# Example usage: + # Example usage: kwcoco_path = "/data/PTG/medical/training/yolo_object_detector/detect/r18_all/r18_all_all_obj_results_with_dets_and_pose.mscoco.json" - dataset = TCNDataset(kwcoco_path=kwcoco_path, - sample_rate=1, - window_size=25) + dataset = TCNDataset(kwcoco_path=kwcoco_path, sample_rate=1, window_size=25) print(f"dataset: {len(dataset)}") data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) From 4ce16f107610b9f81f760a2b1da4c90981e440c6 Mon Sep 17 00:00:00 2001 From: Paul Tunison Date: Wed, 9 Oct 2024 17:58:33 -0400 Subject: [PATCH 5/5] Some formatting and commenting --- tcn_hpl/data/add_gt_to_kwcoco.py | 38 +++++++++++++++++----------- tcn_hpl/data/tcn_dataset.py | 7 +++--- tcn_hpl/models/ptg_module.py | 43 +++++++++++++++++++------------- 3 files changed, 53 insertions(+), 35 deletions(-) diff --git a/tcn_hpl/data/add_gt_to_kwcoco.py b/tcn_hpl/data/add_gt_to_kwcoco.py index fdd6867b9..baaeeda53 100644 --- a/tcn_hpl/data/add_gt_to_kwcoco.py +++ b/tcn_hpl/data/add_gt_to_kwcoco.py @@ -1,15 +1,32 @@ -import kwcoco -from angel_system.data.medical.data_paths import TASK_TO_NAME -from angel_system.data.medical.data_paths import LAB_TASK_TO_NAME -import yaml import argparse + +import kwcoco +import numpy as np +import numpy.typing as npt import tcn_hpl.utils.utils as utils import ubelt as ub +import yaml + +from angel_system.data.medical.data_paths import LAB_TASK_TO_NAME -def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict): - # set background to everything first - activity_gt_list = [0 for x in range(num_frames)] +def text_to_labels( + text_file: str, num_frames: int, task: str, mapping: dict +) -> npt.NDArray[int]: + """ + Convert a "skill_labels_by_frame" text truth file from BBN into labels for + the given task and number of frames. + + :param text_file: Filesystem path to the BBN activity text file. + :param num_frames: Number of frames in the video the truth file is related + to. + :param task: The identifying name of the task, e.g. "m2", "m3", "r18", etc. + :param mapping: Mapping of task step descriptions to the integer label + value for that step. + :return: + """ + # set background to everything first (assuming value 0). + activity_gt_list = np.zeros(num_frames) f = open(text_file, "r") text = f.read() f.close() @@ -37,7 +54,6 @@ def text_to_labels(text_file: str, num_frames: int, task: str, mapping: dict): print("Max frame in GT is larger than number of frames in the video") for label_index in range(start_frame, min(end_frame - 1, num_frames)): - # print(f"label_index: {label_index}") activity_gt_list[label_index] = gt_label return activity_gt_list @@ -64,10 +80,8 @@ def main(config_path: str): i = label["id"] label_str = label["label"] if "description" in label.keys(): - # activity_labels_desc_mapping[label["description"]] = label["label"] activity_labels_desc_mapping[label["description"]] = label["id"] elif "full_str" in label.keys(): - # activity_labels_desc_mapping[label["full_str"]] = label["label"] activity_labels_desc_mapping[label["full_str"]] = label["id"] activity_labels_label_mapping[label_str] = label["id"] if label_str == "done": @@ -103,13 +117,9 @@ def main(config_path: str): im = dset.index.imgs[img_id] frame_index = int(im["frame_index"]) dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index] - # print(f"video: {video}") dset.dump("test.mscoco.json", newlines=True) - # print(raw_data_root) - # print(activity_labels) print(activity_labels_desc_mapping) - # print(paths) if __name__ == "__main__": diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index 45cb7dc6a..d0d3d4590 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -73,12 +73,13 @@ def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int): # for offline training, pre-cut videos into clips according to window size for easy batching self.frames = [] + vid_list = list(self.dset.index.videos.keys()) logger.info( - f"Generating dataset with {len(list(self.dset.index.videos.keys()))} videos" + f"Generating dataset with {len(vid_list)} videos" ) pber = tqdm( - self.dset.index.videos.keys(), - total=len(list(self.dset.index.videos.keys())), + vid_list, + total=len(vid_list), ) for vid in pber: video_dict = self.dset.index.videos[vid] diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 2c3ab7258..a281e2d18 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -126,15 +126,27 @@ def __init__( self.test_acc = Accuracy( task="multiclass", average="weighted", num_classes=num_classes ) - - self.val_f1 = F1Score(num_classes=num_classes, average="none", task="multiclass") - self.test_f1 = F1Score(num_classes=num_classes, average="none", task="multiclass") - - self.val_recall = Recall(num_classes=num_classes, average="none", task="multiclass") - self.test_recall = Recall(num_classes=num_classes, average="none", task="multiclass") - - self.val_precision = Precision(num_classes=num_classes, average="none", task="multiclass") - self.test_precision = Precision(num_classes=num_classes, average="none", task="multiclass") + + self.val_f1 = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_f1 = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + + self.val_recall = Recall( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_recall = Recall( + num_classes=num_classes, average="none", task="multiclass" + ) + + self.val_precision = Precision( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_precision = Precision( + num_classes=num_classes, average="none", task="multiclass" + ) # for averaging loss across batches self.train_loss = MeanMetric() @@ -477,9 +489,7 @@ def validation_step( windowed_ys = torch.tensor(windowed_ys).to(targets) self.val_acc(preds, targets[:, -1]) - - - + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) @@ -507,23 +517,20 @@ def on_validation_epoch_end(self) -> None: # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch best_val_acc = self.val_acc_best.compute() - + if best_val_acc > current_best_val_acc: val_f1_score = self.val_f1(all_preds, all_targets) val_recall_score = self.val_recall(all_preds, all_targets) val_precision_score = self.val_precision(all_preds, all_targets) - + # print(f"preds: {all_preds}") # print(f"all_targets: {all_targets}") print(f"validation f1 score: {val_f1_score}") print(f"validation recall score: {val_recall_score}") print(f"validation precision score: {val_precision_score}") - - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) + self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - - # Load val vidoes if self.val_frames is None: self.val_frames = {}