From 35422c2bd9065421046c6a271c49c5fb74c949de Mon Sep 17 00:00:00 2001 From: Peri Akiva Date: Mon, 26 Feb 2024 13:01:42 -0500 Subject: [PATCH] Dev (#1) * pose estimation integration to TCN * pose generator * version 6 features * ptg data generator * ptg datagenerator with poses and offset vectors * resplit data * top-n objects confs for vector generation * clean and organize pose estimation into a single script * pose generation in one script, union of datasets, selection by gids or vids * step visualization, ptg model change, pose data generation in one script, new loss function * ptg feat v6 with temporal features --------- Co-authored-by: Peri Akiva --- .gitignore | 4 + configs/experiment/m2/feat_v6.yaml | 73 ++++ configs/experiment/m2/feat_v6_no_pose.yaml | 64 ++++ .../m2/feat_v6_only_hands_joints.yaml | 64 ++++ .../m2/feat_v6_only_object_joints.yaml | 64 ++++ configs/model/ptg.yaml | 24 +- environment.yaml | 2 +- requirements.txt | 14 +- setup.py | 4 +- tcn_hpl/data/components/PTG_dataset.py | 16 + tcn_hpl/data/ptg_datamodule.py | 6 + tcn_hpl/data/utils/pose_generation/README.md | 7 + .../configs/Base-RCNN-FPN.yaml | 42 +++ .../ViTPose_base_medic_casualty_256x192.py | 170 +++++++++ .../ViTPose_base_medic_user_256x192.py | 169 +++++++++ .../configs/default_runtime.py | 19 ++ .../utils/pose_generation/configs/main.yaml | 19 ++ .../pose_generation/configs/medic_patient.py | 220 ++++++++++++ .../pose_generation/configs/medic_pose.yaml | 27 ++ .../pose_generation/configs/medic_user.py | 69 ++++ .../pose_generation/generate_pose_data.py | 206 +++++++++++ .../data/utils/pose_generation/predictor.py | 221 ++++++++++++ tcn_hpl/data/utils/pose_generation/utils.py | 118 +++++++ .../utils/pose_generation/video_to_frames.py | 89 +++++ tcn_hpl/data/utils/ptg_datagenerator.py | 261 +++++++++++--- tcn_hpl/models/components/focal_loss.py | 8 +- tcn_hpl/models/components/ms_tcs_net.py | 58 +++- tcn_hpl/models/ptg_module.py | 323 ++++++++++++++++-- tcn_hpl/train.py | 10 +- tcn_hpl/utils/utils.py | 43 ++- 30 files changed, 2317 insertions(+), 97 deletions(-) create mode 100644 configs/experiment/m2/feat_v6.yaml create mode 100644 configs/experiment/m2/feat_v6_no_pose.yaml create mode 100644 configs/experiment/m2/feat_v6_only_hands_joints.yaml create mode 100644 configs/experiment/m2/feat_v6_only_object_joints.yaml create mode 100644 tcn_hpl/data/utils/pose_generation/README.md create mode 100644 tcn_hpl/data/utils/pose_generation/configs/Base-RCNN-FPN.yaml create mode 100644 tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py create mode 100644 tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py create mode 100644 tcn_hpl/data/utils/pose_generation/configs/default_runtime.py create mode 100644 tcn_hpl/data/utils/pose_generation/configs/main.yaml create mode 100644 tcn_hpl/data/utils/pose_generation/configs/medic_patient.py create mode 100644 tcn_hpl/data/utils/pose_generation/configs/medic_pose.yaml create mode 100644 tcn_hpl/data/utils/pose_generation/configs/medic_user.py create mode 100644 tcn_hpl/data/utils/pose_generation/generate_pose_data.py create mode 100644 tcn_hpl/data/utils/pose_generation/predictor.py create mode 100644 tcn_hpl/data/utils/pose_generation/utils.py create mode 100644 tcn_hpl/data/utils/pose_generation/video_to_frames.py diff --git a/.gitignore b/.gitignore index 04a064844..07cdd3de4 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,10 @@ dmypy.json *.h5 *.tar *.tar.gz +*.pth +*.png +*.jpg +*.jpeg # Lightning-Hydra-Template configs/local/default.yaml diff --git a/configs/experiment/m2/feat_v6.yaml b/configs/experiment/m2/feat_v6.yaml new file mode 100644 index 000000000..e76223a71 --- /dev/null +++ b/configs/experiment/m2/feat_v6.yaml @@ -0,0 +1,73 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: ptg + - override /model: ptg + - override /callbacks: default + - override /trainer: gpu + - override /paths: default + - override /logger: aim + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters +# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose" #[_v2_aug_False] +exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_True" #[_v2_aug_False] +# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] +# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] +# exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] + + +tags: ["m2", "ms_tcn"] + +seed: 12345 + +trainer: + min_epochs: 50 + max_epochs: 500 + + +model: + compile: false + + net: + dim: 182 # length of feature vector + +data: + num_classes: 9 # activities: includes background + batch_size: 512 + num_workers: 0 + epoch_length: 20000 + window_size: 30 + sample_rate: 1 + + all_transforms: + train_order: [] #["MoveCenterPts", "NormalizePixelPts"] + test_order: [] #["NormalizePixelPts"] + + MoveCenterPts: + feat_version: 6 + num_obj_classes: 9 # not including background + NormalizeFromCenter: + feat_version: 6 + NormalizePixelPts: + feat_version: 6 + num_obj_classes: 9 # not including background + +paths: + data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_True" #[_v2_aug_False] + # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] + # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] + # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] + + # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" + root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" + +logger: + aim: + experiment: ${exp_name} + capture_terminal_logs: true + +task_name: ${exp_name} \ No newline at end of file diff --git a/configs/experiment/m2/feat_v6_no_pose.yaml b/configs/experiment/m2/feat_v6_no_pose.yaml new file mode 100644 index 000000000..d3b5a4b15 --- /dev/null +++ b/configs/experiment/m2/feat_v6_no_pose.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: ptg + - override /model: ptg + - override /callbacks: default + - override /trainer: gpu + - override /paths: default + - override /logger: aim + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters +exp_name: "p_m2_tqt_data_test_feat_v6_no_pose" #[p_m2_tqt_data_test_feat_v6] + +tags: ["m2", "ms_tcn"] + +seed: 12345 + +trainer: + min_epochs: 50 + max_epochs: 400 + + +model: + compile: false + + net: + dim: 288 # length of feature vector + +data: + num_classes: 9 # activities: includes background + batch_size: 256 + num_workers: 0 + epoch_length: 20000 + window_size: 30 + #sample_rate: 2 + + all_transforms: + train_order: [] #["MoveCenterPts", "NormalizePixelPts"] + test_order: [] #["NormalizePixelPts"] + + MoveCenterPts: + feat_version: 5 + num_obj_classes: 12 # not including background + NormalizeFromCenter: + feat_version: 5 + NormalizePixelPts: + feat_version: 5 + num_obj_classes: 12 # not including background + +paths: + data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_no_pose" #[p_m2_tqt_data_test_feat_v6] + # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" + root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" + +logger: + aim: + experiment: ${exp_name} + capture_terminal_logs: true + +task_name: ${exp_name} \ No newline at end of file diff --git a/configs/experiment/m2/feat_v6_only_hands_joints.yaml b/configs/experiment/m2/feat_v6_only_hands_joints.yaml new file mode 100644 index 000000000..13ea786bc --- /dev/null +++ b/configs/experiment/m2/feat_v6_only_hands_joints.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: ptg + - override /model: ptg + - override /callbacks: default + - override /trainer: gpu + - override /paths: default + - override /logger: aim + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters +exp_name: "p_m2_tqt_data_test_feat_v6_only_hands_joints_v2" #[p_m2_tqt_data_test_feat_v6] + +tags: ["m2", "ms_tcn"] + +seed: 12345 + +trainer: + min_epochs: 50 + max_epochs: 400 + + +model: + compile: false + + net: + dim: 332 # length of feature vector + +data: + num_classes: 9 # activities: includes background + batch_size: 256 + num_workers: 0 + epoch_length: 20000 + window_size: 30 + #sample_rate: 2 + + all_transforms: + train_order: [] #["MoveCenterPts", "NormalizePixelPts"] + test_order: [] #["NormalizePixelPts"] + + MoveCenterPts: + feat_version: 6 + num_obj_classes: 12 # not including background + NormalizeFromCenter: + feat_version: 6 + NormalizePixelPts: + feat_version: 6 + num_obj_classes: 12 # not including background + +paths: + data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_only_hands_joints" #[p_m2_tqt_data_test_feat_v6] + # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" + root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" + +logger: + aim: + experiment: ${exp_name} + capture_terminal_logs: true + +task_name: ${exp_name} \ No newline at end of file diff --git a/configs/experiment/m2/feat_v6_only_object_joints.yaml b/configs/experiment/m2/feat_v6_only_object_joints.yaml new file mode 100644 index 000000000..cbdf232c4 --- /dev/null +++ b/configs/experiment/m2/feat_v6_only_object_joints.yaml @@ -0,0 +1,64 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: ptg + - override /model: ptg + - override /callbacks: default + - override /trainer: gpu + - override /paths: default + - override /logger: aim + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters +exp_name: "p_m2_tqt_data_test_feat_v6_only_objects_joints" #[p_m2_tqt_data_test_feat_v6] + +tags: ["m2", "ms_tcn"] + +seed: 12345 + +trainer: + min_epochs: 50 + max_epochs: 400 + + +model: + compile: false + + net: + dim: 310 # length of feature vector + +data: + num_classes: 9 # activities: includes background + batch_size: 256 + num_workers: 0 + epoch_length: 20000 + window_size: 30 + #sample_rate: 2 + + all_transforms: + train_order: [] #["MoveCenterPts", "NormalizePixelPts"] + test_order: [] #["NormalizePixelPts"] + + MoveCenterPts: + feat_version: 6 + num_obj_classes: 12 # not including background + NormalizeFromCenter: + feat_version: 6 + NormalizePixelPts: + feat_version: 6 + num_obj_classes: 12 # not including background + +paths: + data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_only_objects_joints" #[p_m2_tqt_data_test_feat_v6] + # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" + root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" + +logger: + aim: + experiment: ${exp_name} + capture_terminal_logs: true + +task_name: ${exp_name} \ No newline at end of file diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index 1867a9afc..63173d5fe 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -3,8 +3,8 @@ _target_: tcn_hpl.models.ptg_module.PTGLitModule optimizer: _target_: torch.optim.Adam _partial_: true - lr: 0.0005 - weight_decay: 0.0 + lr: 0.00001 + weight_decay: 0.0001 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau @@ -15,23 +15,27 @@ scheduler: net: _target_: tcn_hpl.models.components.ms_tcs_net.MultiStageModel - num_stages: 4 - num_layers: 10 - num_f_maps: 64 - dim: 204 + num_stages: 4 + num_layers: 5 + num_f_maps: 128 + # dim: 204 + dim: 128 num_classes: ${data.num_classes} + window_size: ${data.window_size} criterion: _target_: tcn_hpl.models.components.focal_loss.FocalLoss - alpha: 0.25 - gamma: 2 - weight: None + alpha: 0.25 # 0.25 + gamma: 1 + # weight: None + weight: [1, 1, 1, 1, 1, 0.75, 0.75, 1, 1] reduction: "mean" data_dir: ${paths.data_dir} # Smoothing loss weight -smoothing_loss: 0.015 +smoothing_loss: 0.0015 +use_smoothing_loss: True # Number of classes num_classes: ${data.num_classes} diff --git a/environment.yaml b/environment.yaml index f74ee8c72..5e2182434 100644 --- a/environment.yaml +++ b/environment.yaml @@ -5,7 +5,7 @@ # - conda allows for installing packages without requiring certain compilers or # libraries to be available in the system, since it installs precompiled binaries -name: myenv +name: ptg channels: - pytorch diff --git a/requirements.txt b/requirements.txt index d06a91b71..4d0ad6eaa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ # --------- pytorch --------- # -torch>=2.0.0 -torchvision>=0.15.0 -lightning>=2.0.0 -torchmetrics>=0.11.4 +# torch>=2.0.0 +# torchvision>=0.15.0 +# lightning>=2.0.0 +# torchmetrics>=0.11.4 # --------- hydra --------- # -hydra-core==1.3.2 -hydra-colorlog==1.2.0 -hydra-optuna-sweeper==1.2.0 +# hydra-core==1.3.2 +# hydra-colorlog==1.2.0 +# hydra-optuna-sweeper==1.2.0 # --------- loggers --------- # # wandb diff --git a/setup.py b/setup.py index d3a66d4de..fe089fc4c 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ author="", author_email="", url="https://github.com/user/project", - install_requires=["lightning", "hydra-core", "hydra-colorlog"], - packages=find_packages(), + # install_requires=["lightning", "hydra-core", "hydra-colorlog"], + packages=find_packages(include=['tcn_hpl', 'tcn_hpl.*']), # use this to customize global commands available in the terminal after installing the package entry_points={ "console_scripts": [ diff --git a/tcn_hpl/data/components/PTG_dataset.py b/tcn_hpl/data/components/PTG_dataset.py index f3a5e2581..ae6680263 100644 --- a/tcn_hpl/data/components/PTG_dataset.py +++ b/tcn_hpl/data/components/PTG_dataset.py @@ -107,10 +107,26 @@ def __getitem__(self, idx): :return: features, targets, and mask of the window """ + + # print(f"size of dataset: {self.__len__()}") + # print(f"self.feature_frames: {self.feature_frames.shape}") + # print(f"self.target_frames: {self.target_frames.shape}") + # print(f"self.mask_frames: {self.mask_frames.shape}") + # print(f"self.source_vids: {self.source_vids.shape}") + # print(f"self.source_frames: {self.source_frames.shape}") + # print(f"self.mask_frames: {self.mask_frames}") + features = self.feature_frames[idx:idx+self.window_size, :] target = self.target_frames[idx:idx+self.window_size] mask = self.mask_frames[idx:idx+self.window_size] source_vid = self.source_vids[idx:idx+self.window_size] source_frame = self.source_frames[idx:idx+self.window_size] + # print(f"mask: {mask}") + # print(f"features: {features.shape}") + # print(f"target: {target.shape}") + # print(f"mask: {mask.shape}") + # print(f"source_vid: {source_vid.shape}") + # print(f"source_frame: {source_frame.shape}") + return features, target, mask, np.array(source_vid), source_frame diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index 3aefebed4..f1d16316f 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -157,6 +157,9 @@ def setup(self, stage: Optional[str] = None) -> None: with open(vid_list_file, "r") as train_f: train_videos = train_f.read().split("\n")[:-1] + + # print(f"train_vids: {train_videos}") + # exit() # Load validation vidoes with open(vid_list_file_val, "r") as val_f: val_videos = val_f.read().split("\n")[:-1] @@ -170,6 +173,9 @@ def setup(self, stage: Optional[str] = None) -> None: features_path, self.hparams.sample_rate, self.hparams.window_size, transform=self.train_transform ) + + # print(f"size:{self.data_train.__len__()}") + # exit() self.data_val = PTG_Dataset( val_videos, self.hparams.num_classes, actions_dict, gt_path, diff --git a/tcn_hpl/data/utils/pose_generation/README.md b/tcn_hpl/data/utils/pose_generation/README.md new file mode 100644 index 000000000..c89f70232 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/README.md @@ -0,0 +1,7 @@ +### Steps: + +1. generate frames from videos +2. predict bounding boxes +3. predict pose +4. write pose to kwcoco file + diff --git a/tcn_hpl/data/utils/pose_generation/configs/Base-RCNN-FPN.yaml b/tcn_hpl/data/utils/pose_generation/configs/Base-RCNN-FPN.yaml new file mode 100644 index 000000000..86f00d852 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/Base-RCNN-FPN.yaml @@ -0,0 +1,42 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + BACKBONE: + NAME: "build_resnet_fpn_backbone" + RESNETS: + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + FPN: + IN_FEATURES: ["res2", "res3", "res4", "res5"] + ANCHOR_GENERATOR: + SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map + ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) + RPN: + IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] + PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level + PRE_NMS_TOPK_TEST: 1000 # Per FPN level + # Detectron1 uses 2000 proposals per-batch, + # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) + # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. + POST_NMS_TOPK_TRAIN: 1000 + POST_NMS_TOPK_TEST: 1000 + ROI_HEADS: + NAME: "StandardROIHeads" + IN_FEATURES: ["p2", "p3", "p4", "p5"] + ROI_BOX_HEAD: + NAME: "FastRCNNConvFCHead" + NUM_FC: 2 + POOLER_RESOLUTION: 7 + ROI_MASK_HEAD: + NAME: "MaskRCNNConvUpsampleHead" + NUM_CONV: 4 + POOLER_RESOLUTION: 14 +DATASETS: + TRAIN: ("coco_2017_train",) + TEST: ("coco_2017_val",) +SOLVER: + IMS_PER_BATCH: 12 + BASE_LR: 0.08 + STEPS: (60000, 80000) + MAX_ITER: 90000 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +VERSION: 2 diff --git a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py new file mode 100644 index 000000000..cc2940933 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py @@ -0,0 +1,170 @@ +_base_ = [ + 'default_runtime.py', + 'medic_patient.py' +] +evaluation = dict(interval=210, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=22, + dataset_joints=22, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=768, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + # bbox_file='/shared/niudt/DATASET/Medical/final_version_coco_jan21/bbox_detections_results/casualty.json', + bbox_file='/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = '/shared/niudt/DATASET/Medical' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_val.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py new file mode 100644 index 000000000..f8f79ab68 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py @@ -0,0 +1,169 @@ +_base_ = [ + 'default_runtime.py', + 'medic_user.py' +] +evaluation = dict(interval=210, metric='mAP', save_best='AP') + +optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + 'bias': dict(decay_multi=0.), + 'pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + } + ) + ) + +optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) + +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[170, 200]) +total_epochs = 210 +target_type = 'GaussianHeatmap' +channel_cfg = dict( + num_output_channels=6, + dataset_joints=6, + dataset_channel=[ + [0, 1, 2, 3, 4, 5], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5 + ]) + +# model settings +model = dict( + type='TopDown', + pretrained=None, + backbone=dict( + type='ViT', + img_size=(256, 192), + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + ratio=1, + use_checkpoint=False, + mlp_ratio=4, + qkv_bias=True, + drop_path_rate=0.3, + ), + keypoint_head=dict( + type='TopdownHeatmapSimpleHead', + in_channels=768, + num_deconv_layers=2, + num_deconv_filters=(256, 256), + num_deconv_kernels=(4, 4), + extra=dict(final_conv_kernel=1, ), + out_channels=channel_cfg['num_output_channels'], + loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + train_cfg=dict(), + test_cfg=dict( + flip_test=True, + post_process='default', + shift_heatmap=False, + target_type=target_type, + modulate_kernel=11, + use_udp=True)) + +data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=False, + det_bbox_thr=0.0, + bbox_file='/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json', +) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownRandomFlip', flip_prob=0.5), + dict( + type='TopDownHalfBodyTransform', + num_joints_half_body=8, + prob_half_body=0.3), + dict( + type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='TopDownGenerateTarget', + sigma=2, + encoding='UDP', + target_type=target_type), + dict( + type='Collect', + keys=['img', 'target', 'target_weight'], + meta_keys=[ + 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', + 'rotation', 'bbox_score', 'flip_pairs' + ]), +] + +val_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='TopDownAffine', use_udp=True), + dict(type='ToTensor'), + dict( + type='NormalizeTensor', + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + dict( + type='Collect', + keys=['img'], + meta_keys=[ + 'image_file', 'center', 'scale', 'rotation', 'bbox_score', + 'flip_pairs' + ]), +] + +test_pipeline = val_pipeline + +data_root = '/shared/niudt/DATASET/Medical' +data = dict( + samples_per_gpu=64, + workers_per_gpu=4, + val_dataloader=dict(samples_per_gpu=32), + test_dataloader=dict(samples_per_gpu=32), + train=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_train.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=train_pipeline, + dataset_info={{_base_.dataset_info}}), + val=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_val.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=val_pipeline, + dataset_info={{_base_.dataset_info}}), + test=dict( + type='TopDownCocoDataset', + ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_train.json', + img_prefix=f'{data_root}/images/', + data_cfg=data_cfg, + pipeline=test_pipeline, + dataset_info={{_base_.dataset_info}}), +) + diff --git a/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py b/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py new file mode 100644 index 000000000..d5fbc3d42 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py @@ -0,0 +1,19 @@ +checkpoint_config = dict(interval=10) + +log_config = dict( + interval=10, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +log_level = 'INFO' +load_from = '/shared/niudt/pose_estimation/vitpose/ViTPose/pretrained_model/vitpose-b.pth' +resume_from = None +dist_params = dict(backend='nccl') +workflow = [('train', 1)] + +# disable opencv multithreading to avoid system being overloaded +opencv_num_threads = 0 +# set multi-process start method as `fork` to speed up the training +mp_start_method = 'fork' diff --git a/tcn_hpl/data/utils/pose_generation/configs/main.yaml b/tcn_hpl/data/utils/pose_generation/configs/main.yaml new file mode 100644 index 000000000..2c3f60fba --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/main.yaml @@ -0,0 +1,19 @@ +root: /data/PTG/medical/bbn_data/Release_v0.5/v0.52/M2_Tourniquet/Data +img_save_path: /data/users/peri.akiva/datasets/m2_tourniquet/imgs +pose_model_config: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py +detection_model_config: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/configs/medic_pose.yaml +thickness: 2 +radius: 10 +kpt-thr: 0.8 +device: cuda +pose_model_checkpoint: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/checkpoints/pose_model.pth +data: + # train: /data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_train_activity_obj_results_with_dets_and_pose.mscoco.json + # val: /data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_val_obj_results_with_dets_and_pose.mscoco.json + # test: /data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_test_obj_results_with_dets_and_pose.mscoco.json + train: /data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_train_activity_obj_results.mscoco.json + val: /data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_val_obj_results.mscoco.json + test: /data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_test_obj_results.mscoco.json + save_root: /data/users/peri.akiva/datasets/ptg + + \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py b/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py new file mode 100644 index 000000000..a25104550 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py @@ -0,0 +1,220 @@ +dataset_info = dict( + dataset_name='coco_medic', + paper_info=dict( + author='Lin, Tsung-Yi and Maire, Michael and ' + 'Belongie, Serge and Hays, James and ' + 'Perona, Pietro and Ramanan, Deva and ' + r'Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/', + ), + keypoint_info={ + 0: + dict(name='nose', + id=0, + color=[0, 215, 255], + type='upper', + swap=''), + 1: + dict( + name='mouth', + id=1, + color=[0, 215, 255], + type='upper', + swap=''), + 2: + dict( + name='throat', + id=2, + color=[0, 215, 255], + type='upper', + swap=''), + 3: + dict( + name='chest', + id=3, + color=[0, 215, 255], + type='upper', + swap=''), + 4: + dict( + name='stomach', + id=4, + color=[0, 215, 255], + type='upper', + swap=''), + 5: + dict( + name='left_upper_arm', + id=5, + color=[255, 0, 0], + type='upper', + swap=''), + 6: + dict( + name='right_upper_arm', + id=6, + color=[0, 255, 0], + type='upper', + swap=''), + 7: + dict( + name='left_lower_arm', + id=7, + color=[255, 0, 0], + type='upper', + swap=''), + 8: + dict( + name='right_lower_arm', + id=8, + color=[0, 255, 0], + type='upper', + swap=''), + 9: + dict( + name='left_wrist', + id=9, + color=[255, 0, 0], + type='upper', + swap=''), + 10: + dict( + name='right_wrist', + id=10, + color=[0, 255, 0], + type='upper', + swap=''), + 11: + dict( + name='left_hand', + id=11, + color=[255, 0, 0], + type='upper', + swap=''), + 12: + dict( + name='right_hand', + id=12, + color=[0, 255, 0], + type='upper', + swap=''), + 13: + dict( + name='left_upper_leg', + id=13, + color=[255, 0, 0], + type='lower', + swap=''), + + 14: + dict( + name='right_upper_leg', + id=14, + color=[0, 255, 0], + type='lower', + swap=''), + 15: + dict( + name='left_knee', + id=15, + color=[255, 0, 0], + type='lower', + swap=''), + 16: + dict( + name='right_knee', + id=16, + color=[0, 255, 0], + type='lower', + swap=''), + 17: + dict( + name='left_lower_leg', + id=17, + color=[255, 0, 0], + type='lower', + swap=''), + 18: + dict( + name='right_lower_leg', + id=18, + color=[0, 255, 0], + type='lower', + swap=''), + 19: + dict( + name= 'left_foot', + id=19, + color=[255, 0, 0], + type='lower', + swap=''), + 20: + dict( + name='right_foot', + id=20, + color=[0, 255, 0], + type='lower', + swap=''), + 21: + dict( + name='back', + id=21, + color=[147, 20, 255], + type='lower', + swap=''), + }, + skeleton_info={ + 0: + {'link': ('nose', 'mouth'), 'id': 0, 'color': [0, 215, 255]}, + 1: + {'link': ('mouth', 'throat'), 'id': 1, 'color': [0, 215, 255]}, + 2: + {'link': ('throat', 'chest'), 'id': 2, 'color': [0, 215, 255]}, + 3: + {'link': ('chest', 'stomach'), 'id': 3, 'color': [0, 215, 255]}, + 4: + {'link': ('throat', 'left_upper_arm'), 'id': 4, 'color': [255, 0, 0]}, + 5: + {'link': ('throat', 'right_upper_arm'), 'id': 5, 'color': [0, 255, 0]}, + 6: + {'link': ('left_upper_arm', 'left_lower_arm'), 'id': 6, 'color': [255, 0, 0]}, + 7: + {'link': ('right_upper_arm', 'right_lower_arm'), 'id': 7, 'color': [0, 255, 0]}, + 8: + {'link': ('left_lower_arm', 'left_wrist'), 'id': 8, 'color': [255, 0, 0]}, + 9: + {'link': ('right_lower_arm', 'right_wrist'), 'id': 9, 'color': [0, 255, 0]}, + 10: + {'link': ('left_wrist', 'left_hand'), 'id': 10, 'color': [255, 0, 0]}, + 11: + {'link': ('right_wrist', 'right_hand'), 'id': 11, 'color': [0, 255, 0]}, + 12: + {'link': ('stomach', 'left_upper_leg'), 'id': 12, 'color': [255, 0, 0]}, + 13: + {'link': ('stomach', 'right_upper_leg'), 'id': 13, 'color': [0, 255, 0]}, + 14: + {'link': ('left_upper_leg', 'left_knee'), 'id': 14, 'color': [255, 0, 0]}, + 15: + {'link': ('right_upper_leg', 'right_knee'), 'id': 15, 'color': [0, 255, 0]}, + 16: + {'link': ('left_knee', 'left_lower_leg'), 'id': 16, 'color': [255, 0, 0]}, + 17: + {'link': ('right_knee', 'right_lower_leg'), 'id': 17, 'color': [0, 255, 0]}, + 18: + {'link': ('left_lower_leg', 'left_foot'), 'id': 18, 'color': [255, 0, 0]}, + 19: + {'link': ('right_lower_leg', 'right_foot'), 'id': 19, 'color': [0, 255, 0]}, + 20: + {'link': ('right_upper_leg', 'back'), 'id': 20, 'color': [147, 20, 255]}, + 21: + {'link': ('left_upper_leg', 'back'), 'id': 21, 'color': [147, 20, 255]}, + }, + + joint_weights=[ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + sigmas=[ + 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067] +) \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_pose.yaml b/tcn_hpl/data/utils/pose_generation/configs/medic_pose.yaml new file mode 100644 index 000000000..12716a131 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_pose.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "/home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/checkpoints/model_final.pth" # please change here to the path where you put the weights + MASK_ON: False + RESNETS: + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 2 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (6670, 8888) + MAX_ITER: 10000 # 180000 * 16 / 100000 ~ 28.8 epochs +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + +OUTPUT_DIR: "./output_pose_ablation/user+env" + diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_user.py b/tcn_hpl/data/utils/pose_generation/configs/medic_user.py new file mode 100644 index 000000000..05191b1b4 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_user.py @@ -0,0 +1,69 @@ +dataset_info = dict( + dataset_name='medic_user', + paper_info=dict( + author='Lin, Tsung-Yi and Maire, Michael and ' + 'Belongie, Serge and Hays, James and ' + 'Perona, Pietro and Ramanan, Deva and ' + r'Doll{\'a}r, Piotr and Zitnick, C Lawrence', + title='Microsoft coco: Common objects in context', + container='European conference on computer vision', + year='2014', + homepage='http://cocodataset.org/', + ), + keypoint_info={ #'left_upper_arm', 'right_upper_arm', 'left_lower_arm', 'right_lower_arm', 'left_hand', 'right_hand + 0: + dict(name='left_upper_arm', + id=0, + color=[0, 0, 139], + type='upper', + swap=''), + 1: + dict( + name='right_upper_arm', + id=1, + color=[255, 112, 132], + type='upper', + swap=''), + 2: + dict( + name='left_lower_arm', + id=2, + color=[0, 0, 139], + type='upper', + swap=''), + 3: + dict( + name='right_lower_arm', + id=3, + color=[255, 112, 132], + type='upper', + swap=''), + 4: + dict( + name='left_hand', + id=4, + color=[0, 0, 139], + type='upper', + swap=''), + 5: + dict( + name='right_hand', + id=5, + color=[255, 112, 132], + type='upper', + swap='') + }, + skeleton_info={ + 0: + {'link': ('right_hand', 'right_lower_arm'), 'id': 0, 'color':[255, 112, 132]}, + 1: + {'link': ('right_upper_arm', 'right_lower_arm'), 'id': 1, 'color': [255, 112, 132]}, + 2: + {'link': ('left_upper_arm', 'left_lower_arm'), 'id': 2, 'color': [0, 0, 139]}, + 3: + {'link': ('left_lower_arm', 'left_hand'), 'id': 3, 'color': [0, 0, 139]} + }, + + joint_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + sigmas=[0.67, 0.67, 0.67, 0.67, 0.67, 0.67] +) \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/generate_pose_data.py b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py new file mode 100644 index 000000000..028f684a6 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py @@ -0,0 +1,206 @@ +"""Generate bounding box detections, then generate poses for patients + """ + +import argparse +import glob +from glob import glob +import multiprocessing as mp +import numpy as np +import os +import tempfile +import time +import warnings +import cv2 +import tqdm +from detectron2.config import get_cfg +from detectron2.data.detection_utils import read_image +from detectron2.utils.logger import setup_logger +import json +from predictor import VisualizationDemo +# import tcn_hpl.utils.utils as utils +from mmpose.apis import (inference_top_down_pose_model, init_pose_model, + vis_pose_result) +import utils +import kwcoco +from mmpose.datasets import DatasetInfo +print(f"utils: {utils.__file__}") + + +import warnings +warnings.filterwarnings("ignore") + + +def setup_detectron_cfg(args): + # load config from file and command-line arguments + cfg = get_cfg() + # To use demo for Panoptic-DeepLab, please uncomment the following two lines. + # from detectron2.projects.panoptic_deeplab import add_panoptic_deeplab_config # noqa + # add_panoptic_deeplab_config(cfg) + cfg.merge_from_file(args.config_file) + cfg.merge_from_list(args.opts) + # Set score_threshold for builtin models + cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold + cfg.freeze() + return cfg + + +class PosesGenerator(object): + def __init__(self, config: dict): + self.config = config + self.root_path = config['root'] + # self.paths = utils.dictionary_contents(config['root'], types=['*.JPG', '*.jpg', '*.JPEG', '*.jpeg', '*.png'], recursive=True) + + self.train_dataset = kwcoco.CocoDataset(config['data']['train']) + self.val_dataset = kwcoco.CocoDataset(config['data']['val']) + self.test_dataset = kwcoco.CocoDataset(config['data']['test']) + + self.keypoints_cats = [ + "nose", "mouth", "throat","chest","stomach","left_upper_arm", + "right_upper_arm","left_lower_arm","right_lower_arm","left_wrist", + "right_wrist","left_hand","right_hand","left_upper_leg", + "right_upper_leg","left_knee","right_knee","left_lower_leg", + "right_lower_leg", "left_foot", "right_foot", "back" + ] + + self.keypoints_cats_dset = [{'name': value, 'id': index} for index, value in enumerate(self.keypoints_cats)] + self.train_dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset + self.val_dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset + self.test_dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset + + + self.args = utils.get_parser(self.config['detection_model_config']).parse_args() + detecron_cfg = setup_detectron_cfg(self.args) + self.predictor = VisualizationDemo(detecron_cfg) + + self.pose_model = init_pose_model(config['pose_model_config'], + config['pose_model_checkpoint'], + device=config['device']) + + self.pose_dataset = self.pose_model.cfg.data['test']['type'] + self.pose_dataset_info = self.pose_model.cfg.data['test'].get('dataset_info', None) + if self.pose_dataset_info is None: + warnings.warn( + 'Please set `dataset_info` in the config.' + 'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', + DeprecationWarning) + else: + self.pose_dataset_info = DatasetInfo(self.pose_dataset_info) + + def generate_bbs_and_pose(self, dset): + + # json_file = utils.initialize_coco_json() + patient_cid = dset.add_category('patient') + user_cid = dset.add_category('user') + # pbar = tqdm.tqdm(enumerate(self.paths), total=len(self.paths)) + pbar = tqdm.tqdm(enumerate(dset.imgs.items()), total=len(list(dset.imgs.keys()))) + + for index, (img_id, img_dict) in pbar: + + path = img_dict['file_name'] + + img = read_image(path, format="BGR") + predictions, visualized_output = self.predictor.run_on_image(img) + + instances = predictions["instances"].to('cpu') + boxes = instances.pred_boxes if instances.has("pred_boxes") else None + scores = instances.scores if instances.has("scores") else None + classes = instances.pred_classes.tolist() if instances.has("pred_classes") else None + + boxes = boxes.tensor.detach().numpy() + scores = scores.numpy() + + file_name = path.split('/')[-1] + # video_name = path.split('/')[-2] + # print(f"file name: {file_name}") + # print(f"file path: {path}") + # print(f"video name: {video_name}") + + if boxes is not None: + + # person_results = [] + for box_id, _bbox in enumerate(boxes): + + current_ann = {} + # current_ann['id'] = ann_id + current_ann['image_id'] = img_id + current_ann['bbox'] = np.asarray(_bbox).tolist()#_bbox + current_ann['category_id'] = patient_cid + current_ann['label'] = 'patient' + current_ann['bbox_score'] = str(round(scores[box_id] * 100,2)) + '%' + + + person_results = [current_ann] + + pose_results, returned_outputs = inference_top_down_pose_model( + self.pose_model, + path, + person_results, + bbox_thr=None, + format='xyxy', + dataset=self.pose_dataset, + dataset_info=self.pose_dataset_info, + return_heatmap=None, + outputs=['backbone']) + + # print(f"outputs: {type(returned_outputs[0])}") + # print(f"outputs: {len(returned_outputs)}") + # print(f"outputs: {returned_outputs[0]}") + image_features = returned_outputs[0]['backbone'][0,:,:,-1] + + # print(f"image_features: {image_features.shape}") + # exit() + pose_keypoints = pose_results[0]['keypoints'].tolist() + # bbox = pose_results[0]['bbox'].tolist() + pose_keypoints_list = [] + for index, keypoint in enumerate(pose_keypoints): + kp_dict = {'xy': [keypoint[0], keypoint[1]], + 'keypoint_category_id': index, + 'keypoint_category': self.keypoints_cats[index]} + pose_keypoints_list.append(kp_dict) + + current_ann['keypoints'] = pose_keypoints_list + current_ann['image_features'] = image_features + + dset.add_annotation(**current_ann) + + return dset + + # def predict_poses(self, dset): + # pass + + def run(self): + self.train_dataset = self.generate_bbs_and_pose(self.train_dataset) + self.val_dataset = self.generate_bbs_and_pose(self.val_dataset) + self.test_dataset = self.generate_bbs_and_pose(self.test_dataset) + + train_path_name = self.config['data']['train'][:-12].split('/')[-1] #remove .mscoco.json + val_path_name = self.config['data']['val'][:-12].split('/')[-1] #remove .mscoco.json + test_path_name = self.config['data']['test'][:-12].split('/')[-1] #remove .mscoco.json + + train_path_with_pose = f"{self.config['data']['save_root']}/{train_path_name}_with_dets_and_pose.mscoco.json" + val_path_with_pose = f"{self.config['data']['save_root']}/{val_path_name}_with_dets_and_pose.mscoco.json" + test_path_with_pose = f"{self.config['data']['save_root']}/{test_path_name}_with_dets_and_pose.mscoco.json" + # print(f"train_path: {train_path}") + # print(f"train_path_with_pose: {train_path_with_pose}") + self.train_dataset.dump(train_path_with_pose, newlines=True) + print(f"Saved train dataset to: {train_path_with_pose}") + + self.val_dataset.dump(val_path_with_pose, newlines=True) + print(f"Saved val dataset to: {val_path_with_pose}") + + self.test_dataset.dump(test_path_with_pose, newlines=True) + print(f"Saved test dataset to: {test_path_with_pose}") + + +def main(): + + main_config_path = f"configs/main.yaml" + config = utils.load_yaml_as_dict(main_config_path) + + PG = PosesGenerator(config) + PG.run() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/predictor.py b/tcn_hpl/data/utils/pose_generation/predictor.py new file mode 100644 index 000000000..941293b06 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/predictor.py @@ -0,0 +1,221 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import atexit +import bisect +import multiprocessing as mp +from collections import deque +import cv2 +import torch + +from detectron2.data import MetadataCatalog +from detectron2.engine.defaults import DefaultPredictor +from detectron2.utils.video_visualizer import VideoVisualizer +from detectron2.utils.visualizer import ColorMode, Visualizer + + +class VisualizationDemo(object): + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + self.metadata = MetadataCatalog.get( + cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused" + ) + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + num_gpu = torch.cuda.device_count() + self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) + else: + self.predictor = DefaultPredictor(cfg) + + def run_on_image(self, image): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + vis_output = None + predictions = self.predictor(image) + # print(f"predictor: {predictions.keys()}") + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_output = visualizer.draw_panoptic_seg_predictions( + panoptic_seg.to(self.cpu_device), segments_info + ) + else: + if "sem_seg" in predictions: + vis_output = visualizer.draw_sem_seg( + predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions(predictions=instances) + + return predictions, vis_output + + def _frame_from_video(self, video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + + def run_on_video(self, video): + """ + Visualizes predictions on frames of the input video. + + Args: + video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be + either a webcam or a video file. + + Yields: + ndarray: BGR visualizations of each video frame. + """ + video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) + + def process_predictions(frame, predictions): + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_frame = video_visualizer.draw_panoptic_seg_predictions( + frame, panoptic_seg.to(self.cpu_device), segments_info + ) + elif "instances" in predictions: + predictions = predictions["instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + elif "sem_seg" in predictions: + vis_frame = video_visualizer.draw_sem_seg( + frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + + # Converts Matplotlib RGB format to OpenCV BGR format + vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) + return vis_frame + + frame_gen = self._frame_from_video(video) + if self.parallel: + buffer_size = self.predictor.default_buffer_size + + frame_data = deque() + + for cnt, frame in enumerate(frame_gen): + frame_data.append(frame) + self.predictor.put(frame) + + if cnt >= buffer_size: + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + + while len(frame_data): + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions) + else: + for frame in frame_gen: + yield process_predictions(frame, self.predictor(frame)) + + +class AsyncPredictor: + """ + A predictor that runs the model asynchronously, possibly on >1 GPUs. + Because rendering the visualization takes considerably amount of time, + this helps improve throughput a little bit when rendering videos. + """ + + class _StopToken: + pass + + class _PredictWorker(mp.Process): + def __init__(self, cfg, task_queue, result_queue): + self.cfg = cfg + self.task_queue = task_queue + self.result_queue = result_queue + super().__init__() + + def run(self): + predictor = DefaultPredictor(self.cfg) + + while True: + task = self.task_queue.get() + if isinstance(task, AsyncPredictor._StopToken): + break + idx, data = task + result = predictor(data) + self.result_queue.put((idx, result)) + + def __init__(self, cfg, num_gpus: int = 1): + """ + Args: + cfg (CfgNode): + num_gpus (int): if 0, will run on CPU + """ + num_workers = max(num_gpus, 1) + self.task_queue = mp.Queue(maxsize=num_workers * 3) + self.result_queue = mp.Queue(maxsize=num_workers * 3) + self.procs = [] + for gpuid in range(max(num_gpus, 1)): + cfg = cfg.clone() + cfg.defrost() + cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" + self.procs.append( + AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) + ) + + self.put_idx = 0 + self.get_idx = 0 + self.result_rank = [] + self.result_data = [] + + for p in self.procs: + p.start() + atexit.register(self.shutdown) + + def put(self, image): + self.put_idx += 1 + self.task_queue.put((self.put_idx, image)) + + def get(self): + self.get_idx += 1 # the index needed for this request + if len(self.result_rank) and self.result_rank[0] == self.get_idx: + res = self.result_data[0] + del self.result_data[0], self.result_rank[0] + return res + + while True: + # make sure the results are returned in the correct order + idx, res = self.result_queue.get() + if idx == self.get_idx: + return res + insert = bisect.bisect(self.result_rank, idx) + self.result_rank.insert(insert, idx) + self.result_data.insert(insert, res) + + def __len__(self): + return self.put_idx - self.get_idx + + def __call__(self, image): + self.put(image) + return self.get() + + def shutdown(self): + for _ in self.procs: + self.task_queue.put(AsyncPredictor._StopToken()) + + @property + def default_buffer_size(self): + return len(self.procs) * 5 diff --git a/tcn_hpl/data/utils/pose_generation/utils.py b/tcn_hpl/data/utils/pose_generation/utils.py new file mode 100644 index 000000000..ad282a44d --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/utils.py @@ -0,0 +1,118 @@ +import argparse +import glob +from glob import glob +import os +import yaml +def get_parser(config_file): + parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") + parser.add_argument( + "--config-file", + default=config_file, + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") + parser.add_argument("--video-input", help="Path to video file." + # , default='/shared/niudt/detectron2/images/Videos/k2/4.MP4' + ) + parser.add_argument( + "--input", + nargs="+", + help="A list of space separated input images; " + "or a single glob pattern such as 'directory/*.jpg'" + , default= ['/shared/niudt/DATASET/Medical/Maydemo/2023-4-25/selected_videos/new/M2-16/*.jpg'] # please change here to the path where you put the images + ) + parser.add_argument( + "--output", + help="A file or directory to save output visualizations. " + "If not given, will show output in an OpenCV window." + , default='./bbox_detection_results' + ) + + parser.add_argument( + "--confidence-threshold", + type=float, + default=0.8, + help="Minimum score for instance predictions to be shown", + ) + parser.add_argument( + "--opts", + help="Modify config options using the command-line 'KEY VALUE' pairs", + default=[], + nargs=argparse.REMAINDER, + ) + return parser + +def load_yaml_as_dict(yaml_path): + with open(yaml_path, 'r') as f: + config_dict = yaml.load(f, Loader=yaml.FullLoader) + return config_dict + +def dictionary_contents(path: str, types: list, recursive: bool = False) -> list: + """ + Extract files of specified types from directories, optionally recursively. + + Parameters: + path (str): Root directory path. + types (list): List of file types (extensions) to be extracted. + recursive (bool, optional): Search for files in subsequent directories if True. Default is False. + + Returns: + list: List of file paths with full paths. + """ + files = [] + if recursive: + path = path + "/**/*" + for type in types: + if recursive: + for x in glob(path + type, recursive=True): + files.append(os.path.join(path, x)) + else: + for x in glob(path + type): + files.append(os.path.join(path, x)) + return files + +def create_dir_if_doesnt_exist(dir_path): + """ + This function creates a dictionary if it doesnt exist. + :param dir_path: string, dictionary path + """ + if not os.path.exists(dir_path): + os.makedirs(dir_path) + return + + +def initialize_coco_json(): + + json_file = {} + json_file['images'] = [] + json_file['annotations'] = [] + json_file['categories'] = [] + + # get new categorie + temp_bbox = {} + temp_bbox['id'] = 1 + temp_bbox['name'] = 'patient' + temp_bbox['instances_count'] = 0 + temp_bbox['def'] = '' + temp_bbox['synonyms'] = ['patient'] + temp_bbox['image_count'] = 0 + temp_bbox['frequency'] = '' + temp_bbox['synset'] = '' + + json_file['categories'].append(temp_bbox) + + temp_bbox = {} + temp_bbox['id'] = 2 + temp_bbox['name'] = 'user' + temp_bbox['instances_count'] = 0 + temp_bbox['def'] = '' + temp_bbox['synonyms'] = ['user'] + temp_bbox['image_count'] = 0 + temp_bbox['frequency'] = '' + temp_bbox['synset'] = '' + + json_file['categories'].append(temp_bbox) + ann_num = 0 + + return json_file \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/video_to_frames.py b/tcn_hpl/data/utils/pose_generation/video_to_frames.py new file mode 100644 index 000000000..0b99d7581 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/video_to_frames.py @@ -0,0 +1,89 @@ +# coding=utf-8 +""" +The data is already preprocessed into frames. + +At this point this script is not needed. +""" + + +import os +import cv2 +from glob import glob +import json +import utils + + + + +def main(): + # the dir including videos you want to process + videos_src_path = '/data/datasets/ptg/m2_tourniquet/' + # the save path + videos_save_path = '/data/datasets/ptg/m2_tourniquet/imgs' + + videos = utils.dictionary_contents(videos_src_path, types=['*.mp4', '*.MP4'], recursive=True) + # print(videos) + # exit() + # videos = filter(lambda x: x.endswith('MP4'), videos) + coco_json = {"info": {"description": "Medical Pose Estimation", + "year": 2023, + "contributer": "Kitware", + "version": "0.1"}, + "images": [], + "annotations": []} + + for index, each_video in enumerate(videos): + # get the name of each video, and make the directory to save frames + video_name = each_video.split('/')[-1].split('.')[0] + print('Video Name :', video_name) + # each_video_name, _ = each_video.split('.')each_video.split('.') + dir_path = f"{videos_save_path}/{video_name}" + utils.create_dir_if_doesnt_exist(dir_path) + # each_video_save_full_path = os.path.join(videos_save_path, each_video_name) + '/' + + # get the full path of each video, which will open the video tp extract frames + # each_video_full_path = os.path.join(videos_src_path, each_video) + + cap = cv2.VideoCapture(each_video) + + frame_count = 1 + + frame_rate = 1 + success = True + # 计数 + num = 0 + while (success): + success, frame = cap.read() + if success == True: + # if not os.path.exists(each_video_save_full_path + video_name): + # os.makedirs(each_video_save_full_path + video_name) + height, width, channels = frame.shape + # print(frame.shape) + file_name = f"{video_name}_{frame_count}.jpg" + image_dict = { + "id": f"{index}{frame_count}", + "file_name": file_name, + "video_name": f"{video_name}", + "video_id": index, + "width": width, + "height": height, + } + coco_json["images"].append(image_dict) + if frame_count % frame_rate == 0: + # cv2.imwrite(each_video_save_full_path + video_name + '/'+ "%06d.jpg" % num, frame) + cv2.imwrite(f"{dir_path}/{file_name}", frame) + image_dict["path"] = f"{dir_path}/{file_name}" + num += 1 + + frame_count = frame_count + 1 + print('Final frame:', num) + + # coco_json_save_path = f"{videos_src_path}/medical_coco.json" + coco_json_save_path = "medical_coco.json" + + with open(coco_json_save_path, "w") as outfile: + json.dump(coco_json, outfile) + +if __name__ == "__main__": + main() + diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index 3c085e200..71f8fb3a2 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -3,6 +3,7 @@ import glob import warnings import kwcoco +import shutil import numpy as np import ubelt as ub @@ -15,37 +16,61 @@ time_from_name, sanitize_str, ) -from angel_system.activity_hmm.train_activity_classifier import ( +from angel_system.activity_classification.train_activity_classifier import ( data_loader, compute_feats, ) from angel_system.data.data_paths import grab_data, data_dir +def load_yaml_as_dict(yaml_path): + with open(yaml_path, 'r') as f: + config_dict = yaml.load(f, Loader=yaml.FullLoader) + return config_dict ##################### # Inputs ##################### -recipe = "coffee+tea" -obj_exp_name = "coffee_base" #"coffee+tea_yolov7" +task = "m2" +obj_exp_name = "p_bbn_model_m2_m3_m5_r18_v11" #"coffee+tea_yolov7" -# obj_dets_dir = f"{data_dir}/annotations/{recipe}/results/{obj_exp_name}" -obj_dets_dir = "/data/PTG/cooking/object_anns/old_coffee/results/coffee_base/" #"/home/local/KHQ/hannah.defazio/yolov7/runs/detect/coffee+tea_yolov7/" -ptg_root = "/home/local/KHQ/hannah.defazio/angel_system/" -activity_config_path = f"{ptg_root}/config/activity_labels" -activity_config_fn = f"{activity_config_path}/recipe_{recipe}.yaml" +# f"{ptg_root}/config/activity_labels/medical" +# obj_dets_dir = f"{data_dir}/annotations/{recipe}/results/{obj_exp_name}" +# obj_dets_dir = "/data/PTG/cooking/object_anns/old_coffee/results/coffee_base/" #"/home/local/KHQ/hannah.defazio/yolov7/runs/detect/coffee+tea_yolov7/" +# obj_dets_dir = "/home/local/KHQ/peri.akiva/projects/medical-pose/bbox_detection_results/RESULTS_m2_with_lab_cleaned_fixed_data_with_steps_results_train_activity_with_patient_dets.mscoco.json" +ptg_root = "/home/local/KHQ/peri.akiva/angel_system" +activity_config_path = f"{ptg_root}/config/activity_labels/medical" +# activity_config_fn = f"{activity_config_path}/task_{task}.yaml" +activity_config_fn = "/home/local/KHQ/peri.akiva/projects/angel_system/config/activity_labels/medical_tourniquet.v3.yaml" -feat_version = 5 +feat_version = 6 using_done = False # Set the gt according to when an activity is done ##################### # Output ##################### -exp_name = "coffee_only_data_test_feat_v5"#f"coffee_and_tea_feat_v{str(feat_version)}" +reshuffle_datasets = True +augment = False +num_augs = 5 +if augment: + num_augs = num_augs + aug_trans_range, aug_rot_range = [-5, 5], [-5, 5] +else: + num_augs = 1 + aug_trans_range, aug_rot_range = None, None + +feat_type = "with_pose" #[no_pose, with_pose, only_hands_joints, only_objects_joints] +feat_to_bools = { + "no_pose": [False, False], + "with_pose": [True, True], + "only_hands_joints": [False, True], + "only_objects_joints": [True, False] +} +exp_name = f"p_m2_tqt_data_test_feat_v6_{feat_type}_v2_aug_{augment}_reshuffle_{reshuffle_datasets}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] if using_done: exp_name = f"{exp_name}_done_gt" -output_data_dir = f"{data_dir}/TCN_data/{recipe}/{exp_name}" +output_data_dir = f"/data/PTG/TCN_data/m2/{exp_name}" gt_dir = f"{output_data_dir}/groundTruth" frames_dir = f"{output_data_dir}/frames" @@ -54,6 +79,8 @@ # Create directories for folder in [output_data_dir, gt_dir, frames_dir, bundle_dir, features_dir]: + if os.path.exists(folder): + shutil.rmtree(folder) Path(folder).mkdir(parents=True, exist_ok=True) # Clear out the bundles @@ -68,6 +95,8 @@ activity_config = yaml.safe_load(stream) activity_labels = activity_config["labels"] +print(f"activity labels: {activity_labels}") + with open(f"{output_data_dir}/mapping.txt", "w") as mapping: for label in activity_labels: i = label["id"] @@ -76,18 +105,137 @@ continue mapping.write(f"{i} {label_str}\n") -##################### +################### # Features, # groundtruth and # bundles ##################### -for split in ["train_activity", "val", "test"]: - kwcoco_file = f"{obj_dets_dir}/{obj_exp_name}_results_{split}_conf_0.1_plus_hl_hands_new_obj_labels.mscoco.json" - dset = kwcoco.CocoDataset(kwcoco_file) +### create splits dsets + +names_black_gloves = [22,23,26,24,25,27,29,28,41,42,43,44,45,46,47,48,49,78, + 79,84,88,90,80,81,82,83,85,86,87,89,91,99,110,111,121,113,115,116] +names_blue_gloves = [132,133,50,51,54,55,56,52,61,59,53,57,62,65,66,67,68,69, + 58,60,63,64,125,126,127,129,131,134,135,136,128,130,137, + 138,139] + +filter_black_gloves, filter_blue_gloves = False, False + +train_file = "/data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_train_activity_obj_results_with_dets_and_pose.mscoco.json" +val_file = "/data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_val_obj_results_with_dets_and_pose.mscoco.json" +test_file = "/data/users/peri.akiva/datasets/ptg/m2_good_images_only_no_amputation_stump_test_obj_results_with_dets_and_pose.mscoco.json" + +# train_file = "/data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_train_activity_obj_results.mscoco.json" +# val_file = "/data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_val_obj_results.mscoco.json" +# test_file = "/data/PTG/medical/training/yolo_object_detector/detect/m2_good_images_only_no_amputation_stump/m2_good_images_only_no_amputation_stump_test_obj_results.mscoco.json" +train_dset = kwcoco.CocoDataset(train_file) +val_dset = kwcoco.CocoDataset(val_file) +test_dset = kwcoco.CocoDataset(test_file) + +if reshuffle_datasets: + dset = kwcoco.CocoDataset.union(train_dset,val_dset,test_dset) + + print(f"vids: {dset.index.videos}") + + train_img_ids, val_img_ids, test_img_ids = [], [], [] + + train_names = [1, 2, 4, 8, 9, 10, 11, 12, 16, 17,18, 20, 19, 30, 31, 32, 33, 34,35,36, + 7,132,133,50,51,54,56,52,61,53,57,65,66,67,68,69,58,60,63,64,125,126, + 127,129,131,134,135,136,119,122,124,70,72,92,93,94,95,97,98,100, + 101,102,103,104,105,107,108,112,114,117,118,73] + + # bad step distribution GT: [37, 6, 76, 39, 38, 30] + val_names= [5, 59,106,130,138, 77, 123, 71] + + # test_vivds= [5,6,37,59,106,130,138] + + test_names = [3,14,55,62,96,109,128,137,139, 120, 75, 21, 13] + + train_vidids = [dset.index.name_to_video[f"M2-{index}"]['id'] for index in train_names if f"M2-{index}" in dset.index.name_to_video.keys()] + val_vivids = [dset.index.name_to_video[f"M2-{index}"]['id'] for index in val_names if f"M2-{index}" in dset.index.name_to_video.keys()] + test_vivds = [dset.index.name_to_video[f"M2-{index}"]['id'] for index in test_names if f"M2-{index}" in dset.index.name_to_video.keys()] + + if filter_black_gloves: + vidids_black_gloves = [dset.index.name_to_video[f"M2-{index}"]['id'] for index in names_black_gloves if f"M2-{index}" in dset.index.name_to_video.keys()] + + train_vidids = [x for x in train_vidids if x not in vidids_black_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_black_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_black_gloves] + + if filter_blue_gloves: + vidids_blue_gloves = [dset.index.name_to_video[f"M2-{index}"]['id'] for index in names_blue_gloves if f"M2-{index}" in dset.index.name_to_video.keys()] + + train_vidids = [x for x in train_vidids if x not in vidids_blue_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] + + + ## individual splits by gids + # total = len(dset.index.imgs) + # inds = [i for i in range(1, total+1)] + # train_size, val_size = int(0.8*total), int(0.1*total) + # test_size = total - train_size - val_size + # train_inds = set(list(np.random.choice(inds, size=train_size, replace=False))) + # remaining_inds = list(set(inds) - train_inds) + # val_inds = set(list(np.random.choice(remaining_inds, size=val_size, replace=False))) + # test_inds = list(set(remaining_inds) - val_inds) + # # test_inds = list(np.random.choice(remaining_inds, size=test_size, replace=False)) + + # train_img_ids = [dset.index.imgs[i]['id'] for i in train_inds] + # val_img_ids = [dset.index.imgs[i]['id'] for i in val_inds] + # test_img_ids = [dset.index.imgs[i]['id'] for i in test_inds] + # print(f"train: {len(train_inds)}, val: {len(val_inds)}, test: {len(test_inds)}") + + for vid in train_vidids: + # print(type(dset.index.vidid_to_gids[vid])) + # print(type(list(dset.index.vidid_to_gids[vid]))) + # print(list(dset.index.vidid_to_gids[vid])) + if vid in dset.index.vidid_to_gids.keys(): + train_img_ids.extend(list(dset.index.vidid_to_gids[vid])) + else: + print(f"{vid} not in the train dataset") + + for vid in val_vivids: + + if vid in dset.index.vidid_to_gids.keys(): + val_img_ids.extend(list(dset.index.vidid_to_gids[vid])) + else: + print(f"{vid} not in the val dataset") + # val_img_ids = set(val_img_ids) + set(dset.index.vidid_to_gids[vid]) + + for vid in test_vivds: + + if vid in dset.index.vidid_to_gids.keys(): + test_img_ids.extend(list(dset.index.vidid_to_gids[vid])) + else: + print(f"{vid} not in the test dataset") + # test_img_ids = set(test_img_ids) + set(dset.index.vidid_to_gids[vid]) - for video_id in ub.ProgIter( - dset.index.videos.keys(), desc=f"Creating features for videos in {split}" - ): + + + train_img_ids, val_img_ids, test_img_ids = list(train_img_ids), list(val_img_ids), list(test_img_ids) + + print(f"train num_images: {len(train_img_ids)}, val num_images: {len(val_img_ids)}, test num_images: {len(test_img_ids)}") + + train_dset = dset.subset(gids=train_img_ids, copy=True) + val_dset = dset.subset(gids=val_img_ids, copy=True) + test_dset = dset.subset(gids=test_img_ids, copy=True) +# exit() + + +# for split in ["train_activity", "val", "test"]: +for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]): + # kwcoco_file = f"{obj_dets_dir}/{obj_exp_name}_results_{split}_conf_0.1_plus_hl_hands_new_obj_labels.mscoco.json" + + # print(f"dset length: {len(dset.index.imgs)}") + # exit() + # print(f"kps cats: {dset.keypoint_categories()}") + org_num_vidoes = len(list(dset.index.videos.keys())) + new_num_videos = 0 + for video_id in ub.ProgIter(dset.index.videos.keys()): + + # if video_id != 10: + # continue + video = dset.index.videos[video_id] video_name = video["name"] if "_extracted" in video_name: @@ -95,6 +243,8 @@ image_ids = dset.index.vidid_to_gids[video_id] num_images = len(image_ids) + + print(f"there are {num_images} in video {video_id}") video_dset = dset.subset(gids=image_ids, copy=True) @@ -108,40 +258,59 @@ act_id_to_str, ann_by_image, ) = data_loader(video_dset, activity_config) - X, y = compute_feats( - act_map, - image_activity_gt, - image_id_to_dataset, - label_to_ind, - act_id_to_str, - ann_by_image, - feat_version=feat_version - ) + + if split != "train_activity": + num_augs = 1 + aug_trans_range, aug_rot_range = None, None + + for aug_index in range(num_augs): + X, y = compute_feats( + act_map, + image_activity_gt, + image_id_to_dataset, + label_to_ind, + act_id_to_str, + ann_by_image, + feat_version=feat_version, + objects_joints=feat_to_bools[feat_type][0], + hands_joints=feat_to_bools[feat_type][1], + aug_trans_range = aug_trans_range, + aug_rot_range = aug_rot_range, + top_n_objects=2 + ) - X = X.T - print(f"X after transpose: {X.shape}") + X = X.T + print(f"X after transpose: {X.shape}") - np.save(f"{features_dir}/{video_name}.npy", X) + if num_augs != 1: + video_name_new = f"{video_name}_{aug_index}" + else: + video_name_new = video_name + + npy_path = f"{features_dir}/{video_name_new}.npy" + + np.save(npy_path, X) + print(f"Video info saved to: {npy_path}") - # groundtruth - with open(f"{gt_dir}/{video_name}.txt", "w") as gt_f, \ - open(f"{frames_dir}/{video_name}.txt", "w") as frames_f: - for image_id in image_ids: - image = dset.imgs[image_id] - image_n = image["file_name"] # this is the shortened string + # groundtruth + with open(f"{gt_dir}/{video_name_new}.txt", "w") as gt_f, \ + open(f"{frames_dir}/{video_name_new}.txt", "w") as frames_f: + for image_id in image_ids: + image = dset.imgs[image_id] + image_n = image["file_name"] # this is the shortened string - frame_idx, time = time_from_name(image_n) - - activity_gt = image["activity_gt"] - if activity_gt is None: - activity_gt = "background" + frame_idx, time = time_from_name(image_n) + + activity_gt = image["activity_gt"] + if activity_gt is None: + activity_gt = "background" - gt_f.write(f"{activity_gt}\n") - frames_f.write(f"{image_n}\n") + gt_f.write(f"{activity_gt}\n") + frames_f.write(f"{image_n}\n") - # bundles - with open(f"{bundle_dir}/{split}.split1.bundle", "a+") as bundle: - bundle.write(f"{video_name}.txt\n") + # bundles + with open(f"{bundle_dir}/{split}.split1.bundle", "a+") as bundle: + bundle.write(f"{video_name_new}.txt\n") print("Done!") print(f"Saved training data to {output_data_dir}") diff --git a/tcn_hpl/models/components/focal_loss.py b/tcn_hpl/models/components/focal_loss.py index b4414ba73..59344f278 100644 --- a/tcn_hpl/models/components/focal_loss.py +++ b/tcn_hpl/models/components/focal_loss.py @@ -11,10 +11,16 @@ def __init__(self, alpha=0.25, gamma=2, weight=None, reduction="mean"): self.gamma = gamma self.alpha = alpha self.reduction = reduction + # print(f"weight: {weight}, type: {type(weight)}") + if weight == "None": + weight=None + else: + weight = torch.Tensor(weight) self.ce = nn.CrossEntropyLoss( ignore_index=-100, - reduction="none" + reduction="none", + weight=weight ) def forward(self, inputs, targets): diff --git a/tcn_hpl/models/components/ms_tcs_net.py b/tcn_hpl/models/components/ms_tcs_net.py index 6b9771b11..8b336c932 100644 --- a/tcn_hpl/models/components/ms_tcs_net.py +++ b/tcn_hpl/models/components/ms_tcs_net.py @@ -2,7 +2,7 @@ from torch import nn import torch.nn.functional as F import copy - +import einops class MultiStageModel(nn.Module): def __init__(self, @@ -10,7 +10,8 @@ def __init__(self, num_layers, num_f_maps, dim, - num_classes): + num_classes, + window_size,): """Initialize a `MultiStageModel` module. :param num_stages: Nubmer of State Model Layers. @@ -29,13 +30,56 @@ def __init__(self, for s in range(num_stages - 1) ] ) + + + self.fc = nn.Sequential( + + nn.Linear(dim*window_size, 4096), + nn.GELU(), + nn.Dropout(0.25), + nn.Linear(4096, 8192), + nn.Dropout(0.25), + nn.Linear(8192, 16384), + nn.GELU(), + nn.Dropout(0.25), + nn.Linear(16384, 8192), + nn.GELU(), + nn.Dropout(0.25), + nn.Linear(8192, 4096), + nn.Dropout(0.25), + nn.Linear(4096, dim*window_size), + ) + # self.fc1 = nn.Linear(dim*30, 4096) + # self.act = nn.GELU() + # self.drop1 = nn.Dropout(0.1) + # self.fc2 = nn.Linear(4096, 8192) + # self.drop2 = nn.Dropout(0.1) + + # self.fc3 = nn.Linear(8192, 16384) + # self.act3 = nn.GELU() + # self.drop3 = nn.Dropout(0.1) + # self.fc4 = nn.Linear(16384, dim*30) + + # self.fc = nn.Linear(1280, 2048) def forward(self, x, mask): + b, d, c = x.shape + # print(f"x: {x.shape}") + # print(f"mask: {mask.shape}") + + re_x = einops.rearrange(x, 'b d c -> b (d c)') + re_x = self.fc(re_x) + x = einops.rearrange(re_x, 'b (d c) -> b d c', d=d, c=c) + # print(f"re_x: {re_x.shape}") + # print(f"x: {x.shape}") + out = self.stage1(x, mask) outputs = out.unsqueeze(0) for s in self.stages: out = s(F.softmax(out, dim=1) * mask[:, None, :], mask) outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) + + # print(f"outputs: {outputs.shape}") return outputs @@ -52,10 +96,12 @@ def __init__(self, num_layers, num_f_maps, dim, num_classes): self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) def forward(self, x, mask): + out = self.conv_1x1(x) for layer in self.layers: out = layer(out, mask) out = self.conv_out(out) * mask[:, None, :] + return out @@ -66,11 +112,17 @@ def __init__(self, dilation, in_channels, out_channels): in_channels, out_channels, 3, padding=dilation, dilation=dilation ) self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) - self.dropout = nn.Dropout() + self.dropout = nn.Dropout(0.5) + self.activation = nn.LeakyReLU(0.2) + self.norm = nn.BatchNorm1d(out_channels) + # self.pool = nn.MaxPool1d(kernel_size=3, stride=1) def forward(self, x, mask): out = F.relu(self.conv_dilated(x)) out = self.conv_1x1(out) + out = self.activation(out) + # out = self.pool(out) + out = self.norm(out) out = self.dropout(out) return (x + out) * mask[:, None, :] diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index 0816ae6ae..01d352f9b 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -1,5 +1,5 @@ from typing import Any, Dict, Tuple - +import os import torch from torch import nn import torch.nn.functional as F @@ -14,11 +14,11 @@ import seaborn as sns import kwcoco - +from statistics import mode from angel_system.data.common.load_data import ( time_from_name, ) - +import einops try: from aim import Image @@ -66,6 +66,7 @@ def __init__( optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, smoothing_loss: float, + use_smoothing_loss: bool, data_dir: str, num_classes: int, compile: bool, @@ -116,16 +117,67 @@ def __init__( # for tracking best so far validation accuracy self.val_acc_best = MaxMetric() + # self.val_acc_best = 0 + self.train_acc_best = MaxMetric() self.validation_step_outputs_prob = [] self.validation_step_outputs_pred = [] self.validation_step_outputs_target = [] self.validation_step_outputs_source_vid = [] self.validation_step_outputs_source_frame = [] + + self.training_step_outputs_target = [] + self.training_step_outputs_source_vid = [] + self.training_step_outputs_source_frame = [] + self.training_step_outputs_pred = [] + self.training_step_outputs_prob = [] + self.train_frames = None self.val_frames = None self.test_frames = None + + def plot_gt_vs_preds(self, per_video_frame_gt_preds, split='train', max_items=30): + # fig = plt.figure(figsize=(10,15)) + + for index, video in enumerate(per_video_frame_gt_preds.keys()): + + if index >= max_items: + return + + fig, ax = plt.subplots(figsize=(15, 8)) + video_gt_preds = per_video_frame_gt_preds[video] + frame_inds = sorted(list(video_gt_preds.keys())) + preds, gt, inds = [], [], [] + for ind in frame_inds: + inds.append(int(ind)) + gt.append(int(video_gt_preds[ind][0])) + preds.append(int(video_gt_preds[ind][1])) + + + # plt.plot(gt, label="gt") + # plt.plot(preds, label="preds") + + # sns.barplot(x=inds, y=gt, linestyle='dotted', color='magenta', label='GT', ax=ax) + # ax.stackplot(inds, gt, alpha=0.5, labels=['GT'], color=['magenta']) + sns.lineplot(x=inds, y=preds, linestyle='dotted', color='blue', linewidth=1, label='Pred', ax=ax).set(title=f'{split} Step Prediction Per Frame', xlabel='Index', ylabel='Step') + sns.lineplot(x=inds, y=gt, color='magenta', label='GT', ax=ax, linewidth=3) + + + # ax.xaxis.label.set_visible(False) + # ax.spines['bottom'].set_visible(False) + ax.legend() + # /confusion_mat_val.png" + # title = f"plot_pred_vs_gt_vid{video}.png" Path(folder).mkdir(parents=True, exist_ok=True) + # fig.savefig(f"{self.hparams.output_dir}/{title}") + root_dir = f"{self.hparams.output_dir}/steps_vs_preds" + + if not os.path.exists(root_dir): + os.makedirs(root_dir) + + fig.savefig(f"{root_dir}/{split}_{video}.png", pad_inches=5) + plt.close() + def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -143,6 +195,9 @@ def on_train_start(self) -> None: self.val_loss.reset() self.val_acc.reset() self.val_acc_best.reset() + # self.train_loss.reset() + # self.train_acc.reset() + # self.train_acc_best.reset() def compute_loss(self, p, y, mask): """Compute the total loss for a batch @@ -153,13 +208,64 @@ def compute_loss(self, p, y, mask): :return: The loss """ + + probs = torch.softmax(p, dim=1) # shape (batch size, self.hparams.num_classes) + preds = torch.argmax(probs, dim=1).float() # shape: batch size + + # print(f"prediction: {p}, GT: {y}"), # [bs, num_classes, window_size] + # print(f"prediction: {p.shape}, GT: {y.shape}") + + loss = torch.zeros((1)).to(p[0]) + + # print(f"loss: {loss.shape}") + # print(f"p loss: {p[:,:,-1].shape}") + # print(f"y: {y.view(-1).shape}") + # p = einops.rearrange(p, 'b c w -> (b w) c') + # print(f"prediction: {p.shape}, GT: {y.shape}") + # print(f"prediction: {p}, GT: {y}"), # [bs, num_classes, window_size] + + # TODO: Use only last frame per window + + # loss += self.criterion( + # p.transpose(2, 1).contiguous().view(-1, self.hparams.num_classes), + # y.view(-1), + # ) + loss += self.criterion( - p.transpose(2, 1).contiguous().view(-1, self.hparams.num_classes), - y.view(-1), + p[:,:,-1], + y[:,-1], ) + # need to penalize high volatility of predictions within a window + + # std, mean = torch.std_mean(preds, dim=-1) + mode, _ = torch.mode(y, dim=-1) + mode = einops.repeat(mode, 'b -> b c', c=preds.shape[-1]) + # print(f"mode: {mode.shape}") + # print(f"preds: {preds.shape}") + # print(f"mode: {mode[0,:]}") + # eps = 1e10 + # variation_coef = std/(mean+eps) + # variation_coef = torch.zeros_like(mode) + variation_coef = torch.abs(preds - mode) + variation_coef = torch.sum(variation_coef, dim=-1) + gt_variation_coef = torch.zeros_like(variation_coef) + # print(f"variation_coef: {variation_coef[0]}") + # print(f"variation_coef mean: {variation_coef.mean()}") + # print(f"p: {p.shape}") + # print(f"p[:, :, 1:]: {p[0, 0, 1:]}") + # print(f"F.log_softmax(p[:, :, 1:], dim=1): {F.log_softmax(p[:, :, 1:], dim=1)}") + + if self.hparams.use_smoothing_loss: + loss += self.hparams.smoothing_loss * torch.mean( + self.mse( + variation_coef, + gt_variation_coef, + ), + ) + loss += self.hparams.smoothing_loss * torch.mean( torch.clamp( self.mse( @@ -188,12 +294,14 @@ def model_step( - A tensor of target labels. - A tensor of [video name, frame idx] """ - x, y, m, source_vid, source_frame = batch # x shape: (batch size, window, feat dim) + x, y, m, source_vid, source_frame = batch + # x shape: (batch size, window, feat dim) # y shape: (batch size, window) # m shape: (batch size, window) # source_vid shape: (batch size, window) x = x.transpose(2, 1) # shape (batch size, feat dim, window) logits = self.forward(x, m) # shape (4, batch size, self.hparams.num_classes, window)) + # print(f"logits: {logits.shape}") loss = torch.zeros((1)).to(x) for p in logits: loss += self.compute_loss(p, y, m) @@ -216,6 +324,9 @@ def training_step( """ loss, probs, preds, targets, source_vid, source_frame = self.model_step(batch) + # print(f"targets: {targets}") + # print(f"preds: {preds}") + # update and log metrics self.train_loss(loss) self.train_acc(preds, targets[:,-1]) @@ -223,12 +334,100 @@ def training_step( self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + self.training_step_outputs_target.append(targets[:,-1]) + self.training_step_outputs_source_vid.append(source_vid[:,-1]) + self.training_step_outputs_source_frame.append(source_frame[:,-1]) + self.training_step_outputs_pred.append(preds) + self.training_step_outputs_prob.append(probs) + # return loss or backpropagation will fail return loss def on_train_epoch_end(self) -> None: "Lightning hook that is called when a training epoch ends." - pass + + # acc = self.train_acc.compute() # get current val acc + # self.train_acc_best(acc) # update best so far val acc + + all_targets = torch.concat(self.training_step_outputs_target) # shape: #frames + all_preds = torch.concat(self.training_step_outputs_pred) # shape: #frames + all_probs = torch.concat(self.training_step_outputs_prob) # shape (#frames, #act labels) + all_source_vids = torch.concat(self.training_step_outputs_source_vid) + all_source_frames = torch.concat(self.training_step_outputs_source_frame) + + if self.train_frames is None: + self.train_frames = {} + vid_list_file_train = f"{self.hparams.data_dir}/splits/train_activity.split1.bundle" + with open(vid_list_file_train, "r") as train_f: + self.train_videos = train_f.read().split("\n")[:-1] + + for video in self.train_videos: + # Load frame filenames for the video + frame_list_file_train = f"{self.hparams.data_dir}/frames/{video}" + with open(frame_list_file_train, "r") as train_f: + train_fns = train_f.read().split("\n")[:-1] + + self.train_frames[video[:-4]] = train_fns + + per_video_frame_gt_preds = {} + + for (gt, pred, source_vid, source_frame) in zip(all_targets, all_preds, all_source_vids, all_source_frames): + video_name = self.train_videos[int(source_vid)][:-4] + + + if video_name not in per_video_frame_gt_preds.keys(): + per_video_frame_gt_preds[video_name] = {} + + frame = self.train_frames[video_name][int(source_frame)] + frame_idx, time = time_from_name(frame) + + per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) + + # dset.add_image( + # file_name=frame, + # video_id=vid, + # frame_index=frame_idx, + # activity_gt=int(gt), + # activity_pred=int(pred), + # activity_conf=prob.tolist() + # ) + + # print(f"per_video_frame_gt_preds: {per_video_frame_gt_preds}") + # exit() + + self.plot_gt_vs_preds(per_video_frame_gt_preds, split='train') + + cm = confusion_matrix( + all_targets.cpu().numpy(), all_preds.cpu().numpy(), + labels=self.class_ids, + normalize="true" + ) + + num_act_classes = len(self.class_ids) + fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) + + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) + + # labels, title and ticks + ax.set_xlabel('Predicted labels') + ax.set_ylabel('True labels') + ax.set_title(f'CM Training Epoch {self.current_epoch}') + ax.xaxis.set_ticklabels(self.classes, rotation=25) + ax.yaxis.set_ticklabels(self.classes, rotation=0) + + self.logger.experiment.track(Image(fig), name=f'CM Training Epoch') + + fig.savefig(f"{self.hparams.output_dir}/confusion_mat_train.png", pad_inches=5) + + plt.close(fig) + + self.training_step_outputs_target.clear() + self.training_step_outputs_source_vid.clear() + self.training_step_outputs_source_frame.clear() + self.training_step_outputs_pred.clear() + self.training_step_outputs_prob.clear() + + # pass def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: """Perform a single validation step on a batch of data from the validation set. @@ -237,26 +436,79 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i labels. :param batch_idx: The index of the current batch. """ + _, _, mask, _, _ = batch loss, probs, preds, targets, source_vid, source_frame = self.model_step(batch) # update and log metrics self.val_loss(loss) - self.val_acc(preds, targets[:,-1]) + + # print(f"preds: {preds.shape}, targets: {targets.shape}") + # print(f"mask: {mask.shape}, {mask[0,:]}") + ys = targets[:,-1] + # print(f"y: {ys.shape}") + # print(f"y: {ys}") + windowed_preds, windowed_ys = [], [] + window_size = 45 + center = 22 + inds = [] + for i in range(preds.shape[0] - window_size + 1): + y = ys[i:i+window_size].tolist() + # print(f"y: {y}") + # print(f"len of set: {len(list(set(y)))}") + if len(list(set(y))) == 1: + inds.append(i+center-1) + windowed_preds.append(preds[i+center-1]) + windowed_ys.append(ys[i+center-1]) + + + windowed_preds = torch.tensor(windowed_preds).to(targets) + windowed_ys = torch.tensor(windowed_ys).to(targets) + + # print(f"preds: {preds.shape}, targets: {targets.shape}") + # print(f"windowed_preds: {windowed_preds.shape}, windowed_ys: {windowed_ys.shape}") + + self.val_acc(windowed_preds, windowed_ys) self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - self.validation_step_outputs_target.append(targets[:,-1]) - self.validation_step_outputs_source_vid.append(source_vid[:,-1]) - self.validation_step_outputs_source_frame.append(source_frame[:,-1]) - self.validation_step_outputs_pred.append(preds) - self.validation_step_outputs_prob.append(probs) - + self.validation_step_outputs_target.append(targets[inds,-1]) + self.validation_step_outputs_source_vid.append(source_vid[inds,-1]) + self.validation_step_outputs_source_frame.append(source_frame[inds,-1]) + self.validation_step_outputs_pred.append(preds[inds]) + self.validation_step_outputs_prob.append(probs[inds]) + + + # def plot_gt_vs_activations(self, step_gts, fname_suffix=None): + # # Plot gt vs predicted class across all vid frames + # fig = plt.figure() + # sns.set(font_scale=1) + # step_gts = [float(i) for i in step_gts] + # plt.plot(step_gts, label="gt") + # starting_zero_value = 0 + # for i in range(len(self.avg_probs)): #TODO - swap len(self.avg_probs) with the number of activities you're tracking + # starting_zero_value -= 2 + # plot_line = np.asarray(self.activity_conf_history)[:, i]. #TODO: this 1D array is the activity confidences for one activity class i. + # plt.plot(2 * plot_line + starting_zero_value, label=f"act_preds[{i}]") + + # plt.legend() + # if not fname_suffix: + # fname_suffix = f"vid{vid_id}" + # recipe_type = foo # TODO: fill with the task name + # title = f"plot_pred_vs_gt_{recipe_type}_{fname_suffix}.png" + # plt.title(title) + # fig.savefig(f"./outputs/{title}") # TODO: output this wherever you want + # plt.close() + + + def on_validation_epoch_end(self) -> None: "Lightning hook that is called when a validation epoch ends." acc = self.val_acc.compute() # get current val acc - self.val_acc_best(acc) # update best so far val acc + + if self.current_epoch >= 15: + self.val_acc_best(acc) # update best so far val acc # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch - self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) all_targets = torch.concat(self.validation_step_outputs_target) # shape: #frames all_preds = torch.concat(self.validation_step_outputs_pred) # shape: #frames @@ -284,15 +536,26 @@ def on_validation_epoch_end(self) -> None: dset.fpath = f"{self.hparams.output_dir}/val_activity_preds_epoch{self.current_epoch}.mscoco.json" dset.dataset["info"].append({"activity_labels": self.action_id_to_str}) + # print(f"video_lookup: {video_lookup}") + per_video_frame_gt_preds = {} + for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): video_name = self.val_videos[int(source_vid)][:-4] + + if video_name not in per_video_frame_gt_preds.keys(): + per_video_frame_gt_preds[video_name] = {} + + + video_lookup = dset.index.name_to_video vid = video_lookup[video_name]["id"] if video_name in video_lookup else dset.add_video(name=video_name) frame = self.val_frames[video_name][int(source_frame)] frame_idx, time = time_from_name(frame) + per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) + dset.add_image( file_name=frame, video_id=vid, @@ -301,9 +564,12 @@ def on_validation_epoch_end(self) -> None: activity_pred=int(pred), activity_conf=prob.tolist() ) - dset.dump(dset.fpath, newlines=True) - print(f"Saved dset to {dset.fpath}") + # dset.dump(dset.fpath, newlines=True) + # print(f"Saved dset to {dset.fpath}") + + # print(f"per_video_frame_gt_preds: {per_video_frame_gt_preds}") + self.plot_gt_vs_preds(per_video_frame_gt_preds, split='validation') # Create confusion matrix cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -314,17 +580,19 @@ def on_validation_epoch_end(self) -> None: num_act_classes = len(self.class_ids) fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) - sns.heatmap(cm, annot=True, ax=ax, fmt="0.0%", linewidth=0.5) + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) # labels, title and ticks ax.set_xlabel('Predicted labels') ax.set_ylabel('True labels') ax.set_title(f'CM Validation Epoch {self.current_epoch}') - ax.xaxis.set_ticklabels(self.classes, rotation=90) + ax.xaxis.set_ticklabels(self.classes, rotation=25) ax.yaxis.set_ticklabels(self.classes, rotation=0) self.logger.experiment.track(Image(fig), name=f'CM Validation Epoch') + fig.savefig(f"{self.hparams.output_dir}/confusion_mat_val.png", pad_inches=5) + plt.close(fig) self.validation_step_outputs_target.clear() @@ -384,15 +652,22 @@ def on_test_epoch_end(self) -> None: dset.fpath = f"{self.hparams.output_dir}/test_activity_preds.mscoco.json" dset.dataset["info"].append({"activity_labels": self.action_id_to_str}) + + per_video_frame_gt_preds = {} for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): video_name = self.test_videos[int(source_vid)][:-4] + if video_name not in per_video_frame_gt_preds.keys(): + per_video_frame_gt_preds[video_name] = {} + video_lookup = dset.index.name_to_video vid = video_lookup[video_name]["id"] if video_name in video_lookup else dset.add_video(name=video_name) frame = self.test_frames[video_name][int(source_frame)] frame_idx, time = time_from_name(frame) + per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) + dset.add_image( file_name=frame, video_id=vid, @@ -404,6 +679,8 @@ def on_test_epoch_end(self) -> None: dset.dump(dset.fpath, newlines=True) print(f"Saved dset to {dset.fpath}") + + self.plot_gt_vs_preds(per_video_frame_gt_preds, split='test') # Create confusion matrix cm = confusion_matrix( all_targets.cpu().numpy(), all_preds.cpu().numpy(), @@ -414,16 +691,16 @@ def on_test_epoch_end(self) -> None: num_act_classes = len(self.class_ids) fig, ax = plt.subplots(figsize=(num_act_classes,num_act_classes)) - sns.heatmap(cm, annot=True, ax=ax, fmt=".2f") + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", vmin=0, vmax=1) # labels, title and ticks ax.set_xlabel('Predicted labels') ax.set_ylabel('True labels') ax.set_title(f'CM Test Epoch {self.current_epoch}') - ax.xaxis.set_ticklabels(self.classes, rotation=90) + ax.xaxis.set_ticklabels(self.classes, rotation=25) ax.yaxis.set_ticklabels(self.classes, rotation=0) - fig.savefig(f"{self.hparams.output_dir}/test_confusion_mat.png", pad_inches=5) + fig.savefig(f"{self.hparams.output_dir}/confusion_mat_test_acc_{self.test_acc.compute():0.2f}.png", pad_inches=5) self.logger.experiment.track(Image(fig), name=f'CM Test Epoch') diff --git a/tcn_hpl/train.py b/tcn_hpl/train.py index 9e395b144..018de1034 100644 --- a/tcn_hpl/train.py +++ b/tcn_hpl/train.py @@ -49,6 +49,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: log.info(f"Instantiating datamodule <{cfg.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) + # print(f"datamodule: {datamodule.__dict__}") + # exit() + log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) @@ -76,7 +79,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + # trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + trainer.fit(model=model, datamodule=datamodule) train_metrics = trainer.callback_metrics @@ -86,7 +90,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: if ckpt_path == "": log.warning("Best ckpt not found! Using current weights for testing...") ckpt_path = None - trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + # trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + trainer.test(model=model, datamodule=datamodule, ckpt_path="best") + log.info(f"Best ckpt path: {ckpt_path}") test_metrics = trainer.callback_metrics diff --git a/tcn_hpl/utils/utils.py b/tcn_hpl/utils/utils.py index adb9efacf..5a5b478b8 100644 --- a/tcn_hpl/utils/utils.py +++ b/tcn_hpl/utils/utils.py @@ -1,14 +1,53 @@ import warnings from importlib.util import find_spec from typing import Any, Callable, Dict, Tuple - +import os +from glob import glob from omegaconf import DictConfig - +import yaml from tcn_hpl.utils import pylogger, rich_utils log = pylogger.get_pylogger(__name__) +def load_yaml_as_dict(yaml_path: str) -> dict: + """ + Load a YAML configuration file as a dictionary. + + Parameters: + yaml_path (str): Path to the YAML file. + + Returns: + dict: A dictionary containing the configuration settings. + """ + with open(yaml_path, 'r') as f: + config_dict = yaml.load(f, Loader=yaml.FullLoader) + return config_dict + +def dictionary_contents(path: str, types: list, recursive: bool = False) -> list: + """ + Extract files of specified types from directories, optionally recursively. + + Parameters: + path (str): Root directory path. + types (list): List of file types (extensions) to be extracted. + recursive (bool, optional): Search for files in subsequent directories if True. Default is False. + + Returns: + list: List of file paths with full paths. + """ + files = [] + if recursive: + path = path + "/**/*" + for type in types: + if recursive: + for x in glob(path + type, recursive=True): + files.append(os.path.join(path, x)) + else: + for x in glob(path + type): + files.append(os.path.join(path, x)) + return files + def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started.