diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index 6a5cde3d1..6a7eed5f6 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -14,6 +14,7 @@ from pathlib import Path +from angel_system.data.medical.data_paths import GrabData from angel_system.data.common.load_data import ( activities_from_dive_csv, objs_as_dataframe, @@ -72,7 +73,8 @@ def load_yaml_as_dict(yaml_path): } -def main(task: str, ptg_root: str, config_root: str, data_type: str): +def main(task: str, ptg_root: str, config_root: str, + data_type: str, data_gen_yaml: str): """ Main function that orchestrates the process of loading configurations, setting up directories, processing datasets, and generating features for training activity classification models. @@ -104,7 +106,16 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): num_augs = 1 aug_trans_range, aug_rot_range = None, None - # the "data_type" parameter is new to the BBN lab data. old experiments dont have that parameter in their experiment name + # the "data_type" parameter is new to the BBN lab data. + # old experiments dont have that parameter in their experiment name + """ + feat_type details wrt to feature generation: + no_pose: we only use object detections, and object-hands intersection + with_pose: we use patient pose to calculate joint-hands and joint-objects offset vectors + only_hands_joints: we use patient pose to calculate only joint-hands offset vectors + only_objects_joints: we use patient pose to calulcate only joint-objects offset vectors + """ + exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}_{data_type}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] output_data_dir = f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}" @@ -159,9 +170,9 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): # ground truth preparation, feature computation, and data organization for training and evaluation. print(f"Generating features for task: {task}") - if data_type == "gyges": + if data_type == "pro": dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco']) - elif data_type == "bbn": + elif data_type == "lab": dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco_lab']) if reshuffle_datasets: @@ -169,7 +180,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): train_img_ids, val_img_ids, test_img_ids = [], [], [] ## The data directory format is different for "Professional" and lab data, so we handle the split differently - if data_type == "gyges": + if data_type == "pro": task_name = config['task'].upper() train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] @@ -189,7 +200,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] - elif data_type == "bbn": + elif data_type == "lab": task_name = config['task'].upper() all_vids = sorted(list(dset.index.name_to_video.keys())) train_vidids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['train_vid_ids_bbn']] @@ -225,10 +236,12 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): test_dset = dset.subset(gids=test_img_ids, copy=True) # again, both data types use different directory formatting. We handle both here - if data_type == "gyges": - skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" - elif data_type == "bbn": - skill_data_root = f"{config['paths']['bbn_data_dir']}/lab_data/{LAB_TASK_TO_NAME[task]}" + data_grabber = GrabData(yaml_path=data_gen_yaml) + if data_type == "pro": + # skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" + skill_data_root = f"{data_grabber.bbn_data_root}/{TASK_TO_NAME[task]}/Data" + elif data_type == "lab": + skill_data_root = f"{data_grabber.lab_bbn_data_root}/{LAB_TASK_TO_NAME[task]}" for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]): for video_id in ub.ProgIter(dset.index.videos.keys()): @@ -249,10 +262,10 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): # begin activity GT section # The GT given data is provided in different formats for the lab and professional data collections. # we handle both here. - if data_type == "gyges": + if data_type == "pro": video_root = f"{skill_data_root}/{video_name}" activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt" - elif data_type == "bbn": + elif data_type == "lab": activity_gt_file = f"{skill_data_root}/{video_name}.skill_labels_by_frame.txt" if not os.path.exists(activity_gt_file): print(f"activity_gt_file {activity_gt_file} doesnt exists. Trying a different way") @@ -268,7 +281,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): activityi_gt_list = ["background" for x in range(num_images)] - if data_type == "gyges": + if data_type == "pro": text = text.replace('\n', '\t') text_list = text.split("\t")[:-1] for index in range(0, len(text_list), 3): @@ -287,7 +300,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): for label_index in range(start_frame, min(end_frame-1, num_images)): activityi_gt_list[label_index] = gt_label - elif data_type == "bbn": + elif data_type == "lab": text = text.replace('\n', '\t') text_list = text.split("\t")#[:-1] for index in range(0, len(text_list), 4): @@ -427,10 +440,19 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str): parser.add_argument( "--data-type", - default="gyges", - help="gyges=proferssional data, bbn=use lab data", + default="pro", + help="pro=proferssional data, lab=use lab data", + type=str, + ) + + parser.add_argument( + "--data-gen-yaml", + default="/home/local/KHQ/peri.akiva/projects/angel_system/config/data_generation/bbn_gyges.yaml", + help="Path to data generation yaml file", type=str, ) args = parser.parse_args() - main(task=args.task, ptg_root=args.ptg_root, config_root=args.config_root, data_type=args.data_type) \ No newline at end of file + main(task=args.task, ptg_root=args.ptg_root, + config_root=args.config_root, data_type=args.data_type, + data_gen_yaml=args.data_gen_yaml) \ No newline at end of file