Skip to content

Commit

Permalink
Merge pull request #10 from periakiva/feat/refactor_data_paths
Browse files Browse the repository at this point in the history
clean, refactor
  • Loading branch information
periakiva authored Apr 9, 2024
2 parents a585e40 + 51e6e66 commit 5ee5262
Showing 1 changed file with 39 additions and 17 deletions.
56 changes: 39 additions & 17 deletions tcn_hpl/data/utils/ptg_datagenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pathlib import Path

from angel_system.data.medical.data_paths import GrabData
from angel_system.data.common.load_data import (
activities_from_dive_csv,
objs_as_dataframe,
Expand Down Expand Up @@ -72,7 +73,8 @@ def load_yaml_as_dict(yaml_path):
}


def main(task: str, ptg_root: str, config_root: str, data_type: str):
def main(task: str, ptg_root: str, config_root: str,
data_type: str, data_gen_yaml: str):
"""
Main function that orchestrates the process of loading configurations, setting up directories,
processing datasets, and generating features for training activity classification models.
Expand Down Expand Up @@ -104,7 +106,16 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
num_augs = 1
aug_trans_range, aug_rot_range = None, None

# the "data_type" parameter is new to the BBN lab data. old experiments dont have that parameter in their experiment name
# the "data_type" parameter is new to the BBN lab data.
# old experiments dont have that parameter in their experiment name
"""
feat_type details wrt to feature generation:
no_pose: we only use object detections, and object-hands intersection
with_pose: we use patient pose to calculate joint-hands and joint-objects offset vectors
only_hands_joints: we use patient pose to calculate only joint-hands offset vectors
only_objects_joints: we use patient pose to calulcate only joint-objects offset vectors
"""

exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}_{data_type}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose]

output_data_dir = f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}"
Expand Down Expand Up @@ -159,17 +170,17 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
# ground truth preparation, feature computation, and data organization for training and evaluation.
print(f"Generating features for task: {task}")

if data_type == "gyges":
if data_type == "pro":
dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco'])
elif data_type == "bbn":
elif data_type == "lab":
dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco_lab'])

if reshuffle_datasets:

train_img_ids, val_img_ids, test_img_ids = [], [], []

## The data directory format is different for "Professional" and lab data, so we handle the split differently
if data_type == "gyges":
if data_type == "pro":
task_name = config['task'].upper()
train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()]
val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()]
Expand All @@ -189,7 +200,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves]
test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves]

elif data_type == "bbn":
elif data_type == "lab":
task_name = config['task'].upper()
all_vids = sorted(list(dset.index.name_to_video.keys()))
train_vidids = [dset.index.name_to_video[vid_name]['id'] for vid_name in all_vids if dset.index.name_to_video[vid_name]['id'] in config['data_gen']['train_vid_ids_bbn']]
Expand Down Expand Up @@ -225,10 +236,12 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
test_dset = dset.subset(gids=test_img_ids, copy=True)

# again, both data types use different directory formatting. We handle both here
if data_type == "gyges":
skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data"
elif data_type == "bbn":
skill_data_root = f"{config['paths']['bbn_data_dir']}/lab_data/{LAB_TASK_TO_NAME[task]}"
data_grabber = GrabData(yaml_path=data_gen_yaml)
if data_type == "pro":
# skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data"
skill_data_root = f"{data_grabber.bbn_data_root}/{TASK_TO_NAME[task]}/Data"
elif data_type == "lab":
skill_data_root = f"{data_grabber.lab_bbn_data_root}/{LAB_TASK_TO_NAME[task]}"

for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]):
for video_id in ub.ProgIter(dset.index.videos.keys()):
Expand All @@ -249,10 +262,10 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
# begin activity GT section
# The GT given data is provided in different formats for the lab and professional data collections.
# we handle both here.
if data_type == "gyges":
if data_type == "pro":
video_root = f"{skill_data_root}/{video_name}"
activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt"
elif data_type == "bbn":
elif data_type == "lab":
activity_gt_file = f"{skill_data_root}/{video_name}.skill_labels_by_frame.txt"
if not os.path.exists(activity_gt_file):
print(f"activity_gt_file {activity_gt_file} doesnt exists. Trying a different way")
Expand All @@ -268,7 +281,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):

activityi_gt_list = ["background" for x in range(num_images)]

if data_type == "gyges":
if data_type == "pro":
text = text.replace('\n', '\t')
text_list = text.split("\t")[:-1]
for index in range(0, len(text_list), 3):
Expand All @@ -287,7 +300,7 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):
for label_index in range(start_frame, min(end_frame-1, num_images)):
activityi_gt_list[label_index] = gt_label

elif data_type == "bbn":
elif data_type == "lab":
text = text.replace('\n', '\t')
text_list = text.split("\t")#[:-1]
for index in range(0, len(text_list), 4):
Expand Down Expand Up @@ -427,10 +440,19 @@ def main(task: str, ptg_root: str, config_root: str, data_type: str):

parser.add_argument(
"--data-type",
default="gyges",
help="gyges=proferssional data, bbn=use lab data",
default="pro",
help="pro=proferssional data, lab=use lab data",
type=str,
)

parser.add_argument(
"--data-gen-yaml",
default="/home/local/KHQ/peri.akiva/projects/angel_system/config/data_generation/bbn_gyges.yaml",
help="Path to data generation yaml file",
type=str,
)

args = parser.parse_args()
main(task=args.task, ptg_root=args.ptg_root, config_root=args.config_root, data_type=args.data_type)
main(task=args.task, ptg_root=args.ptg_root,
config_root=args.config_root, data_type=args.data_type,
data_gen_yaml=args.data_gen_yaml)

0 comments on commit 5ee5262

Please sign in to comment.