Skip to content

Commit

Permalink
Merge pull request #9 from periakiva/check_feature_vector
Browse files Browse the repository at this point in the history
Check feature vector
  • Loading branch information
periakiva authored Apr 9, 2024
2 parents 2b9737f + ac1cb8d commit a585e40
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 96 deletions.
6 changes: 3 additions & 3 deletions configs/experiment/m3/feat_v6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ model:
dim: 188 # length of feature vector

data:
num_classes: 9 # activities: includes background
num_classes: 6 # activities: includes background
batch_size: 512
num_workers: 0
epoch_length: 20000
Expand All @@ -51,12 +51,12 @@ data:

MoveCenterPts:
feat_version: 6
num_obj_classes: 9 # not including background
num_obj_classes: 4 # not including background
NormalizeFromCenter:
feat_version: 6
NormalizePixelPts:
feat_version: 6
num_obj_classes: 9 # not including background
num_obj_classes: 4 # not including background

data_gen:
reshuffle_datasets: true
Expand Down
11 changes: 9 additions & 2 deletions configs/experiment/r18/feat_v6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ defaults:
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose" #[_v2_aug_False]
exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False]
# exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False]
exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False]
# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False]
# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False]
# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False]
Expand Down Expand Up @@ -70,14 +71,20 @@ data_gen:
val_vid_ids: [3, 7, 10, 18, 27, 32, 41]
test_vid_ids: [50, 13, 47, 25]

train_vid_ids_bbn: [1, 2, 3, 4, 5, 6, 7, 8]
val_vid_ids_bbn: [9]
test_vid_ids_bbn: [10]

paths:
data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False]
# data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False]
data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False]
# data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False]
# data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False]
# data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False]
# root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL"
root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL"
dataset_kwcoco: "/data/users/peri.akiva/datasets/ptg/r18_all_all_obj_results_with_dets_and_pose.mscoco.json"
dataset_kwcoco_lab: "/data/users/peri.akiva/datasets/ptg/r18_all_bbn_lab_data_all_obj_results_with_dets_and_pose.mscoco.json"
activity_config_root: "/home/local/KHQ/peri.akiva/projects/angel_system/config/activity_labels"
activity_config_fn: "${paths.activity_config_root}/${task}"
ptg_root: "/home/local/KHQ/peri.akiva/angel_system"
Expand Down
2 changes: 1 addition & 1 deletion configs/model/ptg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ _target_: tcn_hpl.models.ptg_module.PTGLitModule
optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.0001
lr: 0.000001
weight_decay: 0.0001

scheduler:
Expand Down
4 changes: 4 additions & 0 deletions tcn_hpl/data/utils/pose_generation/configs/main.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
task: r18
data_type: bbn #[bbn, gyges]
root: /data/PTG/medical/bbn_data/Release_v0.5/v0.52/M2_Tourniquet/Data
img_save_path: /data/users/peri.akiva/datasets/m2_tourniquet/imgs
pose_model_config: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py
Expand All @@ -20,3 +21,6 @@ data:
m5: /data/PTG/medical/training/yolo_object_detector/detect//m5_all/m5_all_all_obj_results.mscoco.json
r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all/r18_all_all_obj_results.mscoco.json
save_root: /data/users/peri.akiva/datasets/ptg
bbn_lab:
r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all_bbn_lab_data/r18_all_bbn_lab_data_all_obj_results.mscoco.json
save_root: /data/users/peri.akiva/datasets/ptg
16 changes: 11 additions & 5 deletions tcn_hpl/data/utils/pose_generation/generate_pose_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@ def __init__(self, config: dict) -> None:
self.config = config
self.root_path = config['root']

self.dataset = kwcoco.CocoDataset(config['data'][config['task']])
if config['data_type'] == "bbn":
self.config_data_key = "bbn_lab"
else:
self.config_data_key = "data"

self.dataset = kwcoco.CocoDataset(config[self.config_data_key][config['task']])
self.patient_cid = self.dataset.add_category('patient')
self.user_cid = self.dataset.add_category('user')


self.keypoints_cats = [
"nose", "mouth", "throat","chest","stomach","left_upper_arm",
"right_upper_arm","left_lower_arm","right_lower_arm","left_wrist",
Expand All @@ -59,7 +65,7 @@ def __init__(self, config: dict) -> None:

self.dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset

self.dataset_path_name = self.config['data'][self.config['task']][:-12].split('/')[-1] #remove .mscoco.json
self.dataset_path_name = self.config[self.config_data_key][self.config['task']][:-12].split('/')[-1] #remove .mscoco.json

self.args = get_parser(self.config['detection_model_config']).parse_args()
detecron_cfg = setup_detectron_cfg(self.args)
Expand Down Expand Up @@ -257,13 +263,13 @@ def generate_bbs_and_pose(self, dset: kwcoco.CocoDataset, save_intermediate: boo
# import matplotlib.pyplot as plt
# image_show = dset.draw_image(gid=img_id)
# plt.imshow(image_show)
# plt.savefig(f"myfig_{self.config['task']}_{index}.png")
# plt.savefig(f"figs/myfig_{self.config['task']}_{index}.png")
# if index >= 20:
# exit()

if save_intermediate:
if (index % 45000) == 0:
dset_inter_name = f"{self.config['data']['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json"
dset_inter_name = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json"
self.dataset.dump(dset_inter_name, newlines=True)
print(f"Saved intermediate dataset at index {index} to: {dset_inter_name}")

Expand Down Expand Up @@ -293,7 +299,7 @@ def run(self) -> None:
"""
self.dataset = self.generate_bbs_and_pose(self.dataset)

dataset_path_with_pose = f"{self.config['data']['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json"
dataset_path_with_pose = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json"
self.dataset.dump(dataset_path_with_pose, newlines=True)
print(f"Saved test dataset to: {dataset_path_with_pose}")
return
Expand Down
Loading

0 comments on commit a585e40

Please sign in to comment.