diff --git a/configs/experiment/m2/feat_v6.yaml b/configs/experiment/m2/feat_v6.yaml index 5546bb2f0..a4fd3778d 100644 --- a/configs/experiment/m2/feat_v6.yaml +++ b/configs/experiment/m2/feat_v6.yaml @@ -86,7 +86,7 @@ data: num_workers: 16 epoch_length: 20000 window_size: 25 - sample_rate: 2 + sample_rate: 1 all_transforms: train_order: [] #["MoveCenterPts", "NormalizePixelPts"] diff --git a/configs/experiment/m5/feat_v6.yaml b/configs/experiment/m5/feat_v6.yaml index 5c08d692b..dbbd216f4 100644 --- a/configs/experiment/m5/feat_v6.yaml +++ b/configs/experiment/m5/feat_v6.yaml @@ -73,7 +73,7 @@ data: num_workers: 12 epoch_length: 20000 window_size: 25 - sample_rate: 2 + sample_rate: 1 all_transforms: diff --git a/tcn_hpl/data/add_gt_to_kwcoco.py b/tcn_hpl/data/add_gt_to_kwcoco.py new file mode 100644 index 000000000..baaeeda53 --- /dev/null +++ b/tcn_hpl/data/add_gt_to_kwcoco.py @@ -0,0 +1,136 @@ +import argparse + +import kwcoco +import numpy as np +import numpy.typing as npt +import tcn_hpl.utils.utils as utils +import ubelt as ub +import yaml + +from angel_system.data.medical.data_paths import LAB_TASK_TO_NAME + + +def text_to_labels( + text_file: str, num_frames: int, task: str, mapping: dict +) -> npt.NDArray[int]: + """ + Convert a "skill_labels_by_frame" text truth file from BBN into labels for + the given task and number of frames. + + :param text_file: Filesystem path to the BBN activity text file. + :param num_frames: Number of frames in the video the truth file is related + to. + :param task: The identifying name of the task, e.g. "m2", "m3", "r18", etc. + :param mapping: Mapping of task step descriptions to the integer label + value for that step. + :return: + """ + # set background to everything first (assuming value 0). + activity_gt_list = np.zeros(num_frames) + f = open(text_file, "r") + text = f.read() + f.close() + text = text.replace("\n", "\t") + text_list = text.split("\t") + if text_list[-1] == "": + text_list = text_list[:-1] + + # this check handles inconsistencies in the GT we get from BBN + if task == "r18" or task == "m3": + jump = 4 + elif task == "m2" or task == "m5": + jump = 3 + + for index in range(0, len(text_list), jump): + triplet = text_list[index : index + jump] + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[jump - 1] + + gt_label = mapping[desc] + + if end_frame - 1 > num_frames: + ### address issue with GT activity labels + print("Max frame in GT is larger than number of frames in the video") + + for label_index in range(start_frame, min(end_frame - 1, num_frames)): + activity_gt_list[label_index] = gt_label + + return activity_gt_list + + +def main(config_path: str): + with open(config_path, "r") as f: + config = yaml.load(f, Loader=yaml.FullLoader) + + task_name = config["task"] + raw_data_root = ( + f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/" + ) + + dset = kwcoco.CocoDataset(config["data_gen"]["dataset_kwcoco"]) + + with open(config["data_gen"]["activity_config_fn"], "r") as stream: + activity_config = yaml.safe_load(stream) + activity_labels = activity_config["labels"] + + activity_labels_desc_mapping = {} + activity_labels_label_mapping = {} + for label in activity_labels: + i = label["id"] + label_str = label["label"] + if "description" in label.keys(): + activity_labels_desc_mapping[label["description"]] = label["id"] + elif "full_str" in label.keys(): + activity_labels_desc_mapping[label["full_str"]] = label["id"] + activity_labels_label_mapping[label_str] = label["id"] + if label_str == "done": + continue + + gt_paths_to_names_dict = {} + gt_paths = utils.dictionary_contents(raw_data_root, types=["*.txt"]) + for gt_path in gt_paths: + name = gt_path.split("/")[-1].split(".")[0] + gt_paths_to_names_dict[name] = gt_path + + print(gt_paths_to_names_dict) + + if not "activity_gt" in list(dset.imgs.values())[0].keys(): + print("adding activity ground truth to the dataset") + for video_id in ub.ProgIter(dset.index.videos.keys()): + video = dset.index.videos[video_id] + video_name = video["name"] + if video_name in gt_paths_to_names_dict.keys(): + gt_text = gt_paths_to_names_dict[video_name] + else: + print(f"GT file does not exist for {video_name}. Continue...") + continue + + image_ids = dset.index.vidid_to_gids[video_id] + num_frames = len(image_ids) + + activity_gt_list = text_to_labels( + gt_text, num_frames, task_name, activity_labels_desc_mapping + ) + + for index, img_id in enumerate(image_ids): + im = dset.index.imgs[img_id] + frame_index = int(im["frame_index"]) + dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index] + + dset.dump("test.mscoco.json", newlines=True) + print(activity_labels_desc_mapping) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--config", + default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml", + help="", + ) + + args = parser.parse_args() + + main(args.config) diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py new file mode 100644 index 000000000..d0d3d4590 --- /dev/null +++ b/tcn_hpl/data/tcn_dataset.py @@ -0,0 +1,141 @@ +import torch +from torch.utils.data import Dataset +import kwcoco +from tqdm import tqdm +import logging +import tcn_hpl.utils.utils as utils + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO) + + +def standarize_online_inputs(inputs): + pass + + +def standarize_offline_inputs(data): + dset = data[0] + frames = data[1] + + +def collect_inputs(data, offline=True): + """ + Collects inputs from either an offline or real-time source. + + Args: + data: data stream input. For offline, a coco dataset. For online, a set of ROS messages + + offline (bool): If True, fetches data from offline source; otherwise, real-time. + + Returns: + inputs: The collected inputs from either a real-time or offline source in standarized format. + Designed as a dict of lists. all lists must have the size of the widnow_size: + {object_dets: [], pose_estimations:[], left_hand: [], right_hand: []} + + """ + if offline: + inputs = standarize_offline_inputs(data) + else: + inputs = standarize_online_inputs(data) + + return inputs + + +def define_tcn_vector(inputs): + """ + Define the TCN vector using the collected inputs. + + Args: + inputs: The inputs collected from either real-time or offline source. + + Returns: + tcn_vector: The defined TCN vector. + """ + tcn_vector = torch.tensor(inputs).float() # Dummy transformation + return tcn_vector + + +class TCNDataset(Dataset): + def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int): + """ + Initializes the dataset. + + Args: + kwcoco_path: The source of data (can be real-time or offline). + sample_rate: + window_size: + """ + self.sample_rate = sample_rate + self.window_size = window_size + + self.dset = kwcoco.CocoDataset(kwcoco_path) + + # for offline training, pre-cut videos into clips according to window size for easy batching + self.frames = [] + + vid_list = list(self.dset.index.videos.keys()) + logger.info( + f"Generating dataset with {len(vid_list)} videos" + ) + pber = tqdm( + vid_list, + total=len(vid_list), + ) + for vid in pber: + video_dict = self.dset.index.videos[vid] + + vid_frames = self.dset.index.vidid_to_gids[vid] + + for index in range(0, len(vid_frames) - window_size - 1, sample_rate): + video_slice = vid_frames[index : index + window_size] + window_frame_dicts = [self.dset.index.imgs[gid] for gid in video_slice] + + # start_frame = window_frame_dicts[0]['frame_index'] + # end_frame = window_frame_dicts[-1]['frame_index'] + + # n_frames = end_frame - start_frame + 1 + + self.frames.append(window_frame_dicts) + + def __getitem__(self, index): + """ + Fetches the data point and defines its TCN vector. + + Args: + index: The index of the data point. + + Returns: + tcn_vector: The TCN vector for the given index. + """ + + data = self.frames[index] + + inputs = collect_inputs([self.dset, data], offline=True) + + # tcn_vector = define_tcn_vector(inputs) + + return + + def __len__(self): + """ + Returns the length of the dataset. + + Returns: + length: Length of the dataset. + """ + return len(self.frames) + + +if __name__ == "__main__": + # Example usage: + kwcoco_path = "/data/PTG/medical/training/yolo_object_detector/detect/r18_all/r18_all_all_obj_results_with_dets_and_pose.mscoco.json" + + dataset = TCNDataset(kwcoco_path=kwcoco_path, sample_rate=1, window_size=25) + + print(f"dataset: {len(dataset)}") + data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True) + + for index, batch in enumerate(data_loader): + print(batch) # This will print the TCN vectors for the batch + if index > 15: + exit() diff --git a/tcn_hpl/eval.py b/tcn_hpl/eval.py index 841d79d17..fc7be4b8f 100644 --- a/tcn_hpl/eval.py +++ b/tcn_hpl/eval.py @@ -2,10 +2,12 @@ import hydra import rootutils -from lightning import LightningDataModule, LightningModule, Trainer -from lightning.pytorch.loggers import Logger from omegaconf import DictConfig +import pytorch_lightning as L +from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning.loggers import Logger + rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index fd5ab8d8b..a281e2d18 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -10,6 +10,7 @@ from torchmetrics import MaxMetric, MeanMetric from torchmetrics.classification.accuracy import Accuracy +from torchmetrics.classification import F1Score, Recall, Precision from torchmetrics.classification import MulticlassConfusionMatrix from sklearn.metrics import confusion_matrix @@ -126,6 +127,27 @@ def __init__( task="multiclass", average="weighted", num_classes=num_classes ) + self.val_f1 = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_f1 = F1Score( + num_classes=num_classes, average="none", task="multiclass" + ) + + self.val_recall = Recall( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_recall = Recall( + num_classes=num_classes, average="none", task="multiclass" + ) + + self.val_precision = Precision( + num_classes=num_classes, average="none", task="multiclass" + ) + self.test_precision = Precision( + num_classes=num_classes, average="none", task="multiclass" + ) + # for averaging loss across batches self.train_loss = MeanMetric() self.val_loss = MeanMetric() @@ -467,6 +489,7 @@ def validation_step( windowed_ys = torch.tensor(windowed_ys).to(targets) self.val_acc(preds, targets[:, -1]) + self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) @@ -480,13 +503,6 @@ def on_validation_epoch_end(self) -> None: "Lightning hook that is called when a validation epoch ends." acc = self.val_acc.compute() # get current val acc - if self.current_epoch >= 15: - self.val_acc_best(acc) # update best so far val acc - # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object - # otherwise metric would be reset by lightning after each epoch - best_val_acc = self.val_acc_best.compute() - self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) - all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames all_probs = torch.cat( @@ -495,6 +511,26 @@ def on_validation_epoch_end(self) -> None: all_source_vids = torch.cat(self.validation_step_outputs_source_vid) all_source_frames = torch.cat(self.validation_step_outputs_source_frame) + if self.current_epoch >= 15: + current_best_val_acc = self.val_acc_best.compute() + self.val_acc_best(acc) # update best so far val acc + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch + best_val_acc = self.val_acc_best.compute() + + if best_val_acc > current_best_val_acc: + val_f1_score = self.val_f1(all_preds, all_targets) + val_recall_score = self.val_recall(all_preds, all_targets) + val_precision_score = self.val_precision(all_preds, all_targets) + + # print(f"preds: {all_preds}") + # print(f"all_targets: {all_targets}") + print(f"validation f1 score: {val_f1_score}") + print(f"validation recall score: {val_recall_score}") + print(f"validation precision score: {val_precision_score}") + + self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) + # Load val vidoes if self.val_frames is None: self.val_frames = {} diff --git a/tcn_hpl/train.py b/tcn_hpl/train.py index fff4f05d2..2e7e90da3 100644 --- a/tcn_hpl/train.py +++ b/tcn_hpl/train.py @@ -84,11 +84,12 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - # trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) - trainer.fit(model=model, datamodule=datamodule) + trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + # trainer.fit(model=model, datamodule=datamodule) train_metrics = trainer.callback_metrics + if cfg.get("test"): log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path