diff --git a/configs/experiment/m3/feat_v6.yaml b/configs/experiment/m3/feat_v6.yaml index 372fd22dd..d90fb84bc 100644 --- a/configs/experiment/m3/feat_v6.yaml +++ b/configs/experiment/m3/feat_v6.yaml @@ -38,7 +38,7 @@ model: dim: 188 # length of feature vector data: - num_classes: 9 # activities: includes background + num_classes: 6 # activities: includes background batch_size: 512 num_workers: 0 epoch_length: 20000 @@ -51,12 +51,12 @@ data: MoveCenterPts: feat_version: 6 - num_obj_classes: 9 # not including background + num_obj_classes: 4 # not including background NormalizeFromCenter: feat_version: 6 NormalizePixelPts: feat_version: 6 - num_obj_classes: 9 # not including background + num_obj_classes: 4 # not including background data_gen: reshuffle_datasets: true diff --git a/configs/experiment/r18/feat_v6.yaml b/configs/experiment/r18/feat_v6.yaml index 0595d6446..1ad29e0b4 100644 --- a/configs/experiment/r18/feat_v6.yaml +++ b/configs/experiment/r18/feat_v6.yaml @@ -14,7 +14,8 @@ defaults: # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose" #[_v2_aug_False] -exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] +# exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] +exp_name: "p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] # exp_name: "p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] @@ -70,14 +71,20 @@ data_gen: val_vid_ids: [3, 7, 10, 18, 27, 32, 41] test_vid_ids: [50, 13, 47, 25] + train_vid_ids_bbn: [1, 2, 3, 4, 5, 6, 7, 8] + val_vid_ids_bbn: [9] + test_vid_ids_bbn: [10] + paths: - data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] + # data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True" #[_v2_aug_False] + data_dir: "/data/PTG/TCN_data/r18/p_r18_feat_v6_with_pose_v3_aug_False_reshuffle_True_bbn" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_False_reshuffle_False" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_True" #[_v2_aug_False] # data_dir: "/data/PTG/TCN_data/m2/p_m2_tqt_data_test_feat_v6_with_pose_v2_aug_True_reshuffle_False" #[_v2_aug_False] # root_dir: "/data/PTG/medical/training/activity_classifier/TCN_HPL" root_dir: "/data/users/peri.akiva/PTG/medical/training/activity_classifier/TCN_HPL" dataset_kwcoco: "/data/users/peri.akiva/datasets/ptg/r18_all_all_obj_results_with_dets_and_pose.mscoco.json" + dataset_kwcoco_lab: "/data/users/peri.akiva/datasets/ptg/r18_all_bbn_lab_data_all_obj_results_with_dets_and_pose.mscoco.json" activity_config_root: "/home/local/KHQ/peri.akiva/projects/angel_system/config/activity_labels" activity_config_fn: "${paths.activity_config_root}/${task}" ptg_root: "/home/local/KHQ/peri.akiva/angel_system" diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index b78c980f3..7df107add 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -3,7 +3,7 @@ _target_: tcn_hpl.models.ptg_module.PTGLitModule optimizer: _target_: torch.optim.Adam _partial_: true - lr: 0.0001 + lr: 0.000001 weight_decay: 0.0001 scheduler: diff --git a/tcn_hpl/data/components/PTG_dataset.py b/tcn_hpl/data/components/PTG_dataset.py index ae6680263..f59a92f7e 100644 --- a/tcn_hpl/data/components/PTG_dataset.py +++ b/tcn_hpl/data/components/PTG_dataset.py @@ -1,4 +1,3 @@ - import torch import numpy as np @@ -6,18 +5,19 @@ from typing import Optional, Callable, Dict, List from torchvision.transforms import transforms + class PTG_Dataset(torch.utils.data.Dataset): def __init__( self, - videos: List[str], + videos: List[str], num_classes: int, actions_dict: Dict[str, int], - gt_path: str, + gt_path: str, features_path: str, sample_rate: int, window_size: int, - transform: Optional[Callable]=None, - target_transform: Optional[Callable]=None + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ): self.num_classes = num_classes self.actions_dict = actions_dict @@ -37,7 +37,7 @@ def __init__( source_frames_list = [] for v, vid in enumerate(videos): features = np.load(self.features_path + vid.split(".")[0] + ".npy") - + file_ptr = open(self.gt_path + vid, "r") content = file_ptr.read().split("\n")[:-1] @@ -47,7 +47,7 @@ def __init__( for i in range(len(classes)): classes[i] = self.actions_dict[content[i]] source_vid[i] = v - source_frame[i] = i # find filename???? + source_frame[i] = i # find filename???? # mask out the end of the window size of the end of the sequence to prevent overlap between videos. mask = np.ones_like(classes) @@ -59,55 +59,62 @@ def __init__( source_vids_list.append(source_vid[:: self.sample_rate]) source_frames_list.append(source_frame[:: self.sample_rate]) - self.feature_frames = np.concatenate(input_frames_list, axis=1, dtype=np.single).transpose() - self.target_frames = np.concatenate(target_frames_list, axis=0, dtype=int, casting='unsafe') - self.mask_frames = np.concatenate(mask_frames_list, axis=0, dtype=int, casting='unsafe') - self.source_vids = np.concatenate(source_vids_list, axis=0, dtype=int, casting='unsafe') + self.feature_frames = np.concatenate( + input_frames_list, axis=1, dtype=np.single + ).transpose() + self.target_frames = np.concatenate( + target_frames_list, axis=0, dtype=int, casting="unsafe" + ) + self.mask_frames = np.concatenate( + mask_frames_list, axis=0, dtype=int, casting="unsafe" + ) + self.source_vids = np.concatenate( + source_vids_list, axis=0, dtype=int, casting="unsafe" + ) self.source_frames = np.concatenate(source_frames_list, dtype=int, axis=0) - # Transforms/Augmentations if self.transform is not None: self.feature_frames = self.transform(self.feature_frames.copy()) if self.target_transform is not None: self.target_frames = self.target_transform(self.target_frames.copy()) - #zero_idxs = random.sample(list(range(len(self.mask_frames))), len(self.mask_frames)*0.3) - #self.mask_frames[zero_idxs] = 0 + # zero_idxs = random.sample(list(range(len(self.mask_frames))), len(self.mask_frames)*0.3) + # self.mask_frames[zero_idxs] = 0 - self.norm_stats['mean'] = self.feature_frames.mean(axis=0) - self.norm_stats['std'] = self.feature_frames.std(axis=0) - self.norm_stats['max'] = self.feature_frames.max(axis=0) - self.norm_stats['min'] = self.feature_frames.min(axis=0) + self.norm_stats["mean"] = self.feature_frames.mean(axis=0) + self.norm_stats["std"] = self.feature_frames.std(axis=0) + self.norm_stats["max"] = self.feature_frames.max(axis=0) + self.norm_stats["min"] = self.feature_frames.min(axis=0) self.dataset_size = self.target_frames.shape[0] - self.window_size - # Get weights for sampler by inverse count. + # Get weights for sampler by inverse count. # Weights represent the GT of the final frame of a window starting from idx class_name, counts = np.unique(self.target_frames, return_counts=True) - class_weights = 1. / counts + class_weights = 1.0 / counts class_lookup = dict() for i, cn in enumerate(class_name): class_lookup[cn] = class_weights[i] self.weights = np.zeros((self.dataset_size)) for i in range(self.dataset_size): - self.weights[i] = class_lookup[self.target_frames[i+self.window_size]] + self.weights[i] = class_lookup[self.target_frames[i + self.window_size]] # Set weights to 0 for frames before the window length # So they don't get picked - self.weights[:self.window_size] = 0 + self.weights[: self.window_size] = 0 def __len__(self): return self.dataset_size def __getitem__(self, idx): - #print(f"window idx: {idx}:{idx+self.window_size}") + # print(f"window idx: {idx}:{idx+self.window_size}") """Grab a window of frames starting at ``idx`` :param idx: The first index of the time window :return: features, targets, and mask of the window """ - + # print(f"size of dataset: {self.__len__()}") # print(f"self.feature_frames: {self.feature_frames.shape}") # print(f"self.target_frames: {self.target_frames.shape}") @@ -115,18 +122,18 @@ def __getitem__(self, idx): # print(f"self.source_vids: {self.source_vids.shape}") # print(f"self.source_frames: {self.source_frames.shape}") # print(f"self.mask_frames: {self.mask_frames}") - - features = self.feature_frames[idx:idx+self.window_size, :] - target = self.target_frames[idx:idx+self.window_size] - mask = self.mask_frames[idx:idx+self.window_size] - source_vid = self.source_vids[idx:idx+self.window_size] - source_frame = self.source_frames[idx:idx+self.window_size] - + + features = self.feature_frames[idx : idx + self.window_size, :] + target = self.target_frames[idx : idx + self.window_size] + mask = self.mask_frames[idx : idx + self.window_size] + source_vid = self.source_vids[idx : idx + self.window_size] + source_frame = self.source_frames[idx : idx + self.window_size] + # print(f"mask: {mask}") # print(f"features: {features.shape}") # print(f"target: {target.shape}") # print(f"mask: {mask.shape}") # print(f"source_vid: {source_vid.shape}") # print(f"source_frame: {source_frame.shape}") - + return features, target, mask, np.array(source_vid), source_frame diff --git a/tcn_hpl/data/components/augmentations.py b/tcn_hpl/data/components/augmentations.py index 029592c76..df7db8f83 100644 --- a/tcn_hpl/data/components/augmentations.py +++ b/tcn_hpl/data/components/augmentations.py @@ -10,7 +10,14 @@ class MoveCenterPts(torch.nn.Module): """ def __init__( - self, hand_dist_delta, obj_dist_delta, window_size, im_w, im_h, num_obj_classes, feat_version + self, + hand_dist_delta, + obj_dist_delta, + window_size, + im_w, + im_h, + num_obj_classes, + feat_version, ): """ :param hand_dist_delta: Decimal percentage to calculate the +-offset in @@ -116,21 +123,19 @@ def forward(self, features): elif self.feat_version == 3: # Right and left hand distances - right_idx1 = 1; right_idx2 = 2; - left_idx1 = 4; left_idx2 = 5 + right_idx1 = 1 + right_idx2 = 2 + left_idx1 = 4 + left_idx2 = 5 for hand_delta_x, hand_delta_y, start_idx, end_idx in zip( [rhand_delta_x, lhand_delta_x], [rhand_delta_y, lhand_delta_y], [right_idx1, left_idx1], [right_idx2, left_idx2], ): - frame[start_idx] = ( - frame[start_idx] + hand_delta_x - ) - - frame[end_idx] = ( - frame[end_idx] + hand_delta_y - ) + frame[start_idx] = frame[start_idx] + hand_delta_x + + frame[end_idx] = frame[end_idx] + hand_delta_y # Object distances start_idx = 10 @@ -260,10 +265,12 @@ def forward(self, features): ) # Distance between hands - hands_dist_idx = left_dist_idx2 + hands_dist_idx = left_dist_idx2 features[:, hands_dist_idx] = features[:, hands_dist_idx] / self.im_w - features[:, hands_dist_idx + 1] = features[:, hands_dist_idx + 1] / self.im_h + features[:, hands_dist_idx + 1] = ( + features[:, hands_dist_idx + 1] / self.im_h + ) elif self.feat_version == 3: # Distances are from the center, skip @@ -278,9 +285,10 @@ def __repr__(self) -> str: detail = f"(im_w={self.im_w}, im_h={self.im_h}, num_obj_classes={self.num_obj_classes}, feat_version={self.feat_version})" return f"{self.__class__.__name__}{detail}" + class NormalizeFromCenter(torch.nn.Module): """Normalize the distances from -1 to 1 with respect to the image center - + Missing objects will be set to (2, 2) """ @@ -310,19 +318,17 @@ def forward(self, features): elif self.feat_version == 3: # Right and left hand distances - right_idx1 = 1; right_idx2 = 2; - left_idx1 = 4; left_idx2 = 5 + right_idx1 = 1 + right_idx2 = 2 + left_idx1 = 4 + left_idx2 = 5 for start_idx, end_idx in zip( [right_idx1, left_idx1], [right_idx2, left_idx2], ): - - features[:, start_idx] = ( - features[:, start_idx] / self.half_w - ) - features[:, end_idx] = ( - features[:, end_idx] / self.half_h - ) + + features[:, start_idx] = features[:, start_idx] / self.half_w + features[:, end_idx] = features[:, end_idx] / self.half_h # Object distances start_idx = 10 @@ -337,6 +343,7 @@ def forward(self, features): return features def __repr__(self) -> str: - detail = f"(im_w={self.im_w}, im_h={self.im_h}, feat_version={self.feat_version})" + detail = ( + f"(im_w={self.im_w}, im_h={self.im_h}, feat_version={self.feat_version})" + ) return f"{self.__class__.__name__}{detail}" - diff --git a/tcn_hpl/data/mnist_datamodule.py b/tcn_hpl/data/mnist_datamodule.py index a77053034..f98a43ed2 100644 --- a/tcn_hpl/data/mnist_datamodule.py +++ b/tcn_hpl/data/mnist_datamodule.py @@ -114,8 +114,12 @@ def setup(self, stage: Optional[str] = None) -> None: """ # load and split datasets only if not loaded already if not self.data_train and not self.data_val and not self.data_test: - trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) - testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) + trainset = MNIST( + self.hparams.data_dir, train=True, transform=self.transforms + ) + testset = MNIST( + self.hparams.data_dir, train=False, transform=self.transforms + ) dataset = ConcatDataset(datasets=[trainset, testset]) self.data_train, self.data_val, self.data_test = random_split( dataset=dataset, diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index 7260662ad..64a94d89c 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -86,17 +86,15 @@ def __init__( # data transformations transforms_list = [] - for transform_name in all_transforms['train_order']: + for transform_name in all_transforms["train_order"]: transforms_list.append(all_transforms[transform_name]) self.train_transform = transforms.Compose(transforms_list) - transforms_list = [] - for transform_name in all_transforms['test_order']: + for transform_name in all_transforms["test_order"]: transforms_list.append(all_transforms[transform_name]) self.val_transform = transforms.Compose(transforms_list) - self.data_train: Optional[Dataset] = None self.data_val: Optional[Dataset] = None self.data_test: Optional[Dataset] = None @@ -133,10 +131,15 @@ def setup(self, stage: Optional[str] = None) -> None: if not self.data_train and not self.data_val and not self.data_test: exp_data = self.hparams.data_dir - - vid_list_file = f"{exp_data}/splits/train_activity.split{self.hparams.split}.bundle" - vid_list_file_val = f"{exp_data}/splits/val.split{self.hparams.split}.bundle" - vid_list_file_tst = f"{exp_data}/splits/test.split{self.hparams.split}.bundle" + vid_list_file = ( + f"{exp_data}/splits/train_activity.split{self.hparams.split}.bundle" + ) + vid_list_file_val = ( + f"{exp_data}/splits/val.split{self.hparams.split}.bundle" + ) + vid_list_file_tst = ( + f"{exp_data}/splits/test.split{self.hparams.split}.bundle" + ) features_path = f"{exp_data}/features/" gt_path = f"{exp_data}/groundTruth/" @@ -159,7 +162,6 @@ def setup(self, stage: Optional[str] = None) -> None: with open(vid_list_file, "r") as train_f: train_videos = train_f.read().split("\n")[:-1] - # print(f"train_vids: {train_videos}") # exit() # Load validation vidoes @@ -171,26 +173,40 @@ def setup(self, stage: Optional[str] = None) -> None: test_videos = test_f.read().split("\n")[:-1] self.data_train = PTG_Dataset( - train_videos, self.hparams.num_classes, actions_dict, gt_path, - features_path, self.hparams.sample_rate, self.hparams.window_size, - transform=self.train_transform + train_videos, + self.hparams.num_classes, + actions_dict, + gt_path, + features_path, + self.hparams.sample_rate, + self.hparams.window_size, + transform=self.train_transform, ) - + # print(f"size:{self.data_train.__len__()}") # exit() self.data_val = PTG_Dataset( - val_videos, self.hparams.num_classes, actions_dict, gt_path, - features_path, self.hparams.sample_rate, self.hparams.window_size, - transform=self.val_transform + val_videos, + self.hparams.num_classes, + actions_dict, + gt_path, + features_path, + self.hparams.sample_rate, + self.hparams.window_size, + transform=self.val_transform, ) self.data_test = PTG_Dataset( - test_videos, self.hparams.num_classes, actions_dict, gt_path, - features_path, self.hparams.sample_rate, self.hparams.window_size, - transform=self.val_transform + test_videos, + self.hparams.num_classes, + actions_dict, + gt_path, + features_path, + self.hparams.sample_rate, + self.hparams.window_size, + transform=self.val_transform, ) - def train_dataloader(self) -> DataLoader[Any]: """Create and return the train dataloader. @@ -198,10 +214,10 @@ def train_dataloader(self) -> DataLoader[Any]: :return: The train dataloader. """ train_sampler = torch.utils.data.WeightedRandomSampler( - self.data_train.weights, - self.hparams.epoch_length, - replacement=True, - generator=None + self.data_train.weights, + self.hparams.epoch_length, + replacement=True, + generator=None, ) return DataLoader( dataset=self.data_train, @@ -216,7 +232,7 @@ def val_dataloader(self) -> DataLoader[Any]: :return: The validation dataloader. """ - + return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, diff --git a/tcn_hpl/data/utils/__init__.py b/tcn_hpl/data/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tcn_hpl/data/utils/pose_generation/__init__.py b/tcn_hpl/data/utils/pose_generation/__init__.py new file mode 100644 index 000000000..a82aa51e4 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/__init__.py @@ -0,0 +1,6 @@ +# # Copyright (c) OpenMMLab. All rights reserved. +# from .vit import ViT + +# __all__ = [ +# 'ViT' +# ] diff --git a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py index cc2940933..ca6b10ed1 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py +++ b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py @@ -1,34 +1,36 @@ -_base_ = [ - 'default_runtime.py', - 'medic_patient.py' -] -evaluation = dict(interval=210, metric='mAP', save_best='AP') +_base_ = ["default_runtime.py", "medic_patient.py"] +evaluation = dict(interval=210, metric="mAP", save_best="AP") -optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, - constructor='LayerDecayOptimizerConstructor', - paramwise_cfg=dict( - num_layers=12, - layer_decay_rate=0.75, - custom_keys={ - 'bias': dict(decay_multi=0.), - 'pos_embed': dict(decay_mult=0.), - 'relative_position_bias_table': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - } - ) - ) +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + constructor="LayerDecayOptimizerConstructor", + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + "bias": dict(decay_multi=0.0), + "pos_embed": dict(decay_mult=0.0), + "relative_position_bias_table": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + }, + ), +) -optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) +optimizer_config = dict(grad_clip=dict(max_norm=1.0, norm_type=2)) # learning policy lr_config = dict( - policy='step', - warmup='linear', + policy="step", + warmup="linear", warmup_iters=500, warmup_ratio=0.001, - step=[170, 200]) + step=[170, 200], +) total_epochs = 210 -target_type = 'GaussianHeatmap' +target_type = "GaussianHeatmap" channel_cfg = dict( num_output_channels=22, dataset_joints=22, @@ -36,15 +38,37 @@ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], ], inference_channel=[ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21 - ]) + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + ], +) # model settings model = dict( - type='TopDown', + type="TopDown", pretrained=None, backbone=dict( - type='ViT', + type="ViT", img_size=(256, 192), patch_size=16, embed_dim=768, @@ -57,30 +81,35 @@ drop_path_rate=0.3, ), keypoint_head=dict( - type='TopdownHeatmapSimpleHead', + type="TopdownHeatmapSimpleHead", in_channels=768, num_deconv_layers=2, num_deconv_filters=(256, 256), num_deconv_kernels=(4, 4), - extra=dict(final_conv_kernel=1, ), - out_channels=channel_cfg['num_output_channels'], - loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + extra=dict( + final_conv_kernel=1, + ), + out_channels=channel_cfg["num_output_channels"], + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), train_cfg=dict(), test_cfg=dict( flip_test=True, - post_process='default', + post_process="default", shift_heatmap=False, target_type=target_type, modulate_kernel=11, - use_udp=True)) + use_udp=True, + ), +) data_cfg = dict( image_size=[192, 256], heatmap_size=[48, 64], - num_output_channels=channel_cfg['num_output_channels'], - num_joints=channel_cfg['dataset_joints'], - dataset_channel=channel_cfg['dataset_channel'], - inference_channel=channel_cfg['inference_channel'], + num_output_channels=channel_cfg["num_output_channels"], + num_joints=channel_cfg["dataset_joints"], + dataset_channel=channel_cfg["dataset_channel"], + inference_channel=channel_cfg["inference_channel"], soft_nms=False, nms_thr=1.0, oks_thr=0.9, @@ -88,83 +117,85 @@ use_gt_bbox=False, det_bbox_thr=0.0, # bbox_file='/shared/niudt/DATASET/Medical/final_version_coco_jan21/bbox_detections_results/casualty.json', - bbox_file='/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json', + bbox_file="/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json", ) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='TopDownRandomFlip', flip_prob=0.5), - dict( - type='TopDownHalfBodyTransform', - num_joints_half_body=8, - prob_half_body=0.3), + dict(type="LoadImageFromFile"), + dict(type="TopDownRandomFlip", flip_prob=0.5), + dict(type="TopDownHalfBodyTransform", num_joints_half_body=8, prob_half_body=0.3), + dict(type="TopDownGetRandomScaleRotation", rot_factor=40, scale_factor=0.5), + dict(type="TopDownAffine", use_udp=True), + dict(type="ToTensor"), + dict(type="NormalizeTensor", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict( - type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), - dict(type='TopDownAffine', use_udp=True), - dict(type='ToTensor'), - dict( - type='NormalizeTensor', - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - dict( - type='TopDownGenerateTarget', - sigma=2, - encoding='UDP', - target_type=target_type), + type="TopDownGenerateTarget", sigma=2, encoding="UDP", target_type=target_type + ), dict( - type='Collect', - keys=['img', 'target', 'target_weight'], + type="Collect", + keys=["img", "target", "target_weight"], meta_keys=[ - 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', - 'rotation', 'bbox_score', 'flip_pairs' - ]), + "image_file", + "joints_3d", + "joints_3d_visible", + "center", + "scale", + "rotation", + "bbox_score", + "flip_pairs", + ], + ), ] val_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='TopDownAffine', use_udp=True), - dict(type='ToTensor'), - dict( - type='NormalizeTensor', - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), + dict(type="LoadImageFromFile"), + dict(type="TopDownAffine", use_udp=True), + dict(type="ToTensor"), + dict(type="NormalizeTensor", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict( - type='Collect', - keys=['img'], + type="Collect", + keys=["img"], meta_keys=[ - 'image_file', 'center', 'scale', 'rotation', 'bbox_score', - 'flip_pairs' - ]), + "image_file", + "center", + "scale", + "rotation", + "bbox_score", + "flip_pairs", + ], + ), ] test_pipeline = val_pipeline -data_root = '/shared/niudt/DATASET/Medical' +data_root = "/shared/niudt/DATASET/Medical" data = dict( samples_per_gpu=64, workers_per_gpu=4, val_dataloader=dict(samples_per_gpu=32), test_dataloader=dict(samples_per_gpu=32), train=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=train_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), val=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_val.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/casualty_val.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=val_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), test=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/casualty_train.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=test_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), ) - diff --git a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py index f8f79ab68..3f1aed2cc 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py +++ b/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_user_256x192.py @@ -1,50 +1,51 @@ -_base_ = [ - 'default_runtime.py', - 'medic_user.py' -] -evaluation = dict(interval=210, metric='mAP', save_best='AP') +_base_ = ["default_runtime.py", "medic_user.py"] +evaluation = dict(interval=210, metric="mAP", save_best="AP") -optimizer = dict(type='AdamW', lr=5e-4, betas=(0.9, 0.999), weight_decay=0.1, - constructor='LayerDecayOptimizerConstructor', - paramwise_cfg=dict( - num_layers=12, - layer_decay_rate=0.75, - custom_keys={ - 'bias': dict(decay_multi=0.), - 'pos_embed': dict(decay_mult=0.), - 'relative_position_bias_table': dict(decay_mult=0.), - 'norm': dict(decay_mult=0.) - } - ) - ) +optimizer = dict( + type="AdamW", + lr=5e-4, + betas=(0.9, 0.999), + weight_decay=0.1, + constructor="LayerDecayOptimizerConstructor", + paramwise_cfg=dict( + num_layers=12, + layer_decay_rate=0.75, + custom_keys={ + "bias": dict(decay_multi=0.0), + "pos_embed": dict(decay_mult=0.0), + "relative_position_bias_table": dict(decay_mult=0.0), + "norm": dict(decay_mult=0.0), + }, + ), +) -optimizer_config = dict(grad_clip=dict(max_norm=1., norm_type=2)) +optimizer_config = dict(grad_clip=dict(max_norm=1.0, norm_type=2)) # learning policy lr_config = dict( - policy='step', - warmup='linear', + policy="step", + warmup="linear", warmup_iters=500, warmup_ratio=0.001, - step=[170, 200]) + step=[170, 200], +) total_epochs = 210 -target_type = 'GaussianHeatmap' +target_type = "GaussianHeatmap" channel_cfg = dict( num_output_channels=6, dataset_joints=6, dataset_channel=[ [0, 1, 2, 3, 4, 5], ], - inference_channel=[ - 0, 1, 2, 3, 4, 5 - ]) + inference_channel=[0, 1, 2, 3, 4, 5], +) # model settings model = dict( - type='TopDown', + type="TopDown", pretrained=None, backbone=dict( - type='ViT', + type="ViT", img_size=(256, 192), patch_size=16, embed_dim=768, @@ -57,113 +58,120 @@ drop_path_rate=0.3, ), keypoint_head=dict( - type='TopdownHeatmapSimpleHead', + type="TopdownHeatmapSimpleHead", in_channels=768, num_deconv_layers=2, num_deconv_filters=(256, 256), num_deconv_kernels=(4, 4), - extra=dict(final_conv_kernel=1, ), - out_channels=channel_cfg['num_output_channels'], - loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)), + extra=dict( + final_conv_kernel=1, + ), + out_channels=channel_cfg["num_output_channels"], + loss_keypoint=dict(type="JointsMSELoss", use_target_weight=True), + ), train_cfg=dict(), test_cfg=dict( flip_test=True, - post_process='default', + post_process="default", shift_heatmap=False, target_type=target_type, modulate_kernel=11, - use_udp=True)) + use_udp=True, + ), +) data_cfg = dict( image_size=[192, 256], heatmap_size=[48, 64], - num_output_channels=channel_cfg['num_output_channels'], - num_joints=channel_cfg['dataset_joints'], - dataset_channel=channel_cfg['dataset_channel'], - inference_channel=channel_cfg['inference_channel'], + num_output_channels=channel_cfg["num_output_channels"], + num_joints=channel_cfg["dataset_joints"], + dataset_channel=channel_cfg["dataset_channel"], + inference_channel=channel_cfg["inference_channel"], soft_nms=False, nms_thr=1.0, oks_thr=0.9, vis_thr=0.2, use_gt_bbox=False, det_bbox_thr=0.0, - bbox_file='/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json', + bbox_file="/home/local/KHQ/peri.akiva/projects/Medical-Partial-Body-Pose-Estimation/bbox_detection_results/bbox_detections.json", ) train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='TopDownRandomFlip', flip_prob=0.5), - dict( - type='TopDownHalfBodyTransform', - num_joints_half_body=8, - prob_half_body=0.3), - dict( - type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5), - dict(type='TopDownAffine', use_udp=True), - dict(type='ToTensor'), + dict(type="LoadImageFromFile"), + dict(type="TopDownRandomFlip", flip_prob=0.5), + dict(type="TopDownHalfBodyTransform", num_joints_half_body=8, prob_half_body=0.3), + dict(type="TopDownGetRandomScaleRotation", rot_factor=40, scale_factor=0.5), + dict(type="TopDownAffine", use_udp=True), + dict(type="ToTensor"), + dict(type="NormalizeTensor", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict( - type='NormalizeTensor', - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - dict( - type='TopDownGenerateTarget', - sigma=2, - encoding='UDP', - target_type=target_type), + type="TopDownGenerateTarget", sigma=2, encoding="UDP", target_type=target_type + ), dict( - type='Collect', - keys=['img', 'target', 'target_weight'], + type="Collect", + keys=["img", "target", "target_weight"], meta_keys=[ - 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale', - 'rotation', 'bbox_score', 'flip_pairs' - ]), + "image_file", + "joints_3d", + "joints_3d_visible", + "center", + "scale", + "rotation", + "bbox_score", + "flip_pairs", + ], + ), ] val_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='TopDownAffine', use_udp=True), - dict(type='ToTensor'), + dict(type="LoadImageFromFile"), + dict(type="TopDownAffine", use_udp=True), + dict(type="ToTensor"), + dict(type="NormalizeTensor", mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), dict( - type='NormalizeTensor', - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - dict( - type='Collect', - keys=['img'], + type="Collect", + keys=["img"], meta_keys=[ - 'image_file', 'center', 'scale', 'rotation', 'bbox_score', - 'flip_pairs' - ]), + "image_file", + "center", + "scale", + "rotation", + "bbox_score", + "flip_pairs", + ], + ), ] test_pipeline = val_pipeline -data_root = '/shared/niudt/DATASET/Medical' +data_root = "/shared/niudt/DATASET/Medical" data = dict( samples_per_gpu=64, workers_per_gpu=4, val_dataloader=dict(samples_per_gpu=32), test_dataloader=dict(samples_per_gpu=32), train=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_train.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/user_train.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=train_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), val=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_val.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/user_val.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=val_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), test=dict( - type='TopDownCocoDataset', - ann_file=f'{data_root}/7_3_challenging_vitpose_all_images/user_train.json', - img_prefix=f'{data_root}/images/', + type="TopDownCocoDataset", + ann_file=f"{data_root}/7_3_challenging_vitpose_all_images/user_train.json", + img_prefix=f"{data_root}/images/", data_cfg=data_cfg, pipeline=test_pipeline, - dataset_info={{_base_.dataset_info}}), + dataset_info={{_base_.dataset_info}}, + ), ) - diff --git a/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py b/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py index d5fbc3d42..832b748c9 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py +++ b/tcn_hpl/data/utils/pose_generation/configs/default_runtime.py @@ -3,17 +3,20 @@ log_config = dict( interval=10, hooks=[ - dict(type='TextLoggerHook'), + dict(type="TextLoggerHook"), # dict(type='TensorboardLoggerHook') - ]) + ], +) -log_level = 'INFO' -load_from = '/shared/niudt/pose_estimation/vitpose/ViTPose/pretrained_model/vitpose-b.pth' +log_level = "INFO" +load_from = ( + "/shared/niudt/pose_estimation/vitpose/ViTPose/pretrained_model/vitpose-b.pth" +) resume_from = None -dist_params = dict(backend='nccl') -workflow = [('train', 1)] +dist_params = dict(backend="nccl") +workflow = [("train", 1)] # disable opencv multithreading to avoid system being overloaded opencv_num_threads = 0 # set multi-process start method as `fork` to speed up the training -mp_start_method = 'fork' +mp_start_method = "fork" diff --git a/tcn_hpl/data/utils/pose_generation/configs/main.yaml b/tcn_hpl/data/utils/pose_generation/configs/main.yaml index 20d2b0f62..612169f77 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/main.yaml +++ b/tcn_hpl/data/utils/pose_generation/configs/main.yaml @@ -1,4 +1,5 @@ -task: m3 +task: r18 +data_type: bbn #[lab, pro] root: /data/PTG/medical/bbn_data/Release_v0.5/v0.52/M2_Tourniquet/Data img_save_path: /data/users/peri.akiva/datasets/m2_tourniquet/imgs pose_model_config: /home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/configs/ViTPose_base_medic_casualty_256x192.py @@ -20,5 +21,6 @@ data: m5: /data/PTG/medical/training/yolo_object_detector/detect//m5_all/m5_all_all_obj_results.mscoco.json r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all/r18_all_all_obj_results.mscoco.json save_root: /data/users/peri.akiva/datasets/ptg - - \ No newline at end of file +bbn_lab: + r18: /data/PTG/medical/training/yolo_object_detector/detect//r18_all_bbn_lab_data/r18_all_bbn_lab_data_all_obj_results.mscoco.json + save_root: /data/users/peri.akiva/datasets/ptg \ No newline at end of file diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py b/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py index a25104550..896c6832a 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_patient.py @@ -1,220 +1,125 @@ dataset_info = dict( - dataset_name='coco_medic', + dataset_name="coco_medic", paper_info=dict( - author='Lin, Tsung-Yi and Maire, Michael and ' - 'Belongie, Serge and Hays, James and ' - 'Perona, Pietro and Ramanan, Deva and ' - r'Doll{\'a}r, Piotr and Zitnick, C Lawrence', - title='Microsoft coco: Common objects in context', - container='European conference on computer vision', - year='2014', - homepage='http://cocodataset.org/', + author="Lin, Tsung-Yi and Maire, Michael and " + "Belongie, Serge and Hays, James and " + "Perona, Pietro and Ramanan, Deva and " + r"Doll{\'a}r, Piotr and Zitnick, C Lawrence", + title="Microsoft coco: Common objects in context", + container="European conference on computer vision", + year="2014", + homepage="http://cocodataset.org/", ), keypoint_info={ - 0: - dict(name='nose', - id=0, - color=[0, 215, 255], - type='upper', - swap=''), - 1: - dict( - name='mouth', - id=1, - color=[0, 215, 255], - type='upper', - swap=''), - 2: - dict( - name='throat', - id=2, - color=[0, 215, 255], - type='upper', - swap=''), - 3: - dict( - name='chest', - id=3, - color=[0, 215, 255], - type='upper', - swap=''), - 4: - dict( - name='stomach', - id=4, - color=[0, 215, 255], - type='upper', - swap=''), - 5: - dict( - name='left_upper_arm', - id=5, - color=[255, 0, 0], - type='upper', - swap=''), - 6: - dict( - name='right_upper_arm', - id=6, - color=[0, 255, 0], - type='upper', - swap=''), - 7: - dict( - name='left_lower_arm', - id=7, - color=[255, 0, 0], - type='upper', - swap=''), - 8: - dict( - name='right_lower_arm', - id=8, - color=[0, 255, 0], - type='upper', - swap=''), - 9: - dict( - name='left_wrist', - id=9, - color=[255, 0, 0], - type='upper', - swap=''), - 10: - dict( - name='right_wrist', - id=10, - color=[0, 255, 0], - type='upper', - swap=''), - 11: - dict( - name='left_hand', - id=11, - color=[255, 0, 0], - type='upper', - swap=''), - 12: - dict( - name='right_hand', - id=12, - color=[0, 255, 0], - type='upper', - swap=''), - 13: - dict( - name='left_upper_leg', - id=13, - color=[255, 0, 0], - type='lower', - swap=''), - - 14: - dict( - name='right_upper_leg', - id=14, - color=[0, 255, 0], - type='lower', - swap=''), - 15: - dict( - name='left_knee', - id=15, - color=[255, 0, 0], - type='lower', - swap=''), - 16: - dict( - name='right_knee', - id=16, - color=[0, 255, 0], - type='lower', - swap=''), - 17: - dict( - name='left_lower_leg', - id=17, - color=[255, 0, 0], - type='lower', - swap=''), - 18: - dict( - name='right_lower_leg', - id=18, - color=[0, 255, 0], - type='lower', - swap=''), - 19: - dict( - name= 'left_foot', - id=19, - color=[255, 0, 0], - type='lower', - swap=''), - 20: - dict( - name='right_foot', - id=20, - color=[0, 255, 0], - type='lower', - swap=''), - 21: - dict( - name='back', - id=21, - color=[147, 20, 255], - type='lower', - swap=''), + 0: dict(name="nose", id=0, color=[0, 215, 255], type="upper", swap=""), + 1: dict(name="mouth", id=1, color=[0, 215, 255], type="upper", swap=""), + 2: dict(name="throat", id=2, color=[0, 215, 255], type="upper", swap=""), + 3: dict(name="chest", id=3, color=[0, 215, 255], type="upper", swap=""), + 4: dict(name="stomach", id=4, color=[0, 215, 255], type="upper", swap=""), + 5: dict(name="left_upper_arm", id=5, color=[255, 0, 0], type="upper", swap=""), + 6: dict(name="right_upper_arm", id=6, color=[0, 255, 0], type="upper", swap=""), + 7: dict(name="left_lower_arm", id=7, color=[255, 0, 0], type="upper", swap=""), + 8: dict(name="right_lower_arm", id=8, color=[0, 255, 0], type="upper", swap=""), + 9: dict(name="left_wrist", id=9, color=[255, 0, 0], type="upper", swap=""), + 10: dict(name="right_wrist", id=10, color=[0, 255, 0], type="upper", swap=""), + 11: dict(name="left_hand", id=11, color=[255, 0, 0], type="upper", swap=""), + 12: dict(name="right_hand", id=12, color=[0, 255, 0], type="upper", swap=""), + 13: dict( + name="left_upper_leg", id=13, color=[255, 0, 0], type="lower", swap="" + ), + 14: dict( + name="right_upper_leg", id=14, color=[0, 255, 0], type="lower", swap="" + ), + 15: dict(name="left_knee", id=15, color=[255, 0, 0], type="lower", swap=""), + 16: dict(name="right_knee", id=16, color=[0, 255, 0], type="lower", swap=""), + 17: dict( + name="left_lower_leg", id=17, color=[255, 0, 0], type="lower", swap="" + ), + 18: dict( + name="right_lower_leg", id=18, color=[0, 255, 0], type="lower", swap="" + ), + 19: dict(name="left_foot", id=19, color=[255, 0, 0], type="lower", swap=""), + 20: dict(name="right_foot", id=20, color=[0, 255, 0], type="lower", swap=""), + 21: dict(name="back", id=21, color=[147, 20, 255], type="lower", swap=""), }, skeleton_info={ - 0: - {'link': ('nose', 'mouth'), 'id': 0, 'color': [0, 215, 255]}, - 1: - {'link': ('mouth', 'throat'), 'id': 1, 'color': [0, 215, 255]}, - 2: - {'link': ('throat', 'chest'), 'id': 2, 'color': [0, 215, 255]}, - 3: - {'link': ('chest', 'stomach'), 'id': 3, 'color': [0, 215, 255]}, - 4: - {'link': ('throat', 'left_upper_arm'), 'id': 4, 'color': [255, 0, 0]}, - 5: - {'link': ('throat', 'right_upper_arm'), 'id': 5, 'color': [0, 255, 0]}, - 6: - {'link': ('left_upper_arm', 'left_lower_arm'), 'id': 6, 'color': [255, 0, 0]}, - 7: - {'link': ('right_upper_arm', 'right_lower_arm'), 'id': 7, 'color': [0, 255, 0]}, - 8: - {'link': ('left_lower_arm', 'left_wrist'), 'id': 8, 'color': [255, 0, 0]}, - 9: - {'link': ('right_lower_arm', 'right_wrist'), 'id': 9, 'color': [0, 255, 0]}, - 10: - {'link': ('left_wrist', 'left_hand'), 'id': 10, 'color': [255, 0, 0]}, - 11: - {'link': ('right_wrist', 'right_hand'), 'id': 11, 'color': [0, 255, 0]}, - 12: - {'link': ('stomach', 'left_upper_leg'), 'id': 12, 'color': [255, 0, 0]}, - 13: - {'link': ('stomach', 'right_upper_leg'), 'id': 13, 'color': [0, 255, 0]}, - 14: - {'link': ('left_upper_leg', 'left_knee'), 'id': 14, 'color': [255, 0, 0]}, - 15: - {'link': ('right_upper_leg', 'right_knee'), 'id': 15, 'color': [0, 255, 0]}, - 16: - {'link': ('left_knee', 'left_lower_leg'), 'id': 16, 'color': [255, 0, 0]}, - 17: - {'link': ('right_knee', 'right_lower_leg'), 'id': 17, 'color': [0, 255, 0]}, - 18: - {'link': ('left_lower_leg', 'left_foot'), 'id': 18, 'color': [255, 0, 0]}, - 19: - {'link': ('right_lower_leg', 'right_foot'), 'id': 19, 'color': [0, 255, 0]}, - 20: - {'link': ('right_upper_leg', 'back'), 'id': 20, 'color': [147, 20, 255]}, - 21: - {'link': ('left_upper_leg', 'back'), 'id': 21, 'color': [147, 20, 255]}, + 0: {"link": ("nose", "mouth"), "id": 0, "color": [0, 215, 255]}, + 1: {"link": ("mouth", "throat"), "id": 1, "color": [0, 215, 255]}, + 2: {"link": ("throat", "chest"), "id": 2, "color": [0, 215, 255]}, + 3: {"link": ("chest", "stomach"), "id": 3, "color": [0, 215, 255]}, + 4: {"link": ("throat", "left_upper_arm"), "id": 4, "color": [255, 0, 0]}, + 5: {"link": ("throat", "right_upper_arm"), "id": 5, "color": [0, 255, 0]}, + 6: { + "link": ("left_upper_arm", "left_lower_arm"), + "id": 6, + "color": [255, 0, 0], + }, + 7: { + "link": ("right_upper_arm", "right_lower_arm"), + "id": 7, + "color": [0, 255, 0], + }, + 8: {"link": ("left_lower_arm", "left_wrist"), "id": 8, "color": [255, 0, 0]}, + 9: {"link": ("right_lower_arm", "right_wrist"), "id": 9, "color": [0, 255, 0]}, + 10: {"link": ("left_wrist", "left_hand"), "id": 10, "color": [255, 0, 0]}, + 11: {"link": ("right_wrist", "right_hand"), "id": 11, "color": [0, 255, 0]}, + 12: {"link": ("stomach", "left_upper_leg"), "id": 12, "color": [255, 0, 0]}, + 13: {"link": ("stomach", "right_upper_leg"), "id": 13, "color": [0, 255, 0]}, + 14: {"link": ("left_upper_leg", "left_knee"), "id": 14, "color": [255, 0, 0]}, + 15: {"link": ("right_upper_leg", "right_knee"), "id": 15, "color": [0, 255, 0]}, + 16: {"link": ("left_knee", "left_lower_leg"), "id": 16, "color": [255, 0, 0]}, + 17: {"link": ("right_knee", "right_lower_leg"), "id": 17, "color": [0, 255, 0]}, + 18: {"link": ("left_lower_leg", "left_foot"), "id": 18, "color": [255, 0, 0]}, + 19: {"link": ("right_lower_leg", "right_foot"), "id": 19, "color": [0, 255, 0]}, + 20: {"link": ("right_upper_leg", "back"), "id": 20, "color": [147, 20, 255]}, + 21: {"link": ("left_upper_leg", "back"), "id": 21, "color": [147, 20, 255]}, }, - joint_weights=[ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], sigmas=[ - 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067, 0.067] -) \ No newline at end of file + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + 0.067, + ], +) diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_pose_gyges.yaml b/tcn_hpl/data/utils/pose_generation/configs/medic_pose_gyges.yaml new file mode 100644 index 000000000..12716a131 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_pose_gyges.yaml @@ -0,0 +1,27 @@ +_BASE_: "Base-RCNN-FPN.yaml" +MODEL: + WEIGHTS: "/home/local/KHQ/peri.akiva/projects/TCN_HPL/tcn_hpl/data/utils/pose_generation/checkpoints/model_final.pth" # please change here to the path where you put the weights + MASK_ON: False + RESNETS: + DEPTH: 101 + ROI_HEADS: + NUM_CLASSES: 2 + SCORE_THRESH_TEST: 0.0001 +INPUT: + MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) +DATASETS: + TRAIN: ("lvis_v1_train",) + TEST: ("lvis_v1_val",) +TEST: + DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.01 + STEPS: (6670, 8888) + MAX_ITER: 10000 # 180000 * 16 / 100000 ~ 28.8 epochs +DATALOADER: + SAMPLER_TRAIN: "RepeatFactorTrainingSampler" + REPEAT_THRESHOLD: 0.001 + +OUTPUT_DIR: "./output_pose_ablation/user+env" + diff --git a/tcn_hpl/data/utils/pose_generation/configs/medic_user.py b/tcn_hpl/data/utils/pose_generation/configs/medic_user.py index 05191b1b4..9c70fe685 100644 --- a/tcn_hpl/data/utils/pose_generation/configs/medic_user.py +++ b/tcn_hpl/data/utils/pose_generation/configs/medic_user.py @@ -1,69 +1,45 @@ dataset_info = dict( - dataset_name='medic_user', + dataset_name="medic_user", paper_info=dict( - author='Lin, Tsung-Yi and Maire, Michael and ' - 'Belongie, Serge and Hays, James and ' - 'Perona, Pietro and Ramanan, Deva and ' - r'Doll{\'a}r, Piotr and Zitnick, C Lawrence', - title='Microsoft coco: Common objects in context', - container='European conference on computer vision', - year='2014', - homepage='http://cocodataset.org/', + author="Lin, Tsung-Yi and Maire, Michael and " + "Belongie, Serge and Hays, James and " + "Perona, Pietro and Ramanan, Deva and " + r"Doll{\'a}r, Piotr and Zitnick, C Lawrence", + title="Microsoft coco: Common objects in context", + container="European conference on computer vision", + year="2014", + homepage="http://cocodataset.org/", ), - keypoint_info={ #'left_upper_arm', 'right_upper_arm', 'left_lower_arm', 'right_lower_arm', 'left_hand', 'right_hand - 0: - dict(name='left_upper_arm', - id=0, - color=[0, 0, 139], - type='upper', - swap=''), - 1: - dict( - name='right_upper_arm', - id=1, - color=[255, 112, 132], - type='upper', - swap=''), - 2: - dict( - name='left_lower_arm', - id=2, - color=[0, 0, 139], - type='upper', - swap=''), - 3: - dict( - name='right_lower_arm', - id=3, - color=[255, 112, 132], - type='upper', - swap=''), - 4: - dict( - name='left_hand', - id=4, - color=[0, 0, 139], - type='upper', - swap=''), - 5: - dict( - name='right_hand', - id=5, - color=[255, 112, 132], - type='upper', - swap='') + keypoint_info={ #'left_upper_arm', 'right_upper_arm', 'left_lower_arm', 'right_lower_arm', 'left_hand', 'right_hand + 0: dict(name="left_upper_arm", id=0, color=[0, 0, 139], type="upper", swap=""), + 1: dict( + name="right_upper_arm", id=1, color=[255, 112, 132], type="upper", swap="" + ), + 2: dict(name="left_lower_arm", id=2, color=[0, 0, 139], type="upper", swap=""), + 3: dict( + name="right_lower_arm", id=3, color=[255, 112, 132], type="upper", swap="" + ), + 4: dict(name="left_hand", id=4, color=[0, 0, 139], type="upper", swap=""), + 5: dict(name="right_hand", id=5, color=[255, 112, 132], type="upper", swap=""), }, skeleton_info={ - 0: - {'link': ('right_hand', 'right_lower_arm'), 'id': 0, 'color':[255, 112, 132]}, - 1: - {'link': ('right_upper_arm', 'right_lower_arm'), 'id': 1, 'color': [255, 112, 132]}, - 2: - {'link': ('left_upper_arm', 'left_lower_arm'), 'id': 2, 'color': [0, 0, 139]}, - 3: - {'link': ('left_lower_arm', 'left_hand'), 'id': 3, 'color': [0, 0, 139]} + 0: { + "link": ("right_hand", "right_lower_arm"), + "id": 0, + "color": [255, 112, 132], + }, + 1: { + "link": ("right_upper_arm", "right_lower_arm"), + "id": 1, + "color": [255, 112, 132], + }, + 2: { + "link": ("left_upper_arm", "left_lower_arm"), + "id": 2, + "color": [0, 0, 139], + }, + 3: {"link": ("left_lower_arm", "left_hand"), "id": 3, "color": [0, 0, 139]}, }, - joint_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - sigmas=[0.67, 0.67, 0.67, 0.67, 0.67, 0.67] -) \ No newline at end of file + sigmas=[0.67, 0.67, 0.67, 0.67, 0.67, 0.67], +) diff --git a/tcn_hpl/data/utils/pose_generation/generate_pose_data.py b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py index 3a32f3acf..a73430c6c 100644 --- a/tcn_hpl/data/utils/pose_generation/generate_pose_data.py +++ b/tcn_hpl/data/utils/pose_generation/generate_pose_data.py @@ -1,32 +1,26 @@ """Generate bounding box detections, then generate poses for patients """ -import argparse -import glob -from glob import glob -import multiprocessing as mp import numpy as np -import os -import tempfile -import time import warnings -import cv2 +import torch import tqdm from detectron2.config import get_cfg from detectron2.data.detection_utils import read_image from detectron2.utils.logger import setup_logger -import json -from predictor import VisualizationDemo +from tcn_hpl.data.utils.pose_generation.predictor import VisualizationDemo + # import tcn_hpl.utils.utils as utils -from mmpose.apis import (inference_top_down_pose_model, init_pose_model, - vis_pose_result) -import utils +from mmpose.apis import inference_top_down_pose_model, init_pose_model, vis_pose_result +from tcn_hpl.data.utils.pose_generation.utils import get_parser, load_yaml_as_dict import kwcoco from mmpose.datasets import DatasetInfo -print(f"utils: {utils.__file__}") + +# print(f"utils: {utils.__file__}") import warnings + warnings.filterwarnings("ignore") @@ -41,7 +35,9 @@ def setup_detectron_cfg(args): # Set score_threshold for builtin models cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold - cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold + cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = ( + args.confidence_threshold + ) cfg.freeze() return cfg @@ -49,44 +45,153 @@ def setup_detectron_cfg(args): class PosesGenerator(object): def __init__(self, config: dict) -> None: self.config = config - self.root_path = config['root'] - - self.dataset = kwcoco.CocoDataset(config['data'][config['task']]) - + self.root_path = config["root"] + + if config["data_type"] == "lab": + self.config_data_key = "bbn_lab" + else: + self.config_data_key = "data" + + self.dataset = kwcoco.CocoDataset(config[self.config_data_key][config["task"]]) + self.patient_cid = self.dataset.add_category("patient") + self.user_cid = self.dataset.add_category("user") + self.keypoints_cats = [ - "nose", "mouth", "throat","chest","stomach","left_upper_arm", - "right_upper_arm","left_lower_arm","right_lower_arm","left_wrist", - "right_wrist","left_hand","right_hand","left_upper_leg", - "right_upper_leg","left_knee","right_knee","left_lower_leg", - "right_lower_leg", "left_foot", "right_foot", "back" - ] - - self.keypoints_cats_dset = [{'name': value, 'id': index} for index, value in enumerate(self.keypoints_cats)] - - self.dataset.dataset['keypoint_categories'] = self.keypoints_cats_dset - - self.dataset_path_name = self.config['data'][self.config['task']][:-12].split('/')[-1] #remove .mscoco.json - - self.args = utils.get_parser(self.config['detection_model_config']).parse_args() + "nose", + "mouth", + "throat", + "chest", + "stomach", + "left_upper_arm", + "right_upper_arm", + "left_lower_arm", + "right_lower_arm", + "left_wrist", + "right_wrist", + "left_hand", + "right_hand", + "left_upper_leg", + "right_upper_leg", + "left_knee", + "right_knee", + "left_lower_leg", + "right_lower_leg", + "left_foot", + "right_foot", + "back", + ] + + self.keypoints_cats_dset = [ + {"name": value, "id": index} + for index, value in enumerate(self.keypoints_cats) + ] + + self.dataset.dataset["keypoint_categories"] = self.keypoints_cats_dset + + self.dataset_path_name = self.config[self.config_data_key][self.config["task"]][ + :-12 + ].split("/")[ + -1 + ] # remove .mscoco.json + + self.args = get_parser(self.config["detection_model_config"]).parse_args() detecron_cfg = setup_detectron_cfg(self.args) self.predictor = VisualizationDemo(detecron_cfg) - - self.pose_model = init_pose_model(config['pose_model_config'], - config['pose_model_checkpoint'], - device=config['device']) - self.pose_dataset = self.pose_model.cfg.data['test']['type'] - self.pose_dataset_info = self.pose_model.cfg.data['test'].get('dataset_info', None) + self.pose_model = init_pose_model( + config["pose_model_config"], + config["pose_model_checkpoint"], + device=config["device"], + ) + + self.pose_dataset = self.pose_model.cfg.data["test"]["type"] + self.pose_dataset_info = self.pose_model.cfg.data["test"].get( + "dataset_info", None + ) if self.pose_dataset_info is None: warnings.warn( - 'Please set `dataset_info` in the config.' - 'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', - DeprecationWarning) + "Please set `dataset_info` in the config." + "Check https://github.com/open-mmlab/mmpose/pull/663 for details.", + DeprecationWarning, + ) else: self.pose_dataset_info = DatasetInfo(self.pose_dataset_info) - - def generate_bbs_and_pose(self, dset: kwcoco.CocoDataset, save_intermediate: bool =True) -> kwcoco.CocoDataset: - + + def predict_single(self, image: torch.tensor) -> list: + + predictions, _ = self.predictor.run_on_image(image) + instances = predictions["instances"].to("cpu") + boxes = instances.pred_boxes if instances.has("pred_boxes") else None + scores = instances.scores if instances.has("scores") else None + classes = ( + instances.pred_classes.tolist() if instances.has("pred_classes") else None + ) + + boxes_list, labels_list, keypoints_list = [], [], [] + + if boxes is not None: + + # person_results = [] + for box_id, _bbox in enumerate(boxes): + + box_class = classes[box_id] + if box_class == 0: + pred_class = self.patient_cid + pred_label = "patient" + elif box_class == 1: + pred_class = self.user_cid + pred_label = "user" + + boxes_list.append(np.asarray(_bbox).tolist()) + labels_list.append(pred_label) + + current_ann = {} + # current_ann['id'] = ann_id + current_ann["image_id"] = 0 + current_ann["bbox"] = np.asarray(_bbox).tolist() # _bbox + current_ann["category_id"] = pred_class + current_ann["label"] = pred_label + current_ann["bbox_score"] = f"{scores[box_id] * 100:0.2f}" + + if box_class == 0: + person_results = [current_ann] + + pose_results, returned_outputs = inference_top_down_pose_model( + model=self.pose_model, + img_or_path=image, + person_results=person_results, + bbox_thr=None, + format="xyxy", + dataset=self.pose_dataset, + dataset_info=self.pose_dataset_info, + return_heatmap=False, + outputs=["backbone"], + ) + + pose_keypoints = pose_results[0]["keypoints"].tolist() + pose_keypoints_list = [] + for kp_index, keypoint in enumerate(pose_keypoints): + kp_dict = { + "xy": [keypoint[0], keypoint[1]], + "keypoint_category_id": kp_index, + "keypoint_category": self.keypoints_cats[kp_index], + } + pose_keypoints_list.append(kp_dict) + + keypoints_list.append(pose_keypoints_list) + # print(f"pose_keypoints_list: {pose_keypoints_list}") + current_ann["keypoints"] = pose_keypoints_list + # current_ann['image_features'] = image_features + + # dset.add_annotation(**current_ann) + + # results = [] + return boxes_list, labels_list, keypoints_list + + def generate_bbs_and_pose( + self, dset: kwcoco.CocoDataset, save_intermediate: bool = True + ) -> kwcoco.CocoDataset: + """ Generates a CocoDataset with bounding box (bbs) and pose annotations generated from the dataset's images. This method processes each image, detects bounding boxes and classifies them into 'patient' or 'user' categories, @@ -113,92 +218,115 @@ def generate_bbs_and_pose(self, dset: kwcoco.CocoDataset, save_intermediate: boo - The `kwcoco.CocoDataset` class is part of the `kwcoco` package, offering structured management of COCO-format datasets, including easy addition of annotations and categories, and saving/loading datasets. """ - - patient_cid = dset.add_category('patient') - user_cid = dset.add_category('user') - pbar = tqdm.tqdm(enumerate(dset.imgs.items()), total=len(list(dset.imgs.keys()))) - + + # patient_cid = self.dataset.add_category('patient') + # user_cid = self.dataset.add_category('user') + pbar = tqdm.tqdm( + enumerate(self.dataset.imgs.items()), + total=len(list(self.dataset.imgs.keys())), + ) + for index, (img_id, img_dict) in pbar: - - path = img_dict['file_name'] - + + path = img_dict["file_name"] + img = read_image(path, format="BGR") + + # bs, ls, kps = self.predict_single(img) + + # print(f"boxes: {bs}") + # print(f"ls: {ls}") + # print(f"kps: {kps}") + + # continue + predictions, visualized_output = self.predictor.run_on_image(img) - - instances = predictions["instances"].to('cpu') + + instances = predictions["instances"].to("cpu") boxes = instances.pred_boxes if instances.has("pred_boxes") else None scores = instances.scores if instances.has("scores") else None - classes = instances.pred_classes.tolist() if instances.has("pred_classes") else None - + classes = ( + instances.pred_classes.tolist() + if instances.has("pred_classes") + else None + ) + boxes = boxes.tensor.detach().numpy() scores = scores.numpy() - - file_name = path.split('/')[-1] - + + file_name = path.split("/")[-1] + if boxes is not None: - + # person_results = [] for box_id, _bbox in enumerate(boxes): - + box_class = classes[box_id] if box_class == 0: - pred_class = patient_cid - pred_label = 'patient' + pred_class = self.patient_cid + pred_label = "patient" elif box_class == 1: - pred_class = user_cid - pred_label = 'user' - + pred_class = self.user_cid + pred_label = "user" + current_ann = {} # current_ann['id'] = ann_id - current_ann['image_id'] = img_id - current_ann['bbox'] = np.asarray(_bbox).tolist()#_bbox - current_ann['category_id'] = pred_class - current_ann['label'] = pred_label - current_ann['bbox_score'] = str(round(scores[box_id] * 100,2)) + '%' - + current_ann["image_id"] = img_id + current_ann["bbox"] = np.asarray(_bbox).tolist() # _bbox + current_ann["category_id"] = pred_class + current_ann["label"] = pred_label + current_ann["bbox_score"] = ( + str(round(scores[box_id] * 100, 2)) + "%" + ) + if box_class == 0: person_results = [current_ann] - + pose_results, returned_outputs = inference_top_down_pose_model( - self.pose_model, - path, - person_results, - bbox_thr=None, - format='xyxy', - dataset=self.pose_dataset, - dataset_info=self.pose_dataset_info, - return_heatmap=None, - outputs=['backbone']) - - pose_keypoints = pose_results[0]['keypoints'].tolist() + self.pose_model, + path, + person_results, + bbox_thr=None, + format="xyxy", + dataset=self.pose_dataset, + dataset_info=self.pose_dataset_info, + return_heatmap=None, + outputs=["backbone"], + ) + + pose_keypoints = pose_results[0]["keypoints"].tolist() pose_keypoints_list = [] for kp_index, keypoint in enumerate(pose_keypoints): - kp_dict = {'xy': [keypoint[0], keypoint[1]], - 'keypoint_category_id': kp_index, - 'keypoint_category': self.keypoints_cats[kp_index]} + kp_dict = { + "xy": [keypoint[0], keypoint[1]], + "keypoint_category_id": kp_index, + "keypoint_category": self.keypoints_cats[kp_index], + } pose_keypoints_list.append(kp_dict) - + # print(f"pose_keypoints_list: {pose_keypoints_list}") - current_ann['keypoints'] = pose_keypoints_list + current_ann["keypoints"] = pose_keypoints_list # current_ann['image_features'] = image_features - - dset.add_annotation(**current_ann) - + + self.dataset.add_annotation(**current_ann) + # import matplotlib.pyplot as plt # image_show = dset.draw_image(gid=img_id) # plt.imshow(image_show) - # plt.savefig(f"myfig_{self.config['task']}_{index}.png") + # plt.savefig(f"figs/myfig_{self.config['task']}_{index}.png") # if index >= 20: # exit() - + if save_intermediate: if (index % 45000) == 0: - dset_inter_name = f"{self.config['data']['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json" - dset.dump(dset_inter_name, newlines=True) - print(f"Saved intermediate dataset at index {index} to: {dset_inter_name}") - - return dset - + dset_inter_name = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_{index}_with_dets_and_pose.mscoco.json" + self.dataset.dump(dset_inter_name, newlines=True) + print( + f"Saved intermediate dataset at index {index} to: {dset_inter_name}" + ) + + return self.dataset + def run(self) -> None: """ Executes the process of generating bounding box and pose annotations for a dataset and then saves the @@ -222,19 +350,21 @@ def run(self) -> None: processing expected by `generate_bbs_and_pose`. """ self.dataset = self.generate_bbs_and_pose(self.dataset) - - dataset_path_with_pose = f"{self.config['data']['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json" + + dataset_path_with_pose = f"{self.config[self.config_data_key]['save_root']}/{self.dataset_path_name}_with_dets_and_pose.mscoco.json" self.dataset.dump(dataset_path_with_pose, newlines=True) print(f"Saved test dataset to: {dataset_path_with_pose}") return + def main(): - + main_config_path = f"configs/main.yaml" - config = utils.load_yaml_as_dict(main_config_path) + config = load_yaml_as_dict(main_config_path) PG = PosesGenerator(config) PG.run() - -if __name__ == '__main__': - main() \ No newline at end of file + + +if __name__ == "__main__": + main() diff --git a/tcn_hpl/data/utils/pose_generation/predictor.py b/tcn_hpl/data/utils/pose_generation/predictor.py index 941293b06..064396778 100644 --- a/tcn_hpl/data/utils/pose_generation/predictor.py +++ b/tcn_hpl/data/utils/pose_generation/predictor.py @@ -96,7 +96,9 @@ def process_predictions(frame, predictions): ) elif "instances" in predictions: predictions = predictions["instances"].to(self.cpu_device) - vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + vis_frame = video_visualizer.draw_instance_predictions( + frame, predictions + ) elif "sem_seg" in predictions: vis_frame = video_visualizer.draw_sem_seg( frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) diff --git a/tcn_hpl/data/utils/pose_generation/rt_pose_generation.py b/tcn_hpl/data/utils/pose_generation/rt_pose_generation.py new file mode 100644 index 000000000..1330afdf0 --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/rt_pose_generation.py @@ -0,0 +1,117 @@ +import torch +from mmpose.apis import inference_top_down_pose_model +import numpy as np +from mmpose.datasets import DatasetInfo +import warnings + + +def predict_single(det_model, pose_model, image: torch.tensor) -> list: + + keypoints_cats = [ + "nose", + "mouth", + "throat", + "chest", + "stomach", + "left_upper_arm", + "right_upper_arm", + "left_lower_arm", + "right_lower_arm", + "left_wrist", + "right_wrist", + "left_hand", + "right_hand", + "left_upper_leg", + "right_upper_leg", + "left_knee", + "right_knee", + "left_lower_leg", + "right_lower_leg", + "left_foot", + "right_foot", + "back", + ] + + keypoints_cats_dset = [ + {"name": value, "id": index} for index, value in enumerate(keypoints_cats) + ] + + pose_dataset = pose_model.cfg.data["test"]["type"] + pose_dataset_info = pose_model.cfg.data["test"].get("dataset_info", None) + + pose_dataset_info = pose_model.cfg.data["test"].get("dataset_info", None) + if pose_dataset_info is None: + warnings.warn( + "Please set `dataset_info` in the config." + "Check https://github.com/open-mmlab/mmpose/pull/663 for details.", + DeprecationWarning, + ) + else: + pose_dataset_info = DatasetInfo(pose_dataset_info) + + predictions, _ = det_model.run_on_image(image) + instances = predictions["instances"].to("cpu") + boxes = instances.pred_boxes if instances.has("pred_boxes") else None + scores = instances.scores if instances.has("scores") else None + classes = instances.pred_classes.tolist() if instances.has("pred_classes") else None + + boxes_list, labels_list, keypoints_list = [], [], [] + + if boxes is not None: + + # person_results = [] + for box_id, _bbox in enumerate(boxes): + + box_class = classes[box_id] + if box_class == 0: + pred_class = box_class + pred_label = "patient" + elif box_class == 1: + pred_class = box_class + pred_label = "user" + + boxes_list.append(np.asarray(_bbox).tolist()) + labels_list.append(pred_label) + + current_ann = {} + # current_ann['id'] = ann_id + current_ann["image_id"] = 0 + current_ann["bbox"] = np.asarray(_bbox).tolist() # _bbox + current_ann["category_id"] = pred_class + current_ann["label"] = pred_label + current_ann["bbox_score"] = f"{scores[box_id] * 100:0.2f}" + + if box_class == 0: + person_results = [current_ann] + + pose_results, returned_outputs = inference_top_down_pose_model( + model=pose_model, + img_or_path=image, + person_results=person_results, + bbox_thr=None, + format="xyxy", + dataset=pose_dataset, + dataset_info=pose_dataset_info, + return_heatmap=False, + outputs=["backbone"], + ) + + pose_keypoints = pose_results[0]["keypoints"].tolist() + keypoints_list.append(pose_keypoints) + + # pose_keypoints_list = [] + # for kp_index, keypoint in enumerate(pose_keypoints): + # kp_dict = {'xy': [keypoint[0], keypoint[1]], + # 'keypoint_category_id': kp_index, + # 'keypoint_category': keypoints_cats[kp_index]} + # pose_keypoints_list.append(kp_dict) + + # # keypoints_list.append(pose_keypoints_list) + # # print(f"pose_keypoints_list: {pose_keypoints_list}") + # current_ann['keypoints'] = pose_keypoints_list + # current_ann['image_features'] = image_features + + # dset.add_annotation(**current_ann) + + # results = [] + return boxes_list, labels_list, keypoints_list diff --git a/tcn_hpl/data/utils/pose_generation/utils.py b/tcn_hpl/data/utils/pose_generation/utils.py index ad282a44d..b6c7bcd86 100644 --- a/tcn_hpl/data/utils/pose_generation/utils.py +++ b/tcn_hpl/data/utils/pose_generation/utils.py @@ -3,6 +3,8 @@ from glob import glob import os import yaml + + def get_parser(config_file): parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs") parser.add_argument( @@ -11,22 +13,28 @@ def get_parser(config_file): metavar="FILE", help="path to config file", ) - parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") - parser.add_argument("--video-input", help="Path to video file." - # , default='/shared/niudt/detectron2/images/Videos/k2/4.MP4' + parser.add_argument( + "--webcam", action="store_true", help="Take inputs from webcam." + ) + parser.add_argument( + "--video-input", + help="Path to video file." + # , default='/shared/niudt/detectron2/images/Videos/k2/4.MP4' ) parser.add_argument( "--input", nargs="+", help="A list of space separated input images; " - "or a single glob pattern such as 'directory/*.jpg'" - , default= ['/shared/niudt/DATASET/Medical/Maydemo/2023-4-25/selected_videos/new/M2-16/*.jpg'] # please change here to the path where you put the images + "or a single glob pattern such as 'directory/*.jpg'", + default=[ + "/shared/niudt/DATASET/Medical/Maydemo/2023-4-25/selected_videos/new/M2-16/*.jpg" + ], # please change here to the path where you put the images ) parser.add_argument( "--output", help="A file or directory to save output visualizations. " - "If not given, will show output in an OpenCV window." - , default='./bbox_detection_results' + "If not given, will show output in an OpenCV window.", + default="./bbox_detection_results", ) parser.add_argument( @@ -43,11 +51,13 @@ def get_parser(config_file): ) return parser + def load_yaml_as_dict(yaml_path): - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) return config_dict + def dictionary_contents(path: str, types: list, recursive: bool = False) -> list: """ Extract files of specified types from directories, optionally recursively. @@ -72,6 +82,7 @@ def dictionary_contents(path: str, types: list, recursive: bool = False) -> list files.append(os.path.join(path, x)) return files + def create_dir_if_doesnt_exist(dir_path): """ This function creates a dictionary if it doesnt exist. @@ -83,36 +94,36 @@ def create_dir_if_doesnt_exist(dir_path): def initialize_coco_json(): - + json_file = {} - json_file['images'] = [] - json_file['annotations'] = [] - json_file['categories'] = [] + json_file["images"] = [] + json_file["annotations"] = [] + json_file["categories"] = [] # get new categorie temp_bbox = {} - temp_bbox['id'] = 1 - temp_bbox['name'] = 'patient' - temp_bbox['instances_count'] = 0 - temp_bbox['def'] = '' - temp_bbox['synonyms'] = ['patient'] - temp_bbox['image_count'] = 0 - temp_bbox['frequency'] = '' - temp_bbox['synset'] = '' + temp_bbox["id"] = 1 + temp_bbox["name"] = "patient" + temp_bbox["instances_count"] = 0 + temp_bbox["def"] = "" + temp_bbox["synonyms"] = ["patient"] + temp_bbox["image_count"] = 0 + temp_bbox["frequency"] = "" + temp_bbox["synset"] = "" - json_file['categories'].append(temp_bbox) + json_file["categories"].append(temp_bbox) temp_bbox = {} - temp_bbox['id'] = 2 - temp_bbox['name'] = 'user' - temp_bbox['instances_count'] = 0 - temp_bbox['def'] = '' - temp_bbox['synonyms'] = ['user'] - temp_bbox['image_count'] = 0 - temp_bbox['frequency'] = '' - temp_bbox['synset'] = '' - - json_file['categories'].append(temp_bbox) + temp_bbox["id"] = 2 + temp_bbox["name"] = "user" + temp_bbox["instances_count"] = 0 + temp_bbox["def"] = "" + temp_bbox["synonyms"] = ["user"] + temp_bbox["image_count"] = 0 + temp_bbox["frequency"] = "" + temp_bbox["synset"] = "" + + json_file["categories"].append(temp_bbox) ann_num = 0 - - return json_file \ No newline at end of file + + return json_file diff --git a/tcn_hpl/data/utils/pose_generation/video_to_frames.py b/tcn_hpl/data/utils/pose_generation/video_to_frames.py index 0b99d7581..735683e0e 100644 --- a/tcn_hpl/data/utils/pose_generation/video_to_frames.py +++ b/tcn_hpl/data/utils/pose_generation/video_to_frames.py @@ -9,33 +9,37 @@ import os import cv2 from glob import glob -import json +import json import utils - - def main(): # the dir including videos you want to process - videos_src_path = '/data/datasets/ptg/m2_tourniquet/' + videos_src_path = "/data/datasets/ptg/m2_tourniquet/" # the save path - videos_save_path = '/data/datasets/ptg/m2_tourniquet/imgs' + videos_save_path = "/data/datasets/ptg/m2_tourniquet/imgs" - videos = utils.dictionary_contents(videos_src_path, types=['*.mp4', '*.MP4'], recursive=True) + videos = utils.dictionary_contents( + videos_src_path, types=["*.mp4", "*.MP4"], recursive=True + ) # print(videos) # exit() # videos = filter(lambda x: x.endswith('MP4'), videos) - coco_json = {"info": {"description": "Medical Pose Estimation", - "year": 2023, - "contributer": "Kitware", - "version": "0.1"}, - "images": [], - "annotations": []} + coco_json = { + "info": { + "description": "Medical Pose Estimation", + "year": 2023, + "contributer": "Kitware", + "version": "0.1", + }, + "images": [], + "annotations": [], + } for index, each_video in enumerate(videos): # get the name of each video, and make the directory to save frames - video_name = each_video.split('/')[-1].split('.')[0] - print('Video Name :', video_name) + video_name = each_video.split("/")[-1].split(".")[0] + print("Video Name :", video_name) # each_video_name, _ = each_video.split('.')each_video.split('.') dir_path = f"{videos_save_path}/{video_name}" utils.create_dir_if_doesnt_exist(dir_path) @@ -52,7 +56,7 @@ def main(): success = True # 计数 num = 0 - while (success): + while success: success, frame = cap.read() if success == True: # if not os.path.exists(each_video_save_full_path + video_name): @@ -76,14 +80,14 @@ def main(): num += 1 frame_count = frame_count + 1 - print('Final frame:', num) + print("Final frame:", num) # coco_json_save_path = f"{videos_src_path}/medical_coco.json" coco_json_save_path = "medical_coco.json" - with open(coco_json_save_path, "w") as outfile: + with open(coco_json_save_path, "w") as outfile: json.dump(coco_json, outfile) + if __name__ == "__main__": main() - diff --git a/tcn_hpl/data/utils/pose_generation/vit.py b/tcn_hpl/data/utils/pose_generation/vit.py new file mode 100644 index 000000000..8990020bf --- /dev/null +++ b/tcn_hpl/data/utils/pose_generation/vit.py @@ -0,0 +1,411 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +from functools import partial +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +from mmpose.models.builder import BACKBONES +from mmpose.models.backbones.base_backbone import BaseBackbone + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self): + return "p={}".format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.dim = dim + + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + attn_head_dim=None, + ): + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + attn_head_dim=attn_head_dim, + ) + + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = ( + (img_size[1] // patch_size[1]) + * (img_size[0] // patch_size[0]) + * (ratio**2) + ) + self.patch_shape = ( + int(img_size[0] // patch_size[0] * ratio), + int(img_size[1] // patch_size[1] * ratio), + ) + self.origin_patch_shape = ( + int(img_size[0] // patch_size[0]), + int(img_size[1] // patch_size[1]), + ) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=(patch_size[0] // ratio), + padding=4 + 2 * (ratio // 2 - 1), + ) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__( + self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768 + ): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[ + -1 + ] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +@BACKBONES.register_module() +class ViT(BaseBackbone): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=80, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + hybrid_backbone=None, + norm_layer=None, + use_checkpoint=False, + frozen_stages=-1, + ratio=1, + last_norm=True, + patch_padding="pad", + freeze_attn=False, + freeze_ffn=False, + ): + # Protect mutable default arguments + super(ViT, self).__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.use_checkpoint = use_checkpoint + self.patch_padding = patch_padding + self.freeze_attn = freeze_attn + self.freeze_ffn = freeze_ffn + self.depth = depth + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, + img_size=img_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ratio=ratio, + ) + num_patches = self.patch_embed.num_patches + + # since the pretraining model has class token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + self.blocks = nn.ModuleList( + [ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity() + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=0.02) + + self._freeze_stages() + + def _freeze_stages(self): + """Freeze parameters.""" + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = self.blocks[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + if self.freeze_attn: + for i in range(0, self.depth): + m = self.blocks[i] + m.attn.eval() + m.norm1.eval() + for param in m.attn.parameters(): + param.requires_grad = False + for param in m.norm1.parameters(): + param.requires_grad = False + + if self.freeze_ffn: + self.pos_embed.requires_grad = False + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + for i in range(0, self.depth): + m = self.blocks[i] + m.mlp.eval() + m.norm2.eval() + for param in m.mlp.parameters(): + param.requires_grad = False + for param in m.norm2.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + super().init_weights(pretrained, patch_padding=self.patch_padding) + + if pretrained is None: + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + self.apply(_init_weights) + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed", "cls_token"} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + + if self.pos_embed is not None: + # fit for multiple GPU training + # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference + x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1] + + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + + x = self.last_norm(x) + + xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous() + + return xp + + def forward(self, x): + x = self.forward_features(x) + return x + + def train(self, mode=True): + """Convert the model into training mode.""" + super().train(mode) + self._freeze_stages() diff --git a/tcn_hpl/data/utils/ptg_datagenerator.py b/tcn_hpl/data/utils/ptg_datagenerator.py index 4af580deb..c9749fe90 100644 --- a/tcn_hpl/data/utils/ptg_datagenerator.py +++ b/tcn_hpl/data/utils/ptg_datagenerator.py @@ -1,3 +1,7 @@ +# This script is designed for processing and preparing datasets for training activity classifiers. +# It performs several key functions: reading configuration files, generating feature matrices, +# preparing ground truth labels, and organizing the data into a structured format suitable for th TCN model. + import os import yaml import glob @@ -10,6 +14,7 @@ from pathlib import Path +from angel_system.data.medical.data_paths import GrabData from angel_system.data.common.load_data import ( activities_from_dive_csv, objs_as_dataframe, @@ -20,50 +25,84 @@ data_loader, compute_feats, ) -# from angel_system.data.data_paths import grab_data, data_dir + def load_yaml_as_dict(yaml_path): - with open(yaml_path, 'r') as f: + """ + Loads a YAML file and returns it as a Python dictionary. + + Args: + yaml_path (str): The path to the YAML configuration file. + + Returns: + dict: The YAML file's content as a dictionary. + """ + with open(yaml_path, "r") as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) return config_dict + ##################### # Inputs ##################### +# Mapping task identifiers to their descriptive names. TASK_TO_NAME = { - 'm1': "M1_Trauma_Assessment", - 'm2': "M2_Tourniquet", - 'm3': "M3_Pressure_Dressing", - 'm4': "M4_Wound_Packing", - 'm5': "M5_X-Stat", - 'r18': "R18_Chest_Seal", + "m1": "M1_Trauma_Assessment", + "m2": "M2_Tourniquet", + "m3": "M3_Pressure_Dressing", + "m4": "M4_Wound_Packing", + "m5": "M5_X-Stat", + "r18": "R18_Chest_Seal", } +# Mapping lab data task identifiers to their descriptive names. +LAB_TASK_TO_NAME = { + "m2": "M2_Lab_Skills", + "m3": "M3_Lab_Skills", + "m5": "M5_Lab_Skills", + "r18": "R18_Lab_Skills", +} + +# Mapping feature settings to boolean flags indicating the inclusion of pose or object joint information. FEAT_TO_BOOL = { - "no_pose": [False, False], - "with_pose": [True, True], - "only_hands_joints": [False, True], - "only_objects_joints": [True, False] - } - -def main(task: str): - config_path = f"/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/{task}/feat_v6.yaml" + "no_pose": [False, False], + "with_pose": [True, True], + "only_hands_joints": [False, True], + "only_objects_joints": [True, False], +} + + +def main( + task: str, ptg_root: str, config_root: str, data_type: str, data_gen_yaml: str +): + """ + Main function that orchestrates the process of loading configurations, setting up directories, + processing datasets, and generating features for training activity classification models. + + Args: + task (str): The task identifier. + ptg_root (str): Path to the root of the angel_system project. + config_root (str): Path to the root of the configuration files. + data_type (str): Specifies the type of data, either 'gyges' for professional data or 'bbn' for lab data. + """ + config_path = f"{config_root}/experiment/{task}/feat_v6.yaml" config = load_yaml_as_dict(config_path) - ptg_root = "/home/local/KHQ/peri.akiva/angel_system" activity_config_path = f"{ptg_root}/config/activity_labels/medical" feat_version = 6 ##################### # Output ##################### - # reshuffle_datasets = True - # augment = False - # num_augs = 5 - reshuffle_datasets, augment, num_augs = config['data_gen']['reshuffle_datasets'], config['data_gen']['augment'], config['data_gen']['num_augs'] + reshuffle_datasets, augment, num_augs = ( + config["data_gen"]["reshuffle_datasets"], + config["data_gen"]["augment"], + config["data_gen"]["num_augs"], + ) + ## augmentation is not currently used. if augment: num_augs = num_augs aug_trans_range, aug_rot_range = [-5, 5], [-5, 5] @@ -71,10 +110,21 @@ def main(task: str): num_augs = 1 aug_trans_range, aug_rot_range = None, None - - exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}" #[p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] + # the "data_type" parameter is new to the BBN lab data. + # old experiments dont have that parameter in their experiment name + """ + feat_type details wrt to feature generation: + no_pose: we only use object detections, and object-hands intersection + with_pose: we use patient pose to calculate joint-hands and joint-objects offset vectors + only_hands_joints: we use patient pose to calculate only joint-hands offset vectors + only_objects_joints: we use patient pose to calulcate only joint-objects offset vectors + """ - output_data_dir = f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}" + exp_name = f"p_{config['task']}_feat_v6_{config['data_gen']['feat_type']}_v3_aug_{augment}_reshuffle_{reshuffle_datasets}_{data_type}" # [p_m2_tqt_data_test_feat_v6_with_pose, p_m2_tqt_data_test_feat_v6_only_hands_joints, p_m2_tqt_data_test_feat_v6_only_objects_joints, p_m2_tqt_data_test_feat_v6_no_pose] + + output_data_dir = ( + f"{config['paths']['output_data_dir_root']}/{config['task']}/{exp_name}" + ) gt_dir = f"{output_data_dir}/groundTruth" frames_dir = f"{output_data_dir}/frames" @@ -95,7 +145,9 @@ def main(task: str): ##################### # Mapping ##################### - activity_config_path = f"{config['paths']['activity_config_root']}/{config['task']}.yaml" + activity_config_path = ( + f"{config['paths']['activity_config_root']}/{config['task']}.yaml" + ) with open(activity_config_path, "r") as stream: activity_config = yaml.safe_load(stream) activity_labels = activity_config["labels"] @@ -106,7 +158,10 @@ def main(task: str): for label in activity_labels: i = label["id"] label_str = label["label"] - activity_labels_desc_mapping[label['description']] = label["label"] + if "description" in label.keys(): + activity_labels_desc_mapping[label["description"]] = label["label"] + elif "full_str" in label.keys(): + activity_labels_desc_mapping[label["full_str"]] = label["label"] if label_str == "done": continue mapping.write(f"{i} {label_str}\n") @@ -118,47 +173,82 @@ def main(task: str): ##################### ### create splits dsets - dset = kwcoco.CocoDataset(config['paths']['dataset_kwcoco']) + # Data processing and feature generation based on the data type (gyges or bbn). + # The detailed implementation handles data loading, augmentation (if specified), + # ground truth preparation, feature computation, and data organization for training and evaluation. + print(f"Generating features for task: {task}") + + if data_type == "pro": + dset = kwcoco.CocoDataset(config["paths"]["dataset_kwcoco"]) + elif data_type == "lab": + dset = kwcoco.CocoDataset(config["paths"]["dataset_kwcoco_lab"]) if reshuffle_datasets: - + train_img_ids, val_img_ids, test_img_ids = [], [], [] - - task_name = config['task'].upper() - train_vidids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['train_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - val_vivids = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['val_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - test_vivds = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['test_vid_ids'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - if config['data_gen']['filter_black_gloves']: - vidids_black_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_black_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - train_vidids = [x for x in train_vidids if x not in vidids_black_gloves] - val_vivids = [x for x in val_vivids if x not in vidids_black_gloves] - test_vivds = [x for x in test_vivds if x not in vidids_black_gloves] - - if config['data_gen']['filter_blue_gloves']: - vidids_blue_gloves = [dset.index.name_to_video[f"{task_name}-{index}"]['id'] for index in config['data_gen']['names_blue_gloves'] if f"{task_name}-{index}" in dset.index.name_to_video.keys()] - - train_vidids = [x for x in train_vidids if x not in vidids_blue_gloves] - val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] - test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] - - - ## individual splits by gids - # total = len(dset.index.imgs) - # inds = [i for i in range(1, total+1)] - # train_size, val_size = int(0.8*total), int(0.1*total) - # test_size = total - train_size - val_size - # train_inds = set(list(np.random.choice(inds, size=train_size, replace=False))) - # remaining_inds = list(set(inds) - train_inds) - # val_inds = set(list(np.random.choice(remaining_inds, size=val_size, replace=False))) - # test_inds = list(set(remaining_inds) - val_inds) - # # test_inds = list(np.random.choice(remaining_inds, size=test_size, replace=False)) - - # train_img_ids = [dset.index.imgs[i]['id'] for i in train_inds] - # val_img_ids = [dset.index.imgs[i]['id'] for i in val_inds] - # test_img_ids = [dset.index.imgs[i]['id'] for i in test_inds] - # print(f"train: {len(train_inds)}, val: {len(val_inds)}, test: {len(test_inds)}") + + ## The data directory format is different for "Professional" and lab data, so we handle the split differently + if data_type == "pro": + task_name = config["task"].upper() + train_vidids = [ + dset.index.name_to_video[f"{task_name}-{index}"]["id"] + for index in config["data_gen"]["train_vid_ids"] + if f"{task_name}-{index}" in dset.index.name_to_video.keys() + ] + val_vivids = [ + dset.index.name_to_video[f"{task_name}-{index}"]["id"] + for index in config["data_gen"]["val_vid_ids"] + if f"{task_name}-{index}" in dset.index.name_to_video.keys() + ] + test_vivds = [ + dset.index.name_to_video[f"{task_name}-{index}"]["id"] + for index in config["data_gen"]["test_vid_ids"] + if f"{task_name}-{index}" in dset.index.name_to_video.keys() + ] + + if config["data_gen"]["filter_black_gloves"]: + vidids_black_gloves = [ + dset.index.name_to_video[f"{task_name}-{index}"]["id"] + for index in config["data_gen"]["names_black_gloves"] + if f"{task_name}-{index}" in dset.index.name_to_video.keys() + ] + + train_vidids = [x for x in train_vidids if x not in vidids_black_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_black_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_black_gloves] + + if config["data_gen"]["filter_blue_gloves"]: + vidids_blue_gloves = [ + dset.index.name_to_video[f"{task_name}-{index}"]["id"] + for index in config["data_gen"]["names_blue_gloves"] + if f"{task_name}-{index}" in dset.index.name_to_video.keys() + ] + + train_vidids = [x for x in train_vidids if x not in vidids_blue_gloves] + val_vivids = [x for x in val_vivids if x not in vidids_blue_gloves] + test_vivds = [x for x in test_vivds if x not in vidids_blue_gloves] + + elif data_type == "lab": + task_name = config["task"].upper() + all_vids = sorted(list(dset.index.name_to_video.keys())) + train_vidids = [ + dset.index.name_to_video[vid_name]["id"] + for vid_name in all_vids + if dset.index.name_to_video[vid_name]["id"] + in config["data_gen"]["train_vid_ids_bbn"] + ] + val_vivids = [ + dset.index.name_to_video[vid_name]["id"] + for vid_name in all_vids + if dset.index.name_to_video[vid_name]["id"] + in config["data_gen"]["val_vid_ids_bbn"] + ] + test_vivds = [ + dset.index.name_to_video[vid_name]["id"] + for vid_name in all_vids + if dset.index.name_to_video[vid_name]["id"] + in config["data_gen"]["test_vid_ids_bbn"] + ] for vid in train_vidids: if vid in dset.index.vidid_to_gids.keys(): @@ -167,98 +257,148 @@ def main(task: str): print(f"{vid} not in the train dataset") for vid in val_vivids: - + if vid in dset.index.vidid_to_gids.keys(): val_img_ids.extend(list(dset.index.vidid_to_gids[vid])) else: print(f"{vid} not in the val dataset") - # val_img_ids = set(val_img_ids) + set(dset.index.vidid_to_gids[vid]) for vid in test_vivds: - + if vid in dset.index.vidid_to_gids.keys(): test_img_ids.extend(list(dset.index.vidid_to_gids[vid])) else: print(f"{vid} not in the test dataset") - train_img_ids, val_img_ids, test_img_ids = list(train_img_ids), list(val_img_ids), list(test_img_ids) + train_img_ids, val_img_ids, test_img_ids = ( + list(train_img_ids), + list(val_img_ids), + list(test_img_ids), + ) - print(f"train num_images: {len(train_img_ids)}, val num_images: {len(val_img_ids)}, test num_images: {len(test_img_ids)}") + print( + f"train num_images: {len(train_img_ids)}, val num_images: {len(val_img_ids)}, test num_images: {len(test_img_ids)}" + ) train_dset = dset.subset(gids=train_img_ids, copy=True) val_dset = dset.subset(gids=val_img_ids, copy=True) test_dset = dset.subset(gids=test_img_ids, copy=True) - skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" - # videos_names = os.listdir(skill_data_root) - # for split in ["train_activity", "val", "test"]: - for dset, split in zip([train_dset, val_dset, test_dset], ["train_activity", "val", "test"]): - # org_num_vidoes = len(list(dset.index.videos.keys())) - # new_num_videos = 0 + # again, both data types use different directory formatting. We handle both here + data_grabber = GrabData(yaml_path=data_gen_yaml) + if data_type == "pro": + # skill_data_root = f"{config['paths']['bbn_data_dir']}/Release_v0.5/v0.56/{TASK_TO_NAME[task]}/Data" + skill_data_root = f"{data_grabber.bbn_data_root}/{TASK_TO_NAME[task]}/Data" + elif data_type == "lab": + skill_data_root = f"{data_grabber.lab_bbn_data_root}/{LAB_TASK_TO_NAME[task]}" + + for dset, split in zip( + [train_dset, val_dset, test_dset], ["train_activity", "val", "test"] + ): for video_id in ub.ProgIter(dset.index.videos.keys()): - - # if video_id != 10: - # continue - + video = dset.index.videos[video_id] video_name = video["name"] - + if "_extracted" in video_name: video_name = video_name.split("_extracted")[0] image_ids = dset.index.vidid_to_gids[video_id] num_images = len(image_ids) - + print(f"there are {num_images} in video {video_id}") video_dset = dset.subset(gids=image_ids, copy=True) - - #### begin activity GT section - video_root = f"{skill_data_root}/{video_name}" - activity_gt_file = f"{video_root}/{video_name}.skill_labels_by_frame.txt" - + + # begin activity GT section + # The GT given data is provided in different formats for the lab and professional data collections. + # we handle both here. + if data_type == "pro": + video_root = f"{skill_data_root}/{video_name}" + activity_gt_file = ( + f"{video_root}/{video_name}.skill_labels_by_frame.txt" + ) + elif data_type == "lab": + activity_gt_file = ( + f"{skill_data_root}/{video_name}.skill_labels_by_frame.txt" + ) + if not os.path.exists(activity_gt_file): + print( + f"activity_gt_file {activity_gt_file} doesnt exists. Trying a different way" + ) + activity_gt_file = f"{skill_data_root}/{video_name}.txt" + if not os.path.exists(activity_gt_file): + print(f"activity_gt_file {activity_gt_file} doesnt exists. continueing") continue - + f = open(activity_gt_file, "r") text = f.read() f.close() - - # print(f"str type: {type(text)}") - text = text.replace('\n', '\t') - text_list = text.split("\t")[:-1] - + activityi_gt_list = ["background" for x in range(num_images)] - for index in range(0, len(text_list), 3): - triplet = text_list[index:index+3] - # print(f"index: {index}, {text_list[index]}") - # print(f"triplet: {text_list[index:index+3]}") - start_frame = int(triplet[0]) - end_frame = int(triplet[1]) - desc = triplet[2] - gt_label = activity_labels_desc_mapping[desc] - - if end_frame-1 > num_images: - ### address issue with GT activity labels - print("Max frame in GT is larger than number of frames in the video") - - for label_index in range(start_frame, min(end_frame-1, num_images)): - # print(f"label_index: {label_index}") - activityi_gt_list[label_index] = gt_label + + if data_type == "pro": + text = text.replace("\n", "\t") + text_list = text.split("\t")[:-1] + for index in range(0, len(text_list), 3): + triplet = text_list[index : index + 3] + # print(f"index: {index}, {text_list[index]}") + # print(f"triplet: {text_list[index:index+3]}") + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[2] + gt_label = activity_labels_desc_mapping[desc] + + if end_frame - 1 > num_images: + ### address issue with GT activity labels + print( + "Max frame in GT is larger than number of frames in the video" + ) + + for label_index in range( + start_frame, min(end_frame - 1, num_images) + ): + activityi_gt_list[label_index] = gt_label + + elif data_type == "lab": + text = text.replace("\n", "\t") + text_list = text.split("\t") # [:-1] + for index in range(0, len(text_list), 4): + triplet = text_list[index : index + 4] + start_frame = int(triplet[0]) + end_frame = int(triplet[1]) + desc = triplet[3] + gt_label = activity_labels_desc_mapping[desc] + + if end_frame - 1 > num_images: + ### address issue with GT activity labels + print( + "Max frame in GT is larger than number of frames in the video" + ) + + for label_index in range( + start_frame, min(end_frame - 1, num_images) + ): + # print(f"label_index: {label_index}") + activityi_gt_list[label_index] = gt_label # print(f"start: {start_frame}, end: {end_frame}, label: {gt_label}, activityi_gt_list: {len(activityi_gt_list)}, num images: {num_images}") - + import collections + counter = collections.Counter(activityi_gt_list) # print(f"counter: {counter}") image_ids = [0 for x in range(num_images)] for index, img_id in enumerate(video_dset.index.imgs.keys()): im = video_dset.index.imgs[img_id] - frame_index = int(im['frame_index']) - video_dset.index.imgs[img_id]['activity_gt'] = activityi_gt_list[frame_index] - dset.index.imgs[img_id]['activity_gt'] = activityi_gt_list[frame_index] + frame_index = int(im["frame_index"]) + video_dset.index.imgs[img_id]["activity_gt"] = activityi_gt_list[ + frame_index + ] + dset.index.imgs[img_id]["activity_gt"] = activityi_gt_list[frame_index] image_ids[frame_index] = img_id - + # import matplotlib.pyplot as plt # image_show = dset.draw_image(gid=img_id) # plt.imshow(image_show) @@ -267,11 +407,10 @@ def main(task: str): # exit() # print(f"activity gt: {activityi_gt_list[frame_index]}") # print(dset.index.imgs[img_id]) - + # exit() - + #### end activity GT section - # features ( @@ -283,11 +422,11 @@ def main(task: str): act_id_to_str, ann_by_image, ) = data_loader(video_dset, activity_config) - + if split != "train_activity": num_augs = 1 aug_trans_range, aug_rot_range = None, None - + for aug_index in range(num_augs): X, y = compute_feats( act_map, @@ -297,11 +436,11 @@ def main(task: str): act_id_to_str, ann_by_image, feat_version=feat_version, - objects_joints=FEAT_TO_BOOL[config['data_gen']['feat_type']][0], - hands_joints=FEAT_TO_BOOL[config['data_gen']['feat_type']][1], - aug_trans_range = aug_trans_range, - aug_rot_range = aug_rot_range, - top_n_objects=3 + objects_joints=FEAT_TO_BOOL[config["data_gen"]["feat_type"]][0], + hands_joints=FEAT_TO_BOOL[config["data_gen"]["feat_type"]][1], + aug_trans_range=aug_trans_range, + aug_rot_range=aug_rot_range, + top_n_objects=3, ) X = X.T @@ -311,24 +450,25 @@ def main(task: str): video_name_new = f"{video_name}_{aug_index}" else: video_name_new = video_name - + npy_path = f"{features_dir}/{video_name_new}.npy" - + np.save(npy_path, X) print(f"Video info saved to: {npy_path}") # groundtruth - with open(f"{gt_dir}/{video_name_new}.txt", "w") as gt_f, \ - open(f"{frames_dir}/{video_name_new}.txt", "w") as frames_f: + with open(f"{gt_dir}/{video_name_new}.txt", "w") as gt_f, open( + f"{frames_dir}/{video_name_new}.txt", "w" + ) as frames_f: for ind, image_id in enumerate(image_ids): image = dset.imgs[image_id] - image_n = image["file_name"] # this is the shortened string + image_n = image["file_name"] # this is the shortened string # frame_idx, time = time_from_name(image_n) - frame_idx = int(image['frame_index']) + frame_idx = int(image["frame_index"]) # print(f"frame index: {frame_idx}, inds: {ind}") # print(f"image_n: {image_n}") - + activity_gt = image["activity_gt"] if activity_gt is None: activity_gt = "background" @@ -343,6 +483,51 @@ def main(task: str): print("Done!") print(f"Saved training data to {output_data_dir}") -if __name__ == '__main__': - task = "m3" - main(task=task) \ No newline at end of file + +if __name__ == "__main__": + + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--task", + help="Object detections in kwcoco format for the train set", + type=str, + ) + + parser.add_argument( + "--ptg-root", + default="/home/local/KHQ/peri.akiva/angel_system", + help="root to angel_system", + type=str, + ) + + parser.add_argument( + "--config-root", + default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs", + help="root to TCN configs", + type=str, + ) + + parser.add_argument( + "--data-type", + default="pro", + help="pro=proferssional data, lab=use lab data", + type=str, + ) + + parser.add_argument( + "--data-gen-yaml", + default="/home/local/KHQ/peri.akiva/projects/angel_system/config/data_generation/bbn_gyges.yaml", + help="Path to data generation yaml file", + type=str, + ) + + args = parser.parse_args() + main( + task=args.task, + ptg_root=args.ptg_root, + config_root=args.config_root, + data_type=args.data_type, + data_gen_yaml=args.data_gen_yaml, + ) diff --git a/tcn_hpl/models/components/focal_loss.py b/tcn_hpl/models/components/focal_loss.py index 59344f278..322ac3bc9 100644 --- a/tcn_hpl/models/components/focal_loss.py +++ b/tcn_hpl/models/components/focal_loss.py @@ -1,11 +1,13 @@ import torch from torch import nn + class FocalLoss(nn.Module): """ - Multi-class Focal Loss from the paper + Multi-class Focal Loss from the paper `https://arxiv.org/pdf/1708.02002v2.pdf` """ + def __init__(self, alpha=0.25, gamma=2, weight=None, reduction="mean"): super(FocalLoss, self).__init__() self.gamma = gamma @@ -13,21 +15,19 @@ def __init__(self, alpha=0.25, gamma=2, weight=None, reduction="mean"): self.reduction = reduction # print(f"weight: {weight}, type: {type(weight)}") if weight == "None": - weight=None + weight = None else: weight = torch.Tensor(weight) self.ce = nn.CrossEntropyLoss( - ignore_index=-100, - reduction="none", - weight=weight + ignore_index=-100, reduction="none", weight=weight ) def forward(self, inputs, targets): ce_loss = self.ce(inputs, targets) pt = torch.exp(-ce_loss) - focal_loss = self.alpha * (1 - pt)**self.gamma * ce_loss + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss # Check reduction option and return loss accordingly if self.reduction == "mean": @@ -35,4 +35,4 @@ def forward(self, inputs, targets): elif self.reduction == "sum": focal_loss = focal_loss.sum() - return focal_loss \ No newline at end of file + return focal_loss diff --git a/tcn_hpl/models/components/ms_tcs_net.py b/tcn_hpl/models/components/ms_tcs_net.py index 7a4344be4..30898eb18 100644 --- a/tcn_hpl/models/components/ms_tcs_net.py +++ b/tcn_hpl/models/components/ms_tcs_net.py @@ -4,14 +4,17 @@ import copy import einops + class MultiStageModel(nn.Module): - def __init__(self, - num_stages, - num_layers, - num_f_maps, - dim, - num_classes, - window_size,): + def __init__( + self, + num_stages, + num_layers, + num_f_maps, + dim, + num_classes, + window_size, + ): """Initialize a `MultiStageModel` module. :param num_stages: Nubmer of State Model Layers. @@ -31,11 +34,9 @@ def __init__(self, for s in range(num_stages - 1) ] ) - - + self.fc = nn.Sequential( - - nn.Linear(dim*window_size, 4096), + nn.Linear(dim * window_size, 4096), nn.GELU(), nn.Dropout(0.25), nn.Linear(4096, 8192), @@ -48,8 +49,8 @@ def __init__(self, nn.Dropout(0.25), nn.Linear(8192, 4096), nn.Dropout(0.25), - nn.Linear(4096, dim*window_size), - ) + nn.Linear(4096, dim * window_size), + ) # self.fc1 = nn.Linear(dim*30, 4096) # self.act = nn.GELU() # self.drop1 = nn.Dropout(0.1) @@ -60,26 +61,27 @@ def __init__(self, # self.act3 = nn.GELU() # self.drop3 = nn.Dropout(0.1) # self.fc4 = nn.Linear(16384, dim*30) - + # self.fc = nn.Linear(1280, 2048) def forward(self, x, mask): - b, d, c = x.shape + b, d, c = x.shape # [batch_size, feat_dim, window_size] + # mask shape: [batch_size, window_size] # print(f"x: {x.shape}") # print(f"mask: {mask.shape}") - - re_x = einops.rearrange(x, 'b d c -> b (d c)') + + re_x = einops.rearrange(x, "b d c -> b (d c)") re_x = self.fc(re_x) - x = einops.rearrange(re_x, 'b (d c) -> b d c', d=d, c=c) + x = einops.rearrange(re_x, "b (d c) -> b d c", d=d, c=c) # print(f"re_x: {re_x.shape}") # print(f"x: {x.shape}") - + out = self.stage1(x, mask) outputs = out.unsqueeze(0) for s in self.stages: out = s(F.softmax(out, dim=1) * mask[:, None, :], mask) outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) - + # print(f"outputs: {outputs.shape}") return outputs @@ -97,12 +99,12 @@ def __init__(self, num_layers, num_f_maps, dim, num_classes): self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) def forward(self, x, mask): - + out = self.conv_1x1(x) for layer in self.layers: out = layer(out, mask) out = self.conv_out(out) * mask[:, None, :] - + return out diff --git a/tcn_hpl/models/mnist_module.py b/tcn_hpl/models/mnist_module.py index b10450f19..26b12e308 100644 --- a/tcn_hpl/models/mnist_module.py +++ b/tcn_hpl/models/mnist_module.py @@ -125,8 +125,12 @@ def training_step( # update and log metrics self.train_loss(loss) self.train_acc(preds, targets) - self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + self.log( + "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True + ) + self.log( + "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True + ) # return loss or backpropagation will fail return loss @@ -135,7 +139,9 @@ def on_train_epoch_end(self) -> None: "Lightning hook that is called when a training epoch ends." pass - def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + def validation_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: """Perform a single validation step on a batch of data from the validation set. :param batch: A batch of data (a tuple) containing the input tensor of images and target @@ -156,9 +162,13 @@ def on_validation_epoch_end(self) -> None: self.val_acc_best(acc) # update best so far val acc # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch - self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) + self.log( + "val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True + ) - def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + def test_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: """Perform a single test step on a batch of data from the test set. :param batch: A batch of data (a tuple) containing the input tensor of images and target @@ -170,7 +180,9 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> # update and log metrics self.test_loss(loss) self.test_acc(preds, targets) - self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log( + "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True + ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) def on_test_epoch_end(self) -> None: diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index e907221c9..964b1071d 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -99,9 +99,12 @@ def __init__( actions_dict = dict() for a in actions: actions_dict[a.split()[1]] = int(a.split()[0]) - + self.class_ids = list(actions_dict.values()) self.classes = list(actions_dict.keys()) + # print(f"CLASSES IN MODEL: {self.classes}") + # print(f"actions_dict: {actions_dict}") + # exit() self.action_id_to_str = dict(zip(self.class_ids, self.classes)) # loss functions @@ -109,9 +112,15 @@ def __init__( self.mse = nn.MSELoss(reduction="none") # metric objects for calculating and averaging accuracy across batches - self.train_acc = Accuracy(task="multiclass", average="weighted", num_classes=num_classes) - self.val_acc = Accuracy(task="multiclass", average="weighted", num_classes=num_classes) - self.test_acc = Accuracy(task="multiclass", average="weighted", num_classes=num_classes) + self.train_acc = Accuracy( + task="multiclass", average="weighted", num_classes=num_classes + ) + self.val_acc = Accuracy( + task="multiclass", average="weighted", num_classes=num_classes + ) + self.test_acc = Accuracy( + task="multiclass", average="weighted", num_classes=num_classes + ) # for averaging loss across batches self.train_loss = MeanMetric() @@ -128,7 +137,7 @@ def __init__( self.validation_step_outputs_target = [] self.validation_step_outputs_source_vid = [] self.validation_step_outputs_source_frame = [] - + self.training_step_outputs_target = [] self.training_step_outputs_source_vid = [] self.training_step_outputs_source_frame = [] @@ -139,15 +148,14 @@ def __init__( self.val_frames = None self.test_frames = None - - def plot_gt_vs_preds(self, per_video_frame_gt_preds, split='train', max_items=30): + def plot_gt_vs_preds(self, per_video_frame_gt_preds, split="train", max_items=30): # fig = plt.figure(figsize=(10,15)) - + for index, video in enumerate(per_video_frame_gt_preds.keys()): - + if index >= max_items: return - + fig, ax = plt.subplots(figsize=(15, 8)) video_gt_preds = per_video_frame_gt_preds[video] frame_inds = sorted(list(video_gt_preds.keys())) @@ -156,17 +164,27 @@ def plot_gt_vs_preds(self, per_video_frame_gt_preds, split='train', max_items=30 inds.append(int(ind)) gt.append(int(video_gt_preds[ind][0])) preds.append(int(video_gt_preds[ind][1])) - - + # plt.plot(gt, label="gt") # plt.plot(preds, label="preds") - + # sns.barplot(x=inds, y=gt, linestyle='dotted', color='magenta', label='GT', ax=ax) # ax.stackplot(inds, gt, alpha=0.5, labels=['GT'], color=['magenta']) - sns.lineplot(x=inds, y=preds, linestyle='dotted', color='blue', linewidth=1, label='Pred', ax=ax).set(title=f'{split} Step Prediction Per Frame', xlabel='Index', ylabel='Step') - sns.lineplot(x=inds, y=gt, color='magenta', label='GT', ax=ax, linewidth=3) + sns.lineplot( + x=inds, + y=preds, + linestyle="dotted", + color="blue", + linewidth=1, + label="Pred", + ax=ax, + ).set( + title=f"{split} Step Prediction Per Frame", + xlabel="Index", + ylabel="Step", + ) + sns.lineplot(x=inds, y=gt, color="magenta", label="GT", ax=ax, linewidth=3) - # ax.xaxis.label.set_visible(False) # ax.spines['bottom'].set_visible(False) ax.legend() @@ -174,13 +192,13 @@ def plot_gt_vs_preds(self, per_video_frame_gt_preds, split='train', max_items=30 # title = f"plot_pred_vs_gt_vid{video}.png" Path(folder).mkdir(parents=True, exist_ok=True) # fig.savefig(f"{self.hparams.output_dir}/{title}") root_dir = f"{self.hparams.output_dir}/steps_vs_preds" - + if not os.path.exists(root_dir): os.makedirs(root_dir) - + fig.savefig(f"{root_dir}/{split}_{video}.png", pad_inches=5) plt.close() - + def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: """Perform a forward pass through the model `self.net`. @@ -189,7 +207,7 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: :return: A tensor of logits. """ - return self.net(x,m) + return self.net(x, m) def on_train_start(self) -> None: """Lightning hook that is called when training begins.""" @@ -211,16 +229,15 @@ def compute_loss(self, p, y, mask): :return: The loss """ - - probs = torch.softmax(p, dim=1) # shape (batch size, self.hparams.num_classes) - preds = torch.argmax(probs, dim=1).float() # shape: batch size - + + probs = torch.softmax(p, dim=1) # shape (batch size, self.hparams.num_classes) + preds = torch.argmax(probs, dim=1).float() # shape: batch size + # print(f"prediction: {p}, GT: {y}"), # [bs, num_classes, window_size] # print(f"prediction: {p.shape}, GT: {y.shape}") - - + loss = torch.zeros((1)).to(p[0]) - + # print(f"loss: {loss.shape}") # print(f"p loss: {p[:,:,-1].shape}") # print(f"y: {y.view(-1).shape}") @@ -228,24 +245,24 @@ def compute_loss(self, p, y, mask): # p = einops.rearrange(p, 'b c w -> (b w) c') # print(f"prediction: {p.shape}, GT: {y.shape}") # print(f"prediction: {p}, GT: {y}"), # [bs, num_classes, window_size] - + # TODO: Use only last frame per window - + loss += self.criterion( p.transpose(2, 1).contiguous().view(-1, self.hparams.num_classes), y.view(-1), ) - + # loss += self.criterion( # p[:,:,-1], # y[:,-1], # ) # need to penalize high volatility of predictions within a window - + # std, mean = torch.std_mean(preds, dim=-1) mode, _ = torch.mode(y, dim=-1) - mode = einops.repeat(mode, 'b -> b c', c=preds.shape[-1]) + mode = einops.repeat(mode, "b -> b c", c=preds.shape[-1]) # print(f"mode: {mode.shape}") # print(f"preds: {preds.shape}") # print(f"mode: {mode[0,:]}") @@ -260,15 +277,15 @@ def compute_loss(self, p, y, mask): # print(f"p: {p.shape}") # print(f"p[:, :, 1:]: {p[0, 0, 1:]}") # print(f"F.log_softmax(p[:, :, 1:], dim=1): {F.log_softmax(p[:, :, 1:], dim=1)}") - + if self.hparams.use_smoothing_loss: loss += self.hparams.smoothing_loss * torch.mean( - self.mse( - variation_coef, - gt_variation_coef, - ), + self.mse( + variation_coef, + gt_variation_coef, + ), ) - + loss += self.hparams.smoothing_loss * torch.mean( torch.clamp( self.mse( @@ -297,24 +314,27 @@ def model_step( - A tensor of target labels. - A tensor of [video name, frame idx] """ - x, y, m, source_vid, source_frame = batch + x, y, m, source_vid, source_frame = batch # x shape: (batch size, window, feat dim) # y shape: (batch size, window) # m shape: (batch size, window) # source_vid shape: (batch size, window) - x = x.transpose(2, 1) # shape (batch size, feat dim, window) - logits = self.forward(x, m) # shape (4, batch size, self.hparams.num_classes, window)) + x = x.transpose(2, 1) # shape (batch size, feat dim, window) + logits = self.forward( + x, m + ) # shape (4, batch size, self.hparams.num_classes, window)) # print(f"logits: {logits.shape}") loss = torch.zeros((1)).to(x) for p in logits: loss += self.compute_loss(p, y, m) - probs = torch.softmax(logits[-1,:,:,-1], dim=1) # shape (batch size, self.hparams.num_classes) - preds = torch.argmax(logits[-1,:,:,-1], dim=1) # shape: batch size + probs = torch.softmax( + logits[-1, :, :, -1], dim=1 + ) # shape (batch size, self.hparams.num_classes) + preds = torch.argmax(logits[-1, :, :, -1], dim=1) # shape: batch size return loss, probs, preds, y, source_vid, source_frame - def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: @@ -329,46 +349,54 @@ def training_step( # print(f"targets: {targets}") # print(f"preds: {preds}") - + # update and log metrics self.train_loss(loss) - self.train_acc(preds, targets[:,-1]) + self.train_acc(preds, targets[:, -1]) - self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) + self.log( + "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True + ) + self.log( + "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True + ) - self.training_step_outputs_target.append(targets[:,-1]) - self.training_step_outputs_source_vid.append(source_vid[:,-1]) - self.training_step_outputs_source_frame.append(source_frame[:,-1]) + self.training_step_outputs_target.append(targets[:, -1]) + self.training_step_outputs_source_vid.append(source_vid[:, -1]) + self.training_step_outputs_source_frame.append(source_frame[:, -1]) self.training_step_outputs_pred.append(preds) self.training_step_outputs_prob.append(probs) - + # return loss or backpropagation will fail return loss def on_train_epoch_end(self) -> None: "Lightning hook that is called when a training epoch ends." - + # acc = self.train_acc.compute() # get current val acc # self.train_acc_best(acc) # update best so far val acc - + # all_targets = torch.concat(self.training_step_outputs_target) # shape: #frames # all_preds = torch.concat(self.training_step_outputs_pred) # shape: #frames # all_probs = torch.concat(self.training_step_outputs_prob) # shape (#frames, #act labels) # all_source_vids = torch.concat(self.training_step_outputs_source_vid) # all_source_frames = torch.concat(self.training_step_outputs_source_frame) - - all_targets = torch.cat(self.training_step_outputs_target) # shape: #frames - all_preds = torch.cat(self.training_step_outputs_pred) # shape: #frames - all_probs = torch.cat(self.training_step_outputs_prob) # shape (#frames, #act labels) + + all_targets = torch.cat(self.training_step_outputs_target) # shape: #frames + all_preds = torch.cat(self.training_step_outputs_pred) # shape: #frames + all_probs = torch.cat( + self.training_step_outputs_prob + ) # shape (#frames, #act labels) all_source_vids = torch.cat(self.training_step_outputs_source_vid) all_source_frames = torch.cat(self.training_step_outputs_source_frame) - + # print(f"Training Per class occurences in GT: {torch.unique(all_targets, return_counts=True)}") - + if self.train_frames is None: self.train_frames = {} - vid_list_file_train = f"{self.hparams.data_dir}/splits/train_activity.split1.bundle" + vid_list_file_train = ( + f"{self.hparams.data_dir}/splits/train_activity.split1.bundle" + ) with open(vid_list_file_train, "r") as train_f: self.train_videos = train_f.read().split("\n")[:-1] @@ -379,22 +407,22 @@ def on_train_epoch_end(self) -> None: train_fns = train_f.read().split("\n")[:-1] self.train_frames[video[:-4]] = train_fns - + per_video_frame_gt_preds = {} - - for (gt, pred, source_vid, source_frame) in zip(all_targets, all_preds, all_source_vids, all_source_frames): + + for (gt, pred, source_vid, source_frame) in zip( + all_targets, all_preds, all_source_vids, all_source_frames + ): video_name = self.train_videos[int(source_vid)][:-4] - if video_name not in per_video_frame_gt_preds.keys(): per_video_frame_gt_preds[video_name] = {} frame = self.train_frames[video_name][int(source_frame)] # frame_idx, time = time_from_name(frame) - frame_idx = int(frame.split('/')[-1].split('.')[0].split('_')[-1]) + frame_idx = int(frame.split("/")[-1].split(".")[0].split("_")[-1]) per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) - - + # print(f"video name: {video_name}, frame index: {frame_idx}, gt: {gt}, pred: {pred}") # dset.add_image( # file_name=frame, @@ -404,34 +432,35 @@ def on_train_epoch_end(self) -> None: # activity_pred=int(pred), # activity_conf=prob.tolist() # ) - + # print(f"per_video_frame_gt_preds: {per_video_frame_gt_preds}") # exit() - - self.plot_gt_vs_preds(per_video_frame_gt_preds, split='train') - + + self.plot_gt_vs_preds(per_video_frame_gt_preds, split="train") + cm = confusion_matrix( - all_targets.cpu().numpy(), all_preds.cpu().numpy(), + all_targets.cpu().numpy(), + all_preds.cpu().numpy(), labels=self.class_ids, - normalize="true" + normalize="true", ) num_act_classes = len(self.class_ids) fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) - + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) # labels, title and ticks - ax.set_xlabel('Predicted labels') - ax.set_ylabel('True labels') - ax.set_title(f'CM Training Epoch {self.current_epoch}') + ax.set_xlabel("Predicted labels") + ax.set_ylabel("True labels") + ax.set_title(f"CM Training Epoch {self.current_epoch}") ax.xaxis.set_ticklabels(self.classes, rotation=25) ax.yaxis.set_ticklabels(self.classes, rotation=0) - self.logger.experiment.track(Image(fig), name=f'CM Training Epoch') + self.logger.experiment.track(Image(fig), name=f"CM Training Epoch") fig.savefig(f"{self.hparams.output_dir}/confusion_mat_train.png", pad_inches=5) - + plt.close(fig) self.training_step_outputs_target.clear() @@ -439,24 +468,26 @@ def on_train_epoch_end(self) -> None: self.training_step_outputs_source_frame.clear() self.training_step_outputs_pred.clear() self.training_step_outputs_prob.clear() - + # pass - def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + def validation_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: """Perform a single validation step on a batch of data from the validation set. :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. :param batch_idx: The index of the current batch. """ - _, _, mask, _, _ = batch + _, _, mask, _, _ = batch loss, probs, preds, targets, source_vid, source_frame = self.model_step(batch) # update and log metrics self.val_loss(loss) - + # print(f"preds: {preds.shape}, targets: {targets.shape}") # print(f"mask: {mask.shape}, {mask[0,:]}") - ys = targets[:,-1] + ys = targets[:, -1] # print(f"y: {ys.shape}") # print(f"y: {ys}") windowed_preds, windowed_ys = [], [] @@ -464,33 +495,31 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i center = 7 inds = [] for i in range(preds.shape[0] - window_size + 1): - y = ys[i:i+window_size].tolist() + y = ys[i : i + window_size].tolist() # print(f"y: {y}") # print(f"len of set: {len(list(set(y)))}") if len(list(set(y))) == 1: - inds.append(i+center-1) - windowed_preds.append(preds[i+center-1]) - windowed_ys.append(ys[i+center-1]) - - + inds.append(i + center - 1) + windowed_preds.append(preds[i + center - 1]) + windowed_ys.append(ys[i + center - 1]) + windowed_preds = torch.tensor(windowed_preds).to(targets) windowed_ys = torch.tensor(windowed_ys).to(targets) - + # print(f"preds: {preds.shape}, targets: {targets.shape}") # print(f"windowed_preds: {windowed_preds.shape}, windowed_ys: {windowed_ys.shape}") - + # self.val_acc(windowed_preds, windowed_ys) - self.val_acc(preds, targets[:,-1]) + self.val_acc(preds, targets[:, -1]) self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - self.validation_step_outputs_target.append(targets[inds,-1]) - self.validation_step_outputs_source_vid.append(source_vid[inds,-1]) - self.validation_step_outputs_source_frame.append(source_frame[inds,-1]) + self.validation_step_outputs_target.append(targets[inds, -1]) + self.validation_step_outputs_source_vid.append(source_vid[inds, -1]) + self.validation_step_outputs_source_frame.append(source_frame[inds, -1]) self.validation_step_outputs_pred.append(preds[inds]) self.validation_step_outputs_prob.append(probs[inds]) - # def plot_gt_vs_activations(self, step_gts, fname_suffix=None): # # Plot gt vs predicted class across all vid frames # fig = plt.figure() @@ -511,39 +540,38 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i # plt.title(title) # fig.savefig(f"./outputs/{title}") # TODO: output this wherever you want # plt.close() - - def on_validation_epoch_end(self) -> None: "Lightning hook that is called when a validation epoch ends." acc = self.val_acc.compute() # get current val acc - + if self.current_epoch >= 15: self.val_acc_best(acc) # update best so far val acc - # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object - # otherwise metric would be reset by lightning after each epoch + # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object + # otherwise metric would be reset by lightning after each epoch best_val_acc = self.val_acc_best.compute() self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True) # import collections # counter = collections.Counter(self.validation_step_outputs_target) - # all_targets = torch.concat(self.validation_step_outputs_target) # shape: #frames # all_preds = torch.concat(self.validation_step_outputs_pred) # shape: #frames # all_probs = torch.concat(self.validation_step_outputs_prob) # shape (#frames, #act labels) # all_source_vids = torch.concat(self.validation_step_outputs_source_vid) # all_source_frames = torch.concat(self.validation_step_outputs_source_frame) - - all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames - all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames - all_probs = torch.cat(self.validation_step_outputs_prob) # shape (#frames, #act labels) + + all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames + all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames + all_probs = torch.cat( + self.validation_step_outputs_prob + ) # shape (#frames, #act labels) all_source_vids = torch.cat(self.validation_step_outputs_source_vid) all_source_frames = torch.cat(self.validation_step_outputs_source_frame) # print(f"Per class occurences in GT: {torch.unique(all_targets, return_counts=True)}") # print(f"all_targets: {all_targets.shape}") - + # Load val vidoes if self.val_frames is None: self.val_frames = {} @@ -566,68 +594,74 @@ def on_validation_epoch_end(self) -> None: # print(f"video_lookup: {video_lookup}") per_video_frame_gt_preds = {} - - for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): + + for (gt, pred, prob, source_vid, source_frame) in zip( + all_targets, all_preds, all_probs, all_source_vids, all_source_frames + ): video_name = self.val_videos[int(source_vid)][:-4] - if video_name not in per_video_frame_gt_preds.keys(): per_video_frame_gt_preds[video_name] = {} - - - + video_lookup = dset.index.name_to_video - vid = video_lookup[video_name]["id"] if video_name in video_lookup else dset.add_video(name=video_name) + vid = ( + video_lookup[video_name]["id"] + if video_name in video_lookup + else dset.add_video(name=video_name) + ) frame = self.val_frames[video_name][int(source_frame)] - - frame_idx = int(frame.split('/')[-1].split('.')[0].split('_')[-1]) + + frame_idx = int(frame.split("/")[-1].split(".")[0].split("_")[-1]) # print(f"frame: {frame}, frame index: {frame_idx}") # frame_idx, time = time_from_name(frame) per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) - + # print(f"video name: {video_name}, frame index: {frame_idx}, gt: {gt}, pred: {pred}") - + dset.add_image( file_name=frame, video_id=vid, frame_index=frame_idx, activity_gt=int(gt), activity_pred=int(pred), - activity_conf=prob.tolist() + activity_conf=prob.tolist(), ) # dset.dump(dset.fpath, newlines=True) # print(f"Saved dset to {dset.fpath}") - # print(f"per_video_frame_gt_preds: {per_video_frame_gt_preds}") - self.plot_gt_vs_preds(per_video_frame_gt_preds, split='validation') + self.plot_gt_vs_preds(per_video_frame_gt_preds, split="validation") # Create confusion matrix cm = confusion_matrix( - all_targets.cpu().numpy(), all_preds.cpu().numpy(), + all_targets.cpu().numpy(), + all_preds.cpu().numpy(), labels=self.class_ids, - normalize="true" + normalize="true", ) num_act_classes = len(self.class_ids) fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) - + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", linewidth=0.5, vmin=0, vmax=1) # labels, title and ticks - ax.set_xlabel('Predicted labels') - ax.set_ylabel('True labels') - ax.set_title(f'CM Validation Epoch {self.current_epoch}, Accuracy: {acc:.4f}') + ax.set_xlabel("Predicted labels") + ax.set_ylabel("True labels") + ax.set_title(f"CM Validation Epoch {self.current_epoch}, Accuracy: {acc:.4f}") ax.xaxis.set_ticklabels(self.classes, rotation=25) ax.yaxis.set_ticklabels(self.classes, rotation=0) - self.logger.experiment.track(Image(fig), name=f'CM Validation Epoch') - + self.logger.experiment.track(Image(fig), name=f"CM Validation Epoch") + if self.current_epoch >= 15: if acc >= best_val_acc: - fig.savefig(f"{self.hparams.output_dir}/confusion_mat_val_acc_{acc:.4f}.png", pad_inches=5) - + fig.savefig( + f"{self.hparams.output_dir}/confusion_mat_val_acc_{acc:.4f}.png", + pad_inches=5, + ) + plt.close(fig) self.validation_step_outputs_target.clear() @@ -636,7 +670,9 @@ def on_validation_epoch_end(self) -> None: self.validation_step_outputs_pred.clear() self.validation_step_outputs_prob.clear() - def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: + def test_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> None: """Perform a single test step on a batch of data from the test set. :param batch: A batch of data (a tuple) containing the input tensor of images and target @@ -647,13 +683,15 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> # update and log metrics self.test_loss(loss) - self.test_acc(preds, targets[:,-1]) - self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(preds, targets[:, -1]) + self.log( + "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True + ) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - self.validation_step_outputs_target.append(targets[:,-1]) - self.validation_step_outputs_source_vid.append(source_vid[:,-1]) - self.validation_step_outputs_source_frame.append(source_frame[:,-1]) + self.validation_step_outputs_target.append(targets[:, -1]) + self.validation_step_outputs_source_vid.append(source_vid[:, -1]) + self.validation_step_outputs_source_frame.append(source_frame[:, -1]) self.validation_step_outputs_pred.append(preds) self.validation_step_outputs_prob.append(probs) @@ -661,11 +699,13 @@ def on_test_epoch_end(self) -> None: """Lightning hook that is called when a test epoch ends.""" # update and log metrics - all_targets = torch.concat(self.validation_step_outputs_target) # shape: #frames - all_preds = torch.concat(self.validation_step_outputs_pred) # shape: #frames - all_probs = torch.concat(self.validation_step_outputs_prob) # shape (#frames, #act labels) - all_source_vids = torch.concat(self.validation_step_outputs_source_vid) - all_source_frames = torch.concat(self.validation_step_outputs_source_frame) + all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames + all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames + all_probs = torch.cat( + self.validation_step_outputs_prob + ) # shape (#frames, #act labels) + all_source_vids = torch.cat(self.validation_step_outputs_source_vid) + all_source_frames = torch.cat(self.validation_step_outputs_source_frame) # Load test vidoes if self.test_frames is None: @@ -681,64 +721,72 @@ def on_test_epoch_end(self) -> None: test_fns = test_f.read().split("\n")[:-1] self.test_frames[video[:-4]] = test_fns - + # Save results dset = kwcoco.CocoDataset() dset.fpath = f"{self.hparams.output_dir}/test_activity_preds.mscoco.json" dset.dataset["info"].append({"activity_labels": self.action_id_to_str}) - per_video_frame_gt_preds = {} - for (gt, pred, prob, source_vid, source_frame) in zip(all_targets, all_preds, all_probs, all_source_vids, all_source_frames): + for (gt, pred, prob, source_vid, source_frame) in zip( + all_targets, all_preds, all_probs, all_source_vids, all_source_frames + ): video_name = self.test_videos[int(source_vid)][:-4] if video_name not in per_video_frame_gt_preds.keys(): per_video_frame_gt_preds[video_name] = {} - + video_lookup = dset.index.name_to_video - vid = video_lookup[video_name]["id"] if video_name in video_lookup else dset.add_video(name=video_name) + vid = ( + video_lookup[video_name]["id"] + if video_name in video_lookup + else dset.add_video(name=video_name) + ) frame = self.test_frames[video_name][int(source_frame)] - frame_idx = int(frame.split('/')[-1].split('.')[0].split('_')[-1]) + frame_idx = int(frame.split("/")[-1].split(".")[0].split("_")[-1]) # frame_idx, time = time_from_name(frame) per_video_frame_gt_preds[video_name][frame_idx] = (int(gt), int(pred)) - + dset.add_image( file_name=frame, video_id=vid, frame_index=frame_idx, activity_gt=int(gt), activity_pred=int(pred), - activity_conf=prob.tolist() + activity_conf=prob.tolist(), ) dset.dump(dset.fpath, newlines=True) print(f"Saved dset to {dset.fpath}") - - self.plot_gt_vs_preds(per_video_frame_gt_preds, split='test') + self.plot_gt_vs_preds(per_video_frame_gt_preds, split="test") # Create confusion matrix cm = confusion_matrix( - all_targets.cpu().numpy(), all_preds.cpu().numpy(), + all_targets.cpu().numpy(), + all_preds.cpu().numpy(), labels=self.class_ids, - normalize="true" + normalize="true", ) num_act_classes = len(self.class_ids) - fig, ax = plt.subplots(figsize=(num_act_classes,num_act_classes)) - + fig, ax = plt.subplots(figsize=(num_act_classes, num_act_classes)) + sns.heatmap(cm, annot=True, ax=ax, fmt=".2f", vmin=0, vmax=1) # labels, title and ticks - ax.set_xlabel('Predicted labels') - ax.set_ylabel('True labels') - ax.set_title(f'CM Test Epoch {self.current_epoch}') + ax.set_xlabel("Predicted labels") + ax.set_ylabel("True labels") + ax.set_title(f"CM Test Epoch {self.current_epoch}") ax.xaxis.set_ticklabels(self.classes, rotation=25) ax.yaxis.set_ticklabels(self.classes, rotation=0) - fig.savefig(f"{self.hparams.output_dir}/confusion_mat_test_acc_{self.test_acc.compute():0.2f}.png", pad_inches=5) + fig.savefig( + f"{self.hparams.output_dir}/confusion_mat_test_acc_{self.test_acc.compute():0.2f}.png", + pad_inches=5, + ) - self.logger.experiment.track(Image(fig), name=f'CM Test Epoch') + self.logger.experiment.track(Image(fig), name=f"CM Test Epoch") plt.close(fig) diff --git a/tcn_hpl/train.py b/tcn_hpl/train.py index dfeab17c2..402581271 100644 --- a/tcn_hpl/train.py +++ b/tcn_hpl/train.py @@ -3,6 +3,7 @@ import hydra import rootutils import torch + # import lightning as L # from lightning import Callback, LightningDataModule, LightningModule, Trainer # from lightning.pytorch.loggers import Logger @@ -56,7 +57,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # print(f"datamodule: {datamodule.__dict__}") # exit() - + log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) @@ -67,7 +68,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") - trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) + trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, callbacks=callbacks, logger=logger + ) object_dict = { "cfg": cfg, diff --git a/tcn_hpl/utils/pylogger.py b/tcn_hpl/utils/pylogger.py index 856b3d8f9..8972632db 100644 --- a/tcn_hpl/utils/pylogger.py +++ b/tcn_hpl/utils/pylogger.py @@ -15,7 +15,15 @@ def get_pylogger(name: str = __name__) -> logging.Logger: # this ensures all logging levels get marked with the rank zero decorator # otherwise logs would get multiplied for each GPU process in multi-GPU setup - logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") + logging_levels = ( + "debug", + "info", + "warning", + "error", + "exception", + "fatal", + "critical", + ) for level in logging_levels: setattr(logger, level, rank_zero_only(getattr(logger, level))) diff --git a/tcn_hpl/utils/utils.py b/tcn_hpl/utils/utils.py index 5a5b478b8..e5b6a157c 100644 --- a/tcn_hpl/utils/utils.py +++ b/tcn_hpl/utils/utils.py @@ -20,10 +20,11 @@ def load_yaml_as_dict(yaml_path: str) -> dict: Returns: dict: A dictionary containing the configuration settings. """ - with open(yaml_path, 'r') as f: + with open(yaml_path, "r") as f: config_dict = yaml.load(f, Loader=yaml.FullLoader) return config_dict + def dictionary_contents(path: str, types: list, recursive: bool = False) -> list: """ Extract files of specified types from directories, optionally recursively. @@ -48,6 +49,7 @@ def dictionary_contents(path: str, types: list, recursive: bool = False) -> list files.append(os.path.join(path, x)) return files + def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started.