From 96feea0c128695aa3762ad301059e8dfe0b1c9e2 Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Mon, 1 Apr 2024 16:07:36 -0400 Subject: [PATCH 1/3] clean up, debugging --- tcn_hpl/data/utils/ptg_datagenerator.py | 7 +++++-- tcn_hpl/models/components/ms_tcs_net.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index 4af580deb..b4d12afb6 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -106,7 +106,10 @@ def main(task: str): for label in activity_labels: i = label["id"] label_str = label["label"] - activity_labels_desc_mapping[label['description']] = label["label"] + if "description" in label.keys(): + activity_labels_desc_mapping[label['description']] = label["label"] + elif "full_str" in label.keys(): + activity_labels_desc_mapping[label['full_str']] = label["label"] if label_str == "done": continue mapping.write(f"{i} {label_str}\n") @@ -344,5 +347,5 @@ def main(task: str): print(f"Saved training data to {output_data_dir}") if __name__ == '__main__': - task = "m3" + task = "r18" main(task=task) \ No newline at end of file diff --git a/tcn_hpl/models/components/ms_tcs_net.py b/tcn_hpl/models/components/ms_tcs_net.py index 7a4344be4..13e1e6762 100644 --- a/tcn_hpl/models/components/ms_tcs_net.py +++ b/tcn_hpl/models/components/ms_tcs_net.py @@ -64,7 +64,8 @@ def __init__(self, # self.fc = nn.Linear(1280, 2048) def forward(self, x, mask): - b, d, c = x.shape + b, d, c = x.shape #[batch_size, feat_dim, window_size] + # mask shape: [batch_size, window_size] # print(f"x: {x.shape}") # print(f"mask: {mask.shape}") From 1ebfa47361e9c21e2f06523b37c4c0a5339daf10 Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Mon, 8 Apr 2024 17:01:55 -0400 Subject: [PATCH 2/3] bbn lab data pipeline --- configs/experiment/m3/feat_v6.yaml | 6 +- configs/experiment/r18/feat_v6.yaml | 11 +- configs/model/ptg.yaml | 2 +- .../utils/pose_generation/configs/main.yaml | 4 + .../pose_generation/generate_pose_data.py | 16 +- tcn_hpl/data/utils/ptg_datagenerator.py | 182 +++++++++++++----- tcn_hpl/models/ptg_module.py | 14 +- 7 files changed, 173 insertions(+), 62 deletions(-) diff --git a/configs/experiment/m3/feat_v6.yaml b/configs/experiment/m3/feat_v6.yaml index 372fd22dd..d90fb84bc 100644 --- a/configs/experiment/m3/feat_v6.yaml +++ b/configs/experiment/m3/feat_v6.yaml @@ -38,7 +38,7 @@ model: dim: 188 # length of feature vector data: - num_classes: 9 # activities: includes background + num_classes: 6 # activities: includes background batch_size: 512 num_workers: 0 epoch_length: 20000 @@ -51,12 +51,12 @@ data: MoveCenterPts: feat_version: 6 - num_obj_classes: 9 # not including background + num_obj_classes: 4 # not including background NormalizeFromCenter: feat_version: 6 NormalizePixelPts: feat_version: 6 - num_obj_classes: 9 # not including background + num_obj_classes: 4 # not including background data_gen: reshuffle_datasets: true diff --git a/configs/experiment/r18/feat_v6.yaml b/configs/experiment/r18/feat_v6.yaml index 0595d6446..1ad29e0b4 100644 --- a/configs/experiment/r18/feat_v6.yaml +++ b/configs/experiment/r18/feat_v6.yaml @@ -14,7 +14,8 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose" #[_v2_aug_False] -exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] +# exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] +exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] @@ -70,14 +71,20 @@ data_gen: val_vid_ids: [3, 7, 10, 18, 27, 32, 41] test_vid_ids: [50, 13, 47, 25] + train_vid_ids_bbn: [1, 2, 3, 4, 5, 6, 7, 8] + val_vid_ids_bbn: [9] + test_vid_ids_bbn: [10] + paths: - data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] + # data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] + data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" dataset_kwcoco: "/data/users/peri.akiva/datasets/ptg/r18_all_all_obj_results_with_dets_and_pose.mscoco.json" + dataset_kwcoco_lab: "/data/users/peri.akiva/datasets/ptg/r18_all_bbn_lab_data_all_obj_results_with_dets_and_pose.mscoco.json" activity_config_root: "/home/local/KHQ/peri.akiva/projects/angel_system/config/activity_labels" activity_config_fn: "${paths.activity_config_root}/${task}" ptg_root: "/home/local/KHQ/peri.akiva/angel_system" diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index b78c980f3..53d64a36a 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -3,7 +3,7 @@ _target_: tcn_hpl.models.ptg_module.PTGLitModule optimizer: _target_: torch.optim.Adam _partial_: true - lr: 0.0001 + lr: 0.00001 weight_decay: 0.0001 scheduler: diff --git a/tcn_hpl/data/utils/pose_generation/configs/main.yaml b/tcn_hpl/data/utils/pose_generation/configs/main.yaml index 5b2b4e809..dd88734f5 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/main.yaml +++ b/tcn_hpl/data/utils/pose_generation/configs/main.yaml @@ -1,4 +1,5 @@ task: r18 +data_type: bbn #[bbn, gyges] root: /data/PTG/medical/bbn_data/Release_v0.5/v0.52/M2_Tourniquet/Data img_save_path: /data/users/peri.akiva/datasets/m2_tourniquet/imgs pose_model_config: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py @@ -20,3 +21,6 @@ data: m5: /data/PTG/medical/training/yolo_object_detector/detect//m5_all/m5_all_all_obj_results.mscoco.json r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all/r18_all_all_obj_results.mscoco.json save_root: /data/users/peri.akiva/datasets/ptg +bbn_lab: + r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all_bbn_lab_data/r18_all_bbn_lab_data_all_obj_results.mscoco.json + save_root: /data/users/peri.akiva/datasets/ptg \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/generate_pose_data.py b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py index 331a21823..815239b6c 100644 --- a/tcn_hpl/data/utils/pose_generation/generate_pose_data.py +++ b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py @@ -43,10 +43,16 @@ def __init__(self, config: dict) -> None: self.config = config self.root_path = config['root'] - self.dataset = kwcoco.CocoDataset(config['data'][config['task']]) + if config['data_type'] == "bbn": + self.config_data_key = "bbn_lab" + else: + self.config_data_key = "data" + + self.dataset = kwcoco.CocoDataset(config[self.config_data_key][config['task']]) self.patient_cid = self.dataset.add_category('patient') self.user_cid = self.dataset.add_category('user') + self.keypoints_cats = [ "nose", "mouth", "throat","chest","stomach","left_upper_arm", "right_upper_arm","left_lower_arm","right_lower_arm","left_wrist", @@ -59,7 +65,7 @@ def __init__(self, config: dict) -> None: self.dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset - self.dataset_path_name = self.config['data'][self.config['task']][:-12].split('/')[-1] #remove .mscoco.json + self.dataset_path_name = self.config[self.config_data_key][self.config['task']][:-12].split('/')[-1] #remove .mscoco.json self.args = get_parser(self.config['detection_model_config']).parse_args() detecron_cfg = setup_detectron_cfg(self.args) @@ -257,13 +263,13 @@ def generate_bbs_and_pose(self, dset: kwcoco.CocoDataset, save_intermediate: boo # import matplotlib.pyplot as plt # image_show = dset.draw_image(gid=img_id) # plt.imshow(image_show) - # plt.savefig(f"myfig_{self.config['task']}_{index}.png") + # plt.savefig(f"figs/myfig_{self.config['task']}_{index}.png") # if index >= 20: # exit() if save_intermediate: if (index % 45000) == 0: - dset_inter_name = f"{self.config['data']['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json" + dset_inter_name = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json" self.dataset.dump(dset_inter_name, newlines=True) print(f"Saved intermediate dataset at index {index} to: {dset_inter_name}") @@ -293,7 +299,7 @@ def run(self) -> None: """ self.dataset = self.generate_bbs_and_pose(self.dataset) - dataset_path_with_pose = f"{self.config['data']['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json" + dataset_path_with_pose = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json" self.dataset.dump(dataset_path_with_pose, newlines=True) print(f"Saved test dataset to: {dataset_path_with_pose}") return diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index b4d12afb6..99d9ef681 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -40,6 +40,13 @@ def load_yaml_as_dict(yaml_path): 'r18': "R18_Chest_Seal", } +LAB_TASK_TO_NAME = { + 'm2': "M2_Lab_Skills", + 'm3': "M3_Lab_Skills", + 'm5': "M5_Lab_Skills", + 'r18': "R18_Lab_Skills", +} + FEAT_TO_BOOL = { "no_pose": [False, False], "with_pose": [True, True], @@ -47,11 +54,11 @@ def load_yaml_as_dict(yaml_path): "only_objects_joints": [True, False] } -def main(task: str): - config_path = f"/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/{task}/feat_v6.yaml" + +def main(task: str, ptg_root: str, config_root: str, data_type: str): + config_path = f"{config_root}/experiment/{task}/feat_v6.yaml" config = load_yaml_as_dict(config_path) - ptg_root = "/home/local/KHQ/peri.akiva/angel_system" activity_config_path = f"{ptg_root}/config/activity_labels/medical" feat_version = 6 @@ -72,7 +79,7 @@ def main(task: str): aug_trans_range, aug_rot_range = None, None - exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] + exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}_{data_type}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] output_data_dir = f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}" @@ -121,30 +128,47 @@ def main(task: str): ##################### ### create splits dsets - dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco']) + print(f"Generating features for task: {task}") + + if data_type == "gyges": + dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco']) + elif data_type == "bbn": + dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco_lab']) if reshuffle_datasets: train_img_ids, val_img_ids, test_img_ids = [], [], [] - - task_name = config['task'].upper() - train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - test_vivds = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['test_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - if config['data_gen']['filter_black_gloves']: - vidids_black_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_black_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + if data_type == "gyges": + task_name = config['task'].upper() + train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + test_vivds = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['test_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - train_vidids = [x for x in train_vidids if x not in vidids_black_gloves] - val_vivids = [x for x in val_vivids if x not in vidids_black_gloves] - test_vivds = [x for x in test_vivds if x not in vidids_black_gloves] - - if config['data_gen']['filter_blue_gloves']: - vidids_blue_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_blue_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - train_vidids = [x for x in train_vidids if x not in vidids_blue_gloves] - val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] - test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] + if config['data_gen']['filter_black_gloves']: + vidids_black_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_black_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + + train_vidids = [x for x in train_vidids if x not in vidids_black_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_black_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_black_gloves] + + if config['data_gen']['filter_blue_gloves']: + vidids_blue_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_blue_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + + train_vidids = [x for x in train_vidids if x not in vidids_blue_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] + + elif data_type == "bbn": + task_name = config['task'].upper() + print(f"task name: {task_name}") + print(f"dset.index.name_to_video.keys(): {dset.index.name_to_video}") + all_vids = sorted(list(dset.index.name_to_video.keys())) + print(f"all vids: {all_vids}") + train_vidids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['train_vid_ids_bbn']] + val_vivids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['val_vid_ids_bbn']] + test_vivds = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['test_vid_ids_bbn']] + # val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids_bbn'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] + # test_vivds = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['test_vid_ids_bbn'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] ## individual splits by gids @@ -192,7 +216,10 @@ def main(task: str): val_dset = dset.subset(gids=val_img_ids, copy=True) test_dset = dset.subset(gids=test_img_ids, copy=True) - skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" + if data_type == "gyges": + skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" + elif data_type == "bbn": + skill_data_root = f"{config['paths']['bbn_data_dir']}/lab_data/{LAB_TASK_TO_NAME[task]}" # videos_names = os.listdir(skill_data_root) # for split in ["train_activity", "val", "test"]: for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]): @@ -217,37 +244,70 @@ def main(task: str): video_dset = dset.subset(gids=image_ids, copy=True) #### begin activity GT section - video_root = f"{skill_data_root}/{video_name}" - activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt" + if data_type == "gyges": + video_root = f"{skill_data_root}/{video_name}" + activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt" + elif data_type == "bbn": + activity_gt_file = f"{skill_data_root}/{video_name}.skill_labels_by_frame.txt" + if not os.path.exists(activity_gt_file): + print(f"activity_gt_file {activity_gt_file} doesnt exists. Trying a different way") + activity_gt_file = f"{skill_data_root}/{video_name}.txt" if not os.path.exists(activity_gt_file): + print(f"activity_gt_file {activity_gt_file} doesnt exists. continueing") continue f = open(activity_gt_file, "r") text = f.read() f.close() - # print(f"str type: {type(text)}") - text = text.replace('\n', '\t') - text_list = text.split("\t")[:-1] + + # print(f"str type: {text}") activityi_gt_list = ["background" for x in range(num_images)] - for index in range(0, len(text_list), 3): - triplet = text_list[index:index+3] - # print(f"index: {index}, {text_list[index]}") - # print(f"triplet: {text_list[index:index+3]}") - start_frame = int(triplet[0]) - end_frame = int(triplet[1]) - desc = triplet[2] - gt_label = activity_labels_desc_mapping[desc] - - if end_frame-1 > num_images: - ### address issue with GT activity labels - print("Max frame in GT is larger than number of frames in the video") + # print(f"activity_gt_file: {activity_gt_file}") + # print(f"activity_labels_desc_mapping: {activity_labels_desc_mapping}") + # print(f"text_list: {text_list}") + if data_type == "gtges": + text = text.replace('\n', '\t') + text_list = text.split("\t")[:-1] + for index in range(0, len(text_list), 3): + triplet = text_list[index:index+3] + # print(f"index: {index}, {text_list[index]}") + # print(f"triplet: {text_list[index:index+3]}") + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[2] + gt_label = activity_labels_desc_mapping[desc] + + if end_frame-1 > num_images: + ### address issue with GT activity labels + print("Max frame in GT is larger than number of frames in the video") + + for label_index in range(start_frame, min(end_frame-1, num_images)): + # print(f"label_index: {label_index}") + activityi_gt_list[label_index] = gt_label + + elif data_type == "bbn": + text = text.replace('\n', '\t') + text_list = text.split("\t")#[:-1] + # print(f"text_list: {text_list}") + for index in range(0, len(text_list), 4): + triplet = text_list[index:index+4] + # print(f"index: {index}, {text_list[index]}") + # print(f"triplet: {text_list[index:index+4]}") + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[3] + gt_label = activity_labels_desc_mapping[desc] - for label_index in range(start_frame, min(end_frame-1, num_images)): - # print(f"label_index: {label_index}") - activityi_gt_list[label_index] = gt_label + if end_frame-1 > num_images: + ### address issue with GT activity labels + print("Max frame in GT is larger than number of frames in the video") + + for label_index in range(start_frame, min(end_frame-1, num_images)): + # print(f"label_index: {label_index}") + activityi_gt_list[label_index] = gt_label # print(f"start: {start_frame}, end: {end_frame}, label: {gt_label}, activityi_gt_list: {len(activityi_gt_list)}, num images: {num_images}") @@ -347,5 +407,35 @@ def main(task: str): print(f"Saved training data to {output_data_dir}") if __name__ == '__main__': - task = "r18" - main(task=task) \ No newline at end of file + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + help="Object detections in kwcoco format for the train set", + type=str, + ) + + parser.add_argument( + "--ptg-root", + default='/home/local/KHQ/peri.akiva/angel_system', + help="root to angel_system", + type=str, + ) + + parser.add_argument( + "--config-root", + default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs", + help="root to TCN configs", + type=str, + ) + + parser.add_argument( + "--data-type", + default="gyges", + help="gyges=proferssional data, bbn=use lab data", + type=str, + ) + + args = parser.parse_args() + main(task=args.task, ptg_root=args.ptg_root, config_root=args.config_root, data_type=args.data_type) \ No newline at end of file diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index e907221c9..634589a83 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -90,6 +90,7 @@ def __init__( self.save_hyperparameters(logger=False) self.net = net + # Get Action Names mapping_file = f"{self.hparams.data_dir}/{mapping_file_name}" @@ -102,6 +103,9 @@ def __init__( self.class_ids = list(actions_dict.values()) self.classes = list(actions_dict.keys()) + # print(f"CLASSES IN MODEL: {self.classes}") + # print(f"actions_dict: {actions_dict}") + # exit() self.action_id_to_str = dict(zip(self.class_ids, self.classes)) # loss functions @@ -661,11 +665,11 @@ def on_test_epoch_end(self) -> None: """Lightning hook that is called when a test epoch ends.""" # update and log metrics - all_targets = torch.concat(self.validation_step_outputs_target) # shape: #frames - all_preds = torch.concat(self.validation_step_outputs_pred) # shape: #frames - all_probs = torch.concat(self.validation_step_outputs_prob) # shape (#frames, #act labels) - all_source_vids = torch.concat(self.validation_step_outputs_source_vid) - all_source_frames = torch.concat(self.validation_step_outputs_source_frame) + all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames + all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames + all_probs = torch.cat(self.validation_step_outputs_prob) # shape (#frames, #act labels) + all_source_vids = torch.cat(self.validation_step_outputs_source_vid) + all_source_frames = torch.cat(self.validation_step_outputs_source_frame) # Load test vidoes if self.test_frames is None: From ac1cb8d26642a75ce80a8b23aab663585a561afc Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Tue, 9 Apr 2024 12:54:28 -0400 Subject: [PATCH 3/3] documentation and comments --- configs/model/ptg.yaml | 2 +- tcn_hpl/data/utils/ptg_datagenerator.py | 93 ++++++++++++------------- 2 files changed, 45 insertions(+), 50 deletions(-) diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index 53d64a36a..7df107add 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -3,7 +3,7 @@ _target_: tcn_hpl.models.ptg_module.PTGLitModule optimizer: _target_: torch.optim.Adam _partial_: true - lr: 0.00001 + lr: 0.000001 weight_decay: 0.0001 scheduler: diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index 99d9ef681..6a5cde3d1 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -1,3 +1,7 @@ +# This script is designed for processing and preparing datasets for training activity classifiers. +# It performs several key functions: reading configuration files, generating feature matrices, +# preparing ground truth labels, and organizing the data into a structured format suitable for th TCN model. + import os import yaml import glob @@ -20,9 +24,19 @@ data_loader, compute_feats, ) -# from angel_system.data.data_paths import grab_data, data_dir + + def load_yaml_as_dict(yaml_path): + """ + Loads a YAML file and returns it as a Python dictionary. + + Args: + yaml_path (str): The path to the YAML configuration file. + + Returns: + dict: The YAML file's content as a dictionary. + """ with open(yaml_path, 'r') as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) return config_dict @@ -31,6 +45,7 @@ def load_yaml_as_dict(yaml_path): # Inputs ##################### +# Mapping task identifiers to their descriptive names. TASK_TO_NAME = { 'm1': "M1_Trauma_Assessment", 'm2': "M2_Tourniquet", @@ -40,6 +55,7 @@ def load_yaml_as_dict(yaml_path): 'r18': "R18_Chest_Seal", } +# Mapping lab data task identifiers to their descriptive names. LAB_TASK_TO_NAME = { 'm2': "M2_Lab_Skills", 'm3': "M3_Lab_Skills", @@ -47,6 +63,7 @@ def load_yaml_as_dict(yaml_path): 'r18': "R18_Lab_Skills", } +# Mapping feature settings to boolean flags indicating the inclusion of pose or object joint information. FEAT_TO_BOOL = { "no_pose": [False, False], "with_pose": [True, True], @@ -56,6 +73,16 @@ def load_yaml_as_dict(yaml_path): def main(task: str, ptg_root: str, config_root: str, data_type: str): + """ + Main function that orchestrates the process of loading configurations, setting up directories, + processing datasets, and generating features for training activity classification models. + + Args: + task (str): The task identifier. + ptg_root (str): Path to the root of the angel_system project. + config_root (str): Path to the root of the configuration files. + data_type (str): Specifies the type of data, either 'gyges' for professional data or 'bbn' for lab data. + """ config_path = f"{config_root}/experiment/{task}/feat_v6.yaml" config = load_yaml_as_dict(config_path) @@ -65,12 +92,11 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): ##################### # Output ##################### - # reshuffle_datasets = True - # augment = False - # num_augs = 5 reshuffle_datasets, augment, num_augs = config['data_gen']['reshuffle_datasets'], config['data_gen']['augment'], config['data_gen']['num_augs'] + + ## augmentation is not currently used. if augment: num_augs = num_augs aug_trans_range, aug_rot_range = [-5, 5], [-5, 5] @@ -78,7 +104,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): num_augs = 1 aug_trans_range, aug_rot_range = None, None - + # the "data_type" parameter is new to the BBN lab data. old experiments dont have that parameter in their experiment name exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}_{data_type}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] output_data_dir = f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}" @@ -127,7 +153,10 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): # bundles ##################### ### create splits dsets - + + # Data processing and feature generation based on the data type (gyges or bbn). + # The detailed implementation handles data loading, augmentation (if specified), + # ground truth preparation, feature computation, and data organization for training and evaluation. print(f"Generating features for task: {task}") if data_type == "gyges": @@ -138,6 +167,8 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): if reshuffle_datasets: train_img_ids, val_img_ids, test_img_ids = [], [], [] + + ## The data directory format is different for "Professional" and lab data, so we handle the split differently if data_type == "gyges": task_name = config['task'].upper() train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] @@ -160,32 +191,10 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): elif data_type == "bbn": task_name = config['task'].upper() - print(f"task name: {task_name}") - print(f"dset.index.name_to_video.keys(): {dset.index.name_to_video}") all_vids = sorted(list(dset.index.name_to_video.keys())) - print(f"all vids: {all_vids}") train_vidids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['train_vid_ids_bbn']] val_vivids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['val_vid_ids_bbn']] test_vivds = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['test_vid_ids_bbn']] - # val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids_bbn'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - # test_vivds = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['test_vid_ids_bbn'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - - ## individual splits by gids - # total = len(dset.index.imgs) - # inds = [i for i in range(1, total+1)] - # train_size, val_size = int(0.8*total), int(0.1*total) - # test_size = total - train_size - val_size - # train_inds = set(list(np.random.choice(inds, size=train_size, replace=False))) - # remaining_inds = list(set(inds) - train_inds) - # val_inds = set(list(np.random.choice(remaining_inds, size=val_size, replace=False))) - # test_inds = list(set(remaining_inds) - val_inds) - # # test_inds = list(np.random.choice(remaining_inds, size=test_size, replace=False)) - - # train_img_ids = [dset.index.imgs[i]['id'] for i in train_inds] - # val_img_ids = [dset.index.imgs[i]['id'] for i in val_inds] - # test_img_ids = [dset.index.imgs[i]['id'] for i in test_inds] - # print(f"train: {len(train_inds)}, val: {len(val_inds)}, test: {len(test_inds)}") for vid in train_vidids: if vid in dset.index.vidid_to_gids.keys(): @@ -199,7 +208,6 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): val_img_ids.extend(list(dset.index.vidid_to_gids[vid])) else: print(f"{vid} not in the val dataset") - # val_img_ids = set(val_img_ids) + set(dset.index.vidid_to_gids[vid]) for vid in test_vivds: @@ -216,20 +224,15 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): val_dset = dset.subset(gids=val_img_ids, copy=True) test_dset = dset.subset(gids=test_img_ids, copy=True) + # again, both data types use different directory formatting. We handle both here if data_type == "gyges": skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" elif data_type == "bbn": skill_data_root = f"{config['paths']['bbn_data_dir']}/lab_data/{LAB_TASK_TO_NAME[task]}" - # videos_names = os.listdir(skill_data_root) - # for split in ["train_activity", "val", "test"]: + for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]): - # org_num_vidoes = len(list(dset.index.videos.keys())) - # new_num_videos = 0 for video_id in ub.ProgIter(dset.index.videos.keys()): - # if video_id != 10: - # continue - video = dset.index.videos[video_id] video_name = video["name"] @@ -243,7 +246,9 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): video_dset = dset.subset(gids=image_ids, copy=True) - #### begin activity GT section + # begin activity GT section + # The GT given data is provided in different formats for the lab and professional data collections. + # we handle both here. if data_type == "gyges": video_root = f"{skill_data_root}/{video_name}" activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt" @@ -261,14 +266,9 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): text = f.read() f.close() - - # print(f"str type: {text}") - activityi_gt_list = ["background" for x in range(num_images)] - # print(f"activity_gt_file: {activity_gt_file}") - # print(f"activity_labels_desc_mapping: {activity_labels_desc_mapping}") - # print(f"text_list: {text_list}") - if data_type == "gtges": + + if data_type == "gyges": text = text.replace('\n', '\t') text_list = text.split("\t")[:-1] for index in range(0, len(text_list), 3): @@ -285,17 +285,13 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): print("Max frame in GT is larger than number of frames in the video") for label_index in range(start_frame, min(end_frame-1, num_images)): - # print(f"label_index: {label_index}") activityi_gt_list[label_index] = gt_label elif data_type == "bbn": text = text.replace('\n', '\t') text_list = text.split("\t")#[:-1] - # print(f"text_list: {text_list}") for index in range(0, len(text_list), 4): triplet = text_list[index:index+4] - # print(f"index: {index}, {text_list[index]}") - # print(f"triplet: {text_list[index:index+4]}") start_frame = int(triplet[0]) end_frame = int(triplet[1]) desc = triplet[3] @@ -335,7 +331,6 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): #### end activity GT section - # features ( act_map,