diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index 094adefaf..140ce14cf 100644 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -3,6 +3,7 @@ defaults: - early_stopping - model_summary - rich_progress_bar + - learning_rate_monitor - plot_metrics - _self_ @@ -11,11 +12,10 @@ defaults: model_checkpoint: dirpath: ${paths.output_dir}/checkpoints filename: "epoch_{epoch:03d}" - # monitor: "val/loss" - # mode: "min" monitor: "val/f1" mode: "max" save_last: True + save_top_k: 10 # save k best models (determined by above metric) auto_insert_metric_name: False early_stopping: diff --git a/configs/callbacks/learning_rate_monitor.yaml b/configs/callbacks/learning_rate_monitor.yaml new file mode 100644 index 000000000..9347987d6 --- /dev/null +++ b/configs/callbacks/learning_rate_monitor.yaml @@ -0,0 +1,4 @@ +learning_rate_monitor: + _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + log_momentum: false diff --git a/configs/data/ptg.yaml b/configs/data/ptg.yaml index 497e9e683..eeb96f595 100644 --- a/configs/data/ptg.yaml +++ b/configs/data/ptg.yaml @@ -30,6 +30,14 @@ val_dataset: test_dataset: ${data.val_dataset} +# Match the test dataset's configuration **sans** augmentations. +pred_dataset: + _target_: tcn_hpl.data.tcn_dataset.TCNDataset + window_size: ${data.test_dataset.window_size} + window_label_idx: ${data.test_dataset.window_label_idx} + vectorize: ${data.test_dataset.vectorize} + transform_frame_data: null + coco_train_activities: "" coco_train_objects: "" coco_train_poses: "" @@ -42,5 +50,5 @@ coco_test_poses: "" batch_size: 128 num_workers: 0 target_framerate: 15 -epoch_length: 10000 +epoch_sample_factor: 1 pin_memory: True diff --git a/configs/experiment/m2/feat_locsconfs.yaml b/configs/experiment/m2/feat_locsconfs.yaml index 6a5505f5a..5ff7d9843 100644 --- a/configs/experiment/m2/feat_locsconfs.yaml +++ b/configs/experiment/m2/feat_locsconfs.yaml @@ -28,12 +28,6 @@ tags: ["m2", "ms_tcn", "debug"] seed: 12345 -#callbacks: -# model_checkpoint: -# # save all ~80MB checkpoints for post-training investigation. -# # Total: ~45GB -# save_top_k: 500 - trainer: min_epochs: 50 max_epochs: 500 @@ -42,16 +36,15 @@ trainer: model: num_classes: 9 # number of activity classification classes compile: false + scheduler: + # Code change to track train/loss instead of val/loss. + factor: 0.9 + patience: 10 net: # Length of feature vector for a single frame. # Currently derived from the parameterization of dataset vectorizer. dim: 102 -# # Once upon a time defaults -# num_stages: 4 -# num_layers: 10 -# num_f_maps: 64 - data: coco_train_activities: "${paths.coco_file_root}/TRAIN-activity_truth.coco.json" coco_train_objects: "${paths.coco_file_root}/TRAIN-object_detections.coco.json" @@ -68,10 +61,7 @@ data: batch_size: 512 num_workers: 16 target_framerate: 15 # BBN Hololens2 Framerate - # This is a bit more than the number of windows in the training dataset so - # the weighted sampler has more of an opportunity to sample the space - # proportionally. - epoch_length: 300000 + epoch_sample_factor: 1 # 1x the dataset size iterations for train/val train_dataset: window_size: 25 @@ -99,7 +89,15 @@ data: pose_throughput_std: 0.2 fixed_pattern: false - _target_: tcn_hpl.data.frame_data_aug.rotate_scale_translate_jitter.FrameDataRotateScaleTranslateJitter - # Using default parameters to start + translate: 0.05 + scale: [0.9, 1.1] + rotate: [-5, 5] + det_loc_jitter: 0.02 + det_wh_jitter: 0.02 + pose_kp_loc_jitter: 0.005 + dets_score_jitter: 0. + pose_score_jitter: 0. + pose_kp_score_jitter: 0. val_dataset: # Augmentations on windows of frame data before performing vectorization. # Sharing transform with training dataset as it is only the drop-out aug to diff --git a/configs/experiment/r18/feat_locsconfs.yaml b/configs/experiment/r18/feat_locsconfs.yaml index 082cfa172..3a0902b4b 100644 --- a/configs/experiment/r18/feat_locsconfs.yaml +++ b/configs/experiment/r18/feat_locsconfs.yaml @@ -36,16 +36,15 @@ trainer: model: num_classes: 6 # number of activity classification classes compile: false + scheduler: + # Code change to track train/loss instead of val/loss. + factor: 0.9 + patience: 10 net: # Length of feature vector for a single frame. # Currently derived from the parameterization of dataset vectorizer. dim: 102 -# # Once upon a time defaults -# num_stages: 4 -# num_layers: 10 -# num_f_maps: 64 - data: coco_train_activities: "${paths.coco_file_root}/TRAIN-activity_truth.coco.json" coco_train_objects: "${paths.coco_file_root}/TRAIN-object_detections.coco.json" @@ -62,10 +61,7 @@ data: batch_size: 512 num_workers: 16 target_framerate: 15 # BBN Hololens2 Framerate - # This is a bit more than the number of windows in the training dataset so - # the weighted sampler has more of an opportunity to sample the space - # proportionally. - epoch_length: 300000 + epoch_sample_factor: 1 # 1x the dataset size iterations for train/val train_dataset: window_size: 25 @@ -93,7 +89,15 @@ data: pose_throughput_std: 0.2 fixed_pattern: false - _target_: tcn_hpl.data.frame_data_aug.rotate_scale_translate_jitter.FrameDataRotateScaleTranslateJitter - # Using default parameters to start + translate: 0.05 + scale: [0.9, 1.1] + rotate: [-5, 5] + det_loc_jitter: 0.02 + det_wh_jitter: 0.02 + pose_kp_loc_jitter: 0.005 + dets_score_jitter: 0. + pose_score_jitter: 0. + pose_kp_score_jitter: 0. val_dataset: # Augmentations on windows of frame data before performing vectorization. # Sharing transform with training dataset as it is only the drop-out aug to diff --git a/configs/model/ptg.yaml b/configs/model/ptg.yaml index 37613986c..2a3509333 100644 --- a/configs/model/ptg.yaml +++ b/configs/model/ptg.yaml @@ -20,7 +20,6 @@ net: num_stages: 4 num_layers: 5 num_f_maps: 128 - # dim: 204 dim: 128 num_classes: ${model.num_classes} diff --git a/tcn_hpl/callbacks/plot_metrics.py b/tcn_hpl/callbacks/plot_metrics.py index 7fdd99309..850702bf8 100644 --- a/tcn_hpl/callbacks/plot_metrics.py +++ b/tcn_hpl/callbacks/plot_metrics.py @@ -174,8 +174,8 @@ def on_train_epoch_end( all_source_frames = torch.cat(self._train_all_source_frames) # shape: #frames current_epoch = pl_module.current_epoch - curr_acc = pl_module.train_acc.compute() - curr_f1 = pl_module.train_f1.compute() + curr_acc = pl_module.train_metrics.acc.compute() + curr_f1 = pl_module.train_metrics.f1.compute() # # Plot per-video class predictions vs. GT across progressive frames in @@ -265,8 +265,8 @@ def on_validation_epoch_end( all_source_frames = torch.cat(self._val_all_source_frames) # shape: #frames current_epoch = pl_module.current_epoch - curr_acc = pl_module.val_acc.compute() - curr_f1 = pl_module.val_f1.compute() + curr_acc = pl_module.val_metrics.acc.compute() + curr_f1 = pl_module.val_metrics.f1.compute() best_f1 = pl_module.val_f1_best.compute() # @@ -359,8 +359,8 @@ def on_test_epoch_end( all_source_frames = torch.cat(self._val_all_source_frames) # shape: #frames current_epoch = pl_module.current_epoch - test_acc = pl_module.test_acc.compute() - test_f1 = pl_module.test_f1.compute() + test_acc = pl_module.test_metrics.acc.compute() + test_f1 = pl_module.test_metrics.f1.compute() # # Plot per-video class predictions vs. GT across progressive frames in diff --git a/tcn_hpl/data/frame_data.py b/tcn_hpl/data/frame_data.py index 890e33e84..e716789ac 100644 --- a/tcn_hpl/data/frame_data.py +++ b/tcn_hpl/data/frame_data.py @@ -38,6 +38,16 @@ def __post_init__(self): def __bool__(self): return bool(self.boxes.size) + def __eq__(self, other): + return ( + (self.boxes == other.boxes).all() + and (self.labels == other.labels).all() + and (self.scores == other.scores).all() + ) + + def __ne__(self, other): + return not (self == other) + @dataclass class FramePoses: @@ -51,7 +61,8 @@ class FramePoses: # Array of scores for each pose. Ostensibly the bbox score. Shape: (num_poses,) scores: npt.NDArray[float] # Pose join 2D positions in ascending joint ID order. If the joint is not - # present, 0s are used. Shape: (num_poses, num_joints, 2) + # present, 0s are used. Points in (x, y) format. + # Shape: (num_poses, num_joints, 2) joint_positions: npt.NDArray[float] # Poise joint scores. Shape: (num_poses, num_joints) joint_scores: npt.NDArray[float] @@ -67,6 +78,16 @@ def __post_init__(self): def __bool__(self): return bool(self.scores.size) + def __eq__(self, other): + return ( + (self.scores == other.scores).all() + and (self.joint_positions == other.joint_positions).all() + and (self.joint_scores == other.joint_scores).all() + ) + + def __ne__(self, other): + return not (self == other) + @dataclass class FrameData: @@ -114,3 +135,13 @@ def __bool__(self): not. """ return bool(self.object_detections) or bool(self.poses) + + def __eq__(self, other): + return ( + (self.object_detections == other.object_detections) + and (self.poses == other.poses) + and (self.size == other.size) + ) + + def __ne__(self, other): + return not (self == other) diff --git a/tcn_hpl/data/frame_data_aug/rotate_scale_translate_jitter.py b/tcn_hpl/data/frame_data_aug/rotate_scale_translate_jitter.py index e3252694c..68b5877c2 100644 --- a/tcn_hpl/data/frame_data_aug/rotate_scale_translate_jitter.py +++ b/tcn_hpl/data/frame_data_aug/rotate_scale_translate_jitter.py @@ -45,39 +45,60 @@ class FrameDataRotateScaleTranslateJitter(torch.nn.Module): Range in degrees for which we will randomly rotate the scene about a chosen center of rotation. The values given should be in ascending order. - location_jitter: - A relative amount to randomly adjust the position, height and width - of frame data. Locations jitter is relative to the height and width - of the frame. Box width and height jitter is relative to the width - and height of the affected box. This should be in the [0, 1] range. + det_loc_jitter: + A relative amount to randomly adjust the position of frame data + object detection locations. This relative amount is applied to the + width and height dimensions of individual detection boxes to yield + the true amount of movement. This should be in the [0, 1] range. + det_wh_jitter: + A relative amount to randomly adjust the width and height of frame + data object detection boxes. This relative amount is applied to the + width and height dimensions of individual detections boxes to yield + the true amount of adjustment. This should be in the [0, 1] range. dets_score_jitter: Randomly adjust the object detection confidence value within +/- - the given value. The resulting value is clamped within the [0, 1] - range. + the given value as multiplied against the input score. + The resulting value is clamped within the [0, 1] range. pose_score_jitter: - Randomly adjust the pose keypoint confidence value within +/- - the given value. The resulting value is clamped within the [0, 1] + Randomly adjust the pose detection confidence value within +/- + the given value as multiplied against the input score. + The resulting value is clamped within the [0, 1] range. + pose_kp_loc_jitter: + A relative amount to randomly adjust the location of keypoint + locations. This relative amount is applied to the width and height + dimensions of the **axis-aligned footprint** of the pose skeleton + to yield the true amount of movement. This should be in the [0, 1] range. + pose_kp_score_jitter: + Randomly adjust the pose keypoint confidence value within +/- + the given value as multiplied against the input score. + The resulting value is clamped within the [0, 1] range. TODO: Some maximum tolerance for accepting boxes that are outside of the image space? """ def __init__( self, - translate: float = 0.1, - scale: tg.Sequence[float] = (0.9, 1.1), - rotate: tg.Sequence[float] = (-10, 10), - location_jitter: float = 0.025, - dets_score_jitter: float = 0.1, - pose_score_jitter: float = 0.1, + translate: float = 0., + scale: tg.Sequence[float] = (1., 1.), + rotate: tg.Sequence[float] = (0., 0.), + det_loc_jitter: float = 0., + det_wh_jitter: float = 0., + dets_score_jitter: float = 0., + pose_score_jitter: float = 0., + pose_kp_loc_jitter: float = 0., + pose_kp_score_jitter: float = 0., ): super().__init__() self.translate = translate self.scale = scale self.rotate = rotate - self.location_jitter = location_jitter + self.det_loc_jitter = det_loc_jitter + self.det_wh_jitter = det_wh_jitter self.dets_score_jitter = dets_score_jitter self.pose_score_jitter = pose_score_jitter + self.pose_kp_loc_jitter = pose_kp_loc_jitter + self.pose_kp_score_jitter = pose_kp_score_jitter def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: # Extract frame size from the first frame (assuming all frames have the same size) @@ -112,7 +133,13 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: translation=(translate_x, translate_y), ) t_to_base = SimilarityTransform(translation=(center_x, center_y)) - transform = t_to_base + t_rotate_scale + t_to_origin + transform = t_to_origin + t_rotate_scale + t_to_base + # While we could just do `transform(2d_coors)`, it's not specifically + # the fastest. Instead, we'll just do the dot product ourselves. + tform = lambda pts: ( + transform.params + @ np.concatenate([pts, np.ones((pts.shape[0], 1))], axis=1).T + )[:2].T # Collect all frame detection boxes and scores in order to batch # transform and jitter. @@ -136,10 +163,6 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: all_pose_kp_scores.append(frame.poses.joint_scores) frame_pose_indices.extend([i] * len(frame.poses.scores)) - location_jitter = self.location_jitter - # Maximum x/y location jitter based on parameter and frame width/height - xy_jitter_max = np.asarray([frame_width, frame_height]) * location_jitter - if all_box_coords: # Pull together all xywh boxes. all_box_coords = np.concatenate(all_box_coords) # shape: [n_dets, 4] @@ -152,15 +175,24 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: # pass. det_rand = torch.rand(n_dets, 5).numpy() - # Jitter box locations and sizes. + # Jitter box locations. + # Basing the jitter on per-box width and height dimensions. The + # intuition is that we should jitter boxes in proportion to their + # screen space so that (a) their relative spatial applicability is + # maintained and (b) an object detector is less likely to provide + # noisy boxes beyond the relative scope of the object's footprint. + xy_jitter_max = all_box_coords[:, 2:] * self.det_loc_jitter xy_jitter = (xy_jitter_max * 2 * det_rand[:, :2]) - xy_jitter_max all_box_coords[:, :2] += xy_jitter - wh_jitter_max = all_box_coords[:, 2:] * location_jitter + # Jitter box sizes. + wh_jitter_max = all_box_coords[:, 2:] * self.det_wh_jitter wh_jitter = (wh_jitter_max * 2 * det_rand[:, 2:4]) - wh_jitter_max all_box_coords[:, 2:] += wh_jitter # Jitter box scores and clamp - score_jitter = (self.dets_score_jitter * 2 * det_rand[:, 4]) - self.dets_score_jitter - all_box_scores += score_jitter + score_jitter = ( + self.dets_score_jitter * 2 * det_rand[:, 4] + ) - self.dets_score_jitter + all_box_scores *= (1 + score_jitter) all_box_scores[all_box_scores < 0] = 0 all_box_scores[all_box_scores > 1] = 1 @@ -185,10 +217,9 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: ).T # shape: [n_dets, 4, 2] # Reshape into a flat list for applying the transform to. corners = corners.reshape(n_dets * 4, 2) - # corners shape now: [4 * n_dets, 2] - corners = transform(corners) # shape: [4* n_dets, 2] + corners = tform(corners) corners = corners.reshape(n_dets, 4, 2) - # corners shape now: [n_dets, 4, 2] + # Get min and max values of transformed boxes for each # detection to create new axis-aligned bounding boxes. x_min = corners[:, :, 0].min(axis=1) # shape: [n_dets] @@ -198,7 +229,10 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: all_box_coords = np.asarray([x_min, y_min, x_max - x_min, y_max - y_min]).T # Create mask for dets that are at least partially in the frame. in_frame = ( - (x_max > 0) & (y_max > 0) & (x_min < frame_width) & (y_min < frame_height) + (x_max > 0) + & (y_max > 0) + & (x_min < frame_width) + & (y_min < frame_height) ) # Filter down detections to those in the frame all_box_coords = all_box_coords[in_frame] @@ -220,20 +254,39 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: pose_score_rand = torch.rand(n_poses).numpy() # Jitter pose score & clamp - score_jitter = (self.pose_score_jitter * 2 * pose_score_rand) - self.pose_score_jitter - all_pose_scores += score_jitter + score_jitter = ( + self.pose_score_jitter * 2 * pose_score_rand + ) - self.pose_score_jitter + all_pose_scores *= (1 + score_jitter) + all_pose_scores[all_pose_scores < 0] = 0 + all_pose_scores[all_pose_scores > 1] = 1 # Jitter pose keypoint locations + # Calculate per-pose containing box width and height to base jitter + # off of instead of the raw frame width and height. The intuition + # is that a pose estimation model is less likely to provide noisy + # keypoints beyond the relative scope of a skeleton's footprint. + # This will mean poses that cover a small footprint are be smeared + # out of proportion relative to poses that cover a much larger + # footprint, which would likely be afforded more noise. + xy_jitter_max = ( + np.asarray([ + all_pose_kps[:, :, 0].max(axis=1) - all_pose_kps[:, :, 0].min(axis=1), + all_pose_kps[:, :, 1].max(axis=1) - all_pose_kps[:, :, 1].min(axis=1), + ]).T * self.pose_kp_loc_jitter + )[:, None, :] # shape: [n_poses, 1, 2] xy_jitter = (xy_jitter_max * 2 * pose_kp_rand[:, :, :2]) - xy_jitter_max all_pose_kps += xy_jitter # Jitter pose keypoint scores & clamp - score_jitter = (self.pose_score_jitter * 2 * pose_kp_rand[:, :, 2]) - self.pose_score_jitter - all_pose_kp_scores += score_jitter + score_jitter = ( + self.pose_kp_score_jitter * 2 * pose_kp_rand[:, :, 2] + ) - self.pose_kp_score_jitter + all_pose_kp_scores *= (1 + score_jitter) all_pose_kp_scores[all_pose_kp_scores < 0] = 0 all_pose_kp_scores[all_pose_kp_scores > 1] = 1 # Transform keypoint locations all_pose_kps = all_pose_kps.reshape(n_poses * n_kps, 2) - all_pose_kps = transform(all_pose_kps) + all_pose_kps = tform(all_pose_kps) all_pose_kps = all_pose_kps.reshape(n_poses, n_kps, 2) # Zero out the scores for any joints that are now out of the frame @@ -252,7 +305,7 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: if window[i].object_detections is not None: # Make sure we emit an instance if there was an instance input # even if the arrays are empty. - frame_i_mask = (frame_dets_indices == i) + frame_i_mask = frame_dets_indices == i new_frame.object_detections = FrameObjectDetections( boxes=all_box_coords[frame_i_mask], labels=all_box_labels[frame_i_mask], @@ -261,7 +314,7 @@ def forward(self, window: tg.Sequence[FrameData]) -> tg.List[FrameData]: if window[i].poses is not None: # Make sure we emit an instance if there was an instance input # even if the arrays are empty. - frame_i_mask = (frame_pose_indices == i) + frame_i_mask = frame_pose_indices == i new_frame.poses = FramePoses( scores=all_pose_scores[frame_i_mask], joint_positions=all_pose_kps[frame_i_mask], @@ -276,21 +329,29 @@ def test(): import matplotlib.pyplot as plt from tcn_hpl.data.frame_data import FrameObjectDetections, FramePoses - torch.manual_seed(0) + torch.manual_seed(42) # Prime the pump torch.rand(1) rng = np.random.RandomState(0) n_poses = 3 pose_scores = rng.uniform(0, 1, n_poses) - pose_joint_locs = rng.randint(0, 500, (n_poses, 22, 2)).astype(float) pose_joint_scores = rng.uniform(0, 1, (n_poses, 22)) + n_kps = 22 + pose_joint_locs = np.stack([ + # One "pose" in the upper-right + np.stack([rng.uniform(250, 500, n_kps), rng.uniform(0, 250, n_kps)], axis=1), + # One "pose" covering the bottom half + np.stack([rng.uniform(0, 500, n_kps), rng.uniform(250, 500, n_kps)], axis=1), + # One "pose" that's small in the center + np.stack([rng.uniform(225, 275, n_kps), rng.uniform(225, 275, n_kps)], axis=1), + ]) target_wh = 500, 500 frame1 = FrameData( # 3 detections object_detections=FrameObjectDetections( - boxes=np.array([[10., 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]), + boxes=np.array([[10.0, 20, 30, 40], [50, 60, 70, 80], [90, 100, 110, 120]]), labels=np.array([1, 2, 3]), scores=np.array([0.9, 0.75, 0.11]), ), @@ -304,18 +365,31 @@ def test(): ) window = [frame1] * 25 - augment = FrameDataRotateScaleTranslateJitter() + augment = FrameDataRotateScaleTranslateJitter( + translate=0.05, + scale=[0.9, 1.1], + rotate=[-5, 5], + det_loc_jitter=0.02, + det_wh_jitter=0.02, + pose_kp_loc_jitter=0.005, + # dets_score_jitter=0.05, + # pose_score_jitter=0.05, + # pose_kp_score_jitter=0.05, + ) + + # Idiot check. + augment(window) ipython = get_ipython() if ipython is not None: ipython.run_line_magic("timeit", "augment(window)") # Visualize detection boxes before augmentation - fig, axes = plt.subplots(3, 3, figsize=(15, 15)) + fig, axes = plt.subplots(5, 5, figsize=(25, 25)) axes = axes.ravel() - axes[0].set_title("Before Augmentation of Frame 0") + axes[0].set_title("Before Augmentation") axes[0].set_xlim(0, target_wh[0]) - axes[0].set_ylim(0, target_wh[1]) + axes[0].set_ylim(target_wh[1], 0) for frame in window: for box in frame.object_detections.boxes: x, y, w, h = box @@ -327,11 +401,11 @@ def test(): axes[0].plot(pose_kp[:, 0], pose_kp[:, 1]) # For every other axes, show a different augmentation example - for ax_i in range(1, 9): + for ax_i in range(1, len(axes)): aug_window = augment(window) - axes[ax_i].set_title(f"Separate Augmentation [{ax_i}] of Frame 0") + axes[ax_i].set_title(f"Separate Augmentation [{ax_i}]") axes[ax_i].set_xlim(0, target_wh[0]) - axes[ax_i].set_ylim(0, target_wh[1]) + axes[ax_i].set_ylim(target_wh[1], 0) for frame in aug_window: for box in frame.object_detections.boxes: x, y, w, h = box @@ -352,7 +426,7 @@ def test(): for f_i, frame in enumerate(augmented_window): axes[f_i].set_title(f"Augmented Frame [{f_i}]") axes[f_i].set_xlim(0, target_wh[0]) - axes[f_i].set_ylim(0, target_wh[1]) + axes[f_i].set_ylim(target_wh[1], 0) for box in frame.object_detections.boxes: x, y, w, h = box rect = plt.Rectangle( @@ -365,5 +439,6 @@ def test(): plt.savefig("FrameDataRotateScaleTranslateJitter_allFrames.png") plt.close(fig) + if __name__ == "__main__": test() diff --git a/tcn_hpl/data/ptg_datamodule.py b/tcn_hpl/data/ptg_datamodule.py index 9d9ba070a..68e514b8d 100644 --- a/tcn_hpl/data/ptg_datamodule.py +++ b/tcn_hpl/data/ptg_datamodule.py @@ -14,7 +14,7 @@ def create_dataset_from_hydra( model_hydra_conf: Path, - split: str = "test", + split: str = "pred", ) -> "TCNDataset": """ Create a TCNDataset for some specified split based on the Hydra @@ -95,6 +95,7 @@ def __init__( train_dataset: TCNDataset, val_dataset: TCNDataset, test_dataset: TCNDataset, + pred_dataset: TCNDataset, coco_train_activities: str, coco_train_objects: str, coco_train_poses: str, @@ -107,8 +108,8 @@ def __init__( batch_size: int, num_workers: int, target_framerate: float, - epoch_length: int, - pin_memory: bool, + epoch_sample_factor: float, + pin_memory: bool = False, ) -> None: """Initialize a `PTGDataModule`. @@ -130,10 +131,13 @@ def __init__( object detections to use for training. :param coco_test_poses: Path to the COCO file with test-split pose estimations to use for training. - :vector_cache_dir: Directory path to store cache files related to - dataset vectory computation. :param batch_size: The batch size. Defaults to `64`. :param num_workers: The number of workers. Defaults to `0`. + :param target_framerate: Hz rate for loaded datasets to be checked + against and normalized to if there is faster rate data in the mix. + :param epoch_sample_factor: A multiplicative factor on the size of a + dataset for a weighted random sampler to sample over. This is + currently applicable to the train and validation dataloaders. :param pin_memory: Whether to pin memory. Defaults to `False`. """ super().__init__() @@ -141,12 +145,19 @@ def __init__( # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt self.save_hyperparameters( - logger=False, ignore=["train_dataset", "val_dataset", "test_dataset"] + logger=False, + ignore=[ + "train_dataset", + "val_dataset", + "test_dataset", + "pred_dataset", + ], ) self.data_train: Optional[TCNDataset] = train_dataset self.data_val: Optional[TCNDataset] = val_dataset self.data_test: Optional[TCNDataset] = test_dataset + self.data_pred: Optional[TCNDataset] = pred_dataset def setup(self, stage: Optional[str] = None) -> None: """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. @@ -159,19 +170,21 @@ def setup(self, stage: Optional[str] = None) -> None: :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ # load and split datasets only if not loaded already - if not self.data_train and not self.data_val and not self.data_test: + if (stage == "train" or stage == "fit") and not self.data_train: self.data_train.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_train_activities), kwcoco.CocoDataset(self.hparams.coco_train_objects), kwcoco.CocoDataset(self.hparams.coco_train_poses), self.hparams.target_framerate, ) + if stage == "validate" and not self.data_val: self.data_val.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_validation_activities), kwcoco.CocoDataset(self.hparams.coco_validation_objects), kwcoco.CocoDataset(self.hparams.coco_validation_poses), self.hparams.target_framerate, ) + if stage == "test" and not self.data_test: self.data_test.load_data_offline( kwcoco.CocoDataset(self.hparams.coco_test_activities), kwcoco.CocoDataset(self.hparams.coco_test_objects), @@ -186,7 +199,7 @@ def train_dataloader(self) -> DataLoader[Any]: """ train_sampler = torch.utils.data.WeightedRandomSampler( self.data_train.window_weights, - self.hparams.epoch_length, + int(round(self.hparams.epoch_sample_factor * len(self.data_train))), replacement=True, generator=None, ) @@ -203,19 +216,19 @@ def val_dataloader(self) -> DataLoader[Any]: :return: The validation dataloader. """ - val_sampler = torch.utils.data.WeightedRandomSampler( - self.data_val.window_weights, - len(self.data_val) * 3, - replacement=True, - generator=None, - ) + # val_sampler = torch.utils.data.WeightedRandomSampler( + # self.data_val.window_weights, + # self.hparams.epoch_sample_factor * len(self.data_val), + # replacement=True, + # generator=None, + # ) return DataLoader( dataset=self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, pin_memory=self.hparams.pin_memory, - # shuffle=False, - sampler=val_sampler, + shuffle=False, + # sampler=val_sampler, ) def test_dataloader(self) -> DataLoader[Any]: diff --git a/tcn_hpl/data/tcn_dataset.py b/tcn_hpl/data/tcn_dataset.py index a94e1ee1f..728feddac 100644 --- a/tcn_hpl/data/tcn_dataset.py +++ b/tcn_hpl/data/tcn_dataset.py @@ -340,7 +340,9 @@ class that has pose keypoints associated with it. The current else: frame_poses = empty_pose # import ipdb; ipdb.set_trace() - vid_frame_data.append(FrameData(frame_dets, frame_poses, frame_size)) + vid_frame_data.append( + FrameData(frame_dets, frame_poses, frame_size) + ) # Compose a list of indices into frame_data that this video's # worth of content resides. @@ -520,31 +522,37 @@ def test_dataset_for_input( from tcn_hpl.data.vectorize.locs_and_confs import LocsAndConfs vectorize = LocsAndConfs( - top_k = 1, - num_classes = 7, - use_joint_confs = True, - use_pixel_norm = False, - use_joint_obj_offsets = False, - background_idx = 0, + top_k=1, + num_classes=7, + use_joint_confs=True, + use_pixel_norm=True, + use_joint_obj_offsets=False, + background_idx=0, ) # TODO: Some method of configuring which augmentations to use. - from tcn_hpl.data.frame_data_aug.rotate_scale_translate_jitter import FrameDataRotateScaleTranslateJitter - from tcn_hpl.data.frame_data_aug.window_frame_dropout import DropoutFrameDataTransform + from tcn_hpl.data.frame_data_aug.rotate_scale_translate_jitter import ( + FrameDataRotateScaleTranslateJitter, + ) + from tcn_hpl.data.frame_data_aug.window_frame_dropout import ( + DropoutFrameDataTransform, + ) import torchvision.transforms - transform_frame_data = torchvision.transforms.Compose([ - DropoutFrameDataTransform( - frame_rate=15, - dets_throughput_mean=14.5, - pose_throughput_mean=10, - dets_latency=0, - pose_latency=1/10, # (1 / 10) - (1 / 14.5), - dets_throughput_std=0.2, - pose_throughput_std=0.2, - ), - FrameDataRotateScaleTranslateJitter(), - ]) + transform_frame_data = torchvision.transforms.Compose( + [ + DropoutFrameDataTransform( + frame_rate=15, + dets_throughput_mean=14.5, + pose_throughput_mean=10, + dets_latency=0, + pose_latency=1 / 10, # (1 / 10) - (1 / 14.5), + dets_throughput_std=0.2, + pose_throughput_std=0.2, + ), + FrameDataRotateScaleTranslateJitter(), + ] + ) dataset = TCNDataset( window_size=window_size, diff --git a/tcn_hpl/models/ptg_module.py b/tcn_hpl/models/ptg_module.py index f396de9ef..e747e4c38 100644 --- a/tcn_hpl/models/ptg_module.py +++ b/tcn_hpl/models/ptg_module.py @@ -1,12 +1,11 @@ from typing import Any, Dict, Optional, Tuple, Union, List -from numpy.lib.utils import source from pytorch_lightning import LightningModule import torch from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from torch import nn import torch.nn.functional as F -from torchmetrics import MaxMetric, MeanMetric +from torchmetrics import MaxMetric, MeanMetric, MetricCollection from torchmetrics.classification.accuracy import Accuracy from torchmetrics.classification import F1Score, Recall, Precision @@ -82,66 +81,50 @@ def __init__( self.criterion = criterion self.mse = nn.MSELoss(reduction="none") - # metric objects for calculating and averaging accuracy across batches - self.train_acc = Accuracy( - task="multiclass", average="weighted", num_classes=num_classes - ) - self.val_acc = Accuracy( - task="multiclass", average="weighted", num_classes=num_classes - ) - self.test_acc = Accuracy( - task="multiclass", average="weighted", num_classes=num_classes - ) - # Track per-class accuracy for separated logging - self.train_acc_perclass = Accuracy( - task="multiclass", average="none", num_classes=num_classes - ) - self.val_acc_perclass = Accuracy( - task="multiclass", average="none", num_classes=num_classes - ) - self.test_acc_perclass = Accuracy( - task="multiclass", average="none", num_classes=num_classes - ) - - self.train_f1 = F1Score( - num_classes=num_classes, average="weighted", task="multiclass" - ) - self.val_f1 = F1Score( - num_classes=num_classes, average="weighted", task="multiclass" - ) - self.test_f1 = F1Score( - num_classes=num_classes, average="weighted", task="multiclass" - ) - # Track per-class F1 for separated logging - self.train_f1_perclass = F1Score( - num_classes=num_classes, average="none", task="multiclass" - ) - self.val_f1_perclass = F1Score( - num_classes=num_classes, average="none", task="multiclass" - ) - self.test_f1_perclass = F1Score( - num_classes=num_classes, average="none", task="multiclass" - ) - - self.train_recall = Recall( - num_classes=num_classes, average="weighted", task="multiclass" - ) - self.val_recall = Recall( - num_classes=num_classes, average="weighted", task="multiclass" - ) - self.test_recall = Recall( - num_classes=num_classes, average="weighted", task="multiclass" - ) - - self.train_precision = Precision( - num_classes=num_classes, average="weighted", task="multiclass" - ) - self.val_precision = Precision( - num_classes=num_classes, average="weighted", task="multiclass" + # We only want to validation metric logging if training has actually + # started, i.e. not during the sanity checking phase. + self.has_training_started = False + + # Metric objects for calculating and averaging accuracy across batches + # Various metrics are reset at the **beginning** of epochs due to the + # desire to access metrics by callbacks and the **end** of epochs, + # which would be hard to do if we reset at the end before they get a + # chance to use those values... + self.train_metrics = MetricCollection( + { + "acc": Accuracy( + task="multiclass", num_classes=num_classes, average="weighted" + ), + "f1": F1Score( + task="multiclass", num_classes=num_classes, average="weighted" + ), + "recall": Recall( + task="multiclass", num_classes=num_classes, average="weighted" + ), + "precsion": Precision( + task="multiclass", num_classes=num_classes, average="weighted" + ), + }, + prefix="train/" ) - self.test_precision = Precision( - num_classes=num_classes, average="weighted", task="multiclass" + self.val_metrics = self.train_metrics.clone(prefix="val/") + self.test_metrics = self.train_metrics.clone(prefix="test/") + + # Some metrics that output per-class vectors. These will have to be + # logged manually in a loop since only scalars can be logged. + self.train_vec_metrics = MetricCollection( + { + "acc": Accuracy( + task="multiclass", num_classes=num_classes, average="none" + ), + "f1": F1Score( + task="multiclass", num_classes=num_classes, average="none" + ), + }, + prefix="train/" ) + self.val_vec_metrics = self.train_vec_metrics.clone(prefix="val/") + self.test_vec_metrics = self.train_vec_metrics.clone(prefix="test/") # for averaging loss across batches self.train_loss = MeanMetric() @@ -162,14 +145,7 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> torch.Tensor: def on_train_start(self) -> None: """Lightning hook that is called when training begins.""" - # by default lightning executes validation step sanity checks before training starts, - # so it's worth to make sure validation metrics don't store results from these checks - self.val_loss.reset() - self.val_acc.reset() - self.val_f1.reset() - self.val_recall.reset() - self.val_precision.reset() - self.val_f1_best.reset() + self.has_training_started = True def compute_loss(self, p, y, mask): """Compute the total loss for a batch @@ -269,6 +245,11 @@ def model_step( return loss, probs, preds, y, source_vid, source_frame + def on_train_epoch_start(self) -> None: + # Reset relevant metric collections + self.train_metrics.reset() + self.train_vec_metrics.reset() + def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -285,27 +266,14 @@ def training_step( """ loss, probs, preds, targets, source_vid, source_frame = self.model_step(batch) - # update and log metrics + # update and log loss + # Don't want to log this on step because it causes some loggers (CSV) + # to create some pretty unreadable output. self.train_loss(loss) - self.train_acc(preds, targets[:, self.hparams.pred_frame_index]) - - self.log( - "train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True - ) self.log( - "train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True + "train/loss", self.train_loss, prog_bar=True, on_step=False, on_epoch=True ) - self.train_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) - for c_i, c_acc in enumerate(self.train_acc_perclass.compute()): - self.log( - f"train/acc-per-class/c{c_i}", - c_acc, - prog_bar=False, - on_step=False, - on_epoch=True, - ) - # return loss or backpropagation will fail return { "loss": loss, @@ -320,21 +288,18 @@ def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: all_preds = torch.cat([o["preds"] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) - self.train_f1(all_preds, all_targets) - self.train_recall(all_preds, all_targets) - self.train_precision(all_preds, all_targets) - self.log("train/f1", self.train_f1, prog_bar=True, on_epoch=True) - self.log("train/recall", self.train_recall, prog_bar=True, on_epoch=True) - self.log("train/precision", self.train_precision, prog_bar=True, on_epoch=True) - # vector metrics - self.train_f1_perclass(all_preds, all_targets) - for c_i, c_f1 in enumerate(self.train_f1_perclass.compute()): - self.log( - f"train/f1-per-class/c{c_i}", - c_f1, - prog_bar=False, - on_epoch=True, - ) + scalar_metrics = self.train_metrics(all_preds, all_targets) + self.log_dict(scalar_metrics, prog_bar=True, on_epoch=True) + + vec_metrics = self.train_vec_metrics(all_preds, all_targets) + for k, t in vec_metrics.items(): + for i, v in enumerate(t): + self.log(f"{k}-class_{i}", v) + + def on_validation_epoch_start(self) -> None: + # Reset relevant metric collections + self.val_metrics.reset() + self.val_vec_metrics.reset() def validation_step( self, @@ -353,20 +318,7 @@ def validation_step( # update and log metrics self.val_loss(loss) - self.val_acc(preds, targets[:, self.hparams.pred_frame_index]) - - self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - - self.val_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) - for c_i, c_acc in enumerate(self.val_acc_perclass.compute()): - self.log( - f"val/acc-per-class/c{c_i}", - c_acc, - prog_bar=False, - on_step=False, - on_epoch=True, - ) + self.log("val/loss", self.val_loss, prog_bar=True) # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. @@ -380,31 +332,31 @@ def validation_step( } def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None: + if not self.has_training_started: + return + all_preds = torch.cat([o['preds'] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) - self.val_f1(all_preds, all_targets) - self.val_recall(all_preds, all_targets) - self.val_precision(all_preds, all_targets) - self.log("val/f1", self.val_f1, prog_bar=True, on_epoch=True) - self.log("val/recall", self.val_recall, prog_bar=True, on_epoch=True) - self.log("val/precision", self.val_precision, prog_bar=True, on_epoch=True) - # vector metrics - self.val_f1_perclass(all_preds, all_targets) - for c_i, c_f1 in enumerate(self.val_f1_perclass.compute()): - self.log( - f"val/f1-per-class/c{c_i}", - c_f1, - prog_bar=False, - on_epoch=True, - ) + scalar_metrics = self.val_metrics(all_preds, all_targets) + self.log_dict(scalar_metrics, prog_bar=True) + + vec_metrics = self.val_vec_metrics(all_preds, all_targets) + for k, t in vec_metrics.items(): + for i, v in enumerate(t): + self.log(f"{k}-class_{i}", v) # log `val_f1_best` as a value through `.compute()` return, instead of # as a metric object otherwise metric would be reset by lightning after # each epoch. - self.val_f1_best(self.val_f1.compute()) + self.val_f1_best(self.val_metrics.f1.compute()) self.log("val/f1_best", self.val_f1_best.compute(), prog_bar=True, on_epoch=True) + def on_test_epoch_start(self) -> None: + # Reset relevant metric collections + self.test_metrics.reset() + self.test_vec_metrics.reset() + def test_step( self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], @@ -422,21 +374,7 @@ def test_step( # update and log metrics self.test_loss(loss) - self.test_acc(preds, targets[:, self.hparams.pred_frame_index]) - self.log( - "test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True - ) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) - - self.test_acc_perclass(preds, targets[:, self.hparams.pred_frame_index]) - for c_i, c_acc in enumerate(self.test_acc_perclass.compute()): - self.log( - f"test/acc-per-class/c{c_i}", - c_acc, - prog_bar=False, - on_step=False, - on_epoch=True, - ) + self.log("test/loss", self.test_loss, prog_bar=True) # Only retain the truth and source vid/frame IDs for the final window # frame as this is the ultimately relevant result. @@ -453,22 +391,13 @@ def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> No all_preds = torch.cat([o['preds'] for o in outputs]) all_targets = torch.cat([o['targets'] for o in outputs]) - # update and log metrics - self.test_f1(all_preds, all_targets) - self.test_recall(all_preds, all_targets) - self.test_precision(all_preds, all_targets) - self.log("test/f1", self.test_f1, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/recall", self.test_recall, on_step=False, on_epoch=True, prog_bar=True) - self.log("test/precision", self.test_precision, on_step=False, on_epoch=True, prog_bar=True) - # vector metrics - self.test_f1_perclass(all_preds, all_targets) - for c_i, c_f1 in enumerate(self.test_f1_perclass.compute()): - self.log( - f"test/f1-per-class/c{c_i}", - c_f1, - prog_bar=False, - on_epoch=True, - ) + scalar_metrics = self.test_metrics(all_preds, all_targets) + self.log_dict(scalar_metrics, prog_bar=True, on_epoch=True) + + vec_metrics = self.test_vec_metrics(all_preds, all_targets) + for k, t in vec_metrics.items(): + for i, v in enumerate(t): + self.log(f"{k}-class_{i}", v) def setup(self, stage: Optional[str] = None) -> None: """Lightning hook that is called at the beginning of fit (train + validate), validate, @@ -498,7 +427,7 @@ def configure_optimizers(self) -> Dict[str, Any]: "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, - "monitor": "val/loss", + "monitor": "train/loss", "interval": "epoch", "frequency": 1, },