Skip to content

Commit

Permalink
Merge pull request #44 from Purg/dev/model-testing-selectable-window-idx
Browse files Browse the repository at this point in the history
Experimentally Drive Feature Enhancements
  • Loading branch information
Purg authored Nov 14, 2024
2 parents 9eb64da + 252cc79 commit d904388
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 68 deletions.
2 changes: 2 additions & 0 deletions configs/data/ptg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ _target_: tcn_hpl.data.ptg_datamodule.PTGDataModule
train_dataset:
_target_: tcn_hpl.data.tcn_dataset.TCNDataset
window_size: 15
window_label_idx: -1
# A vectorizer is required to complete construction of a TCN Dataset.
# We are not providing a default here given how hydra merged hyperparameters.
# For example:
Expand All @@ -21,6 +22,7 @@ train_dataset:
val_dataset:
_target_: tcn_hpl.data.tcn_dataset.TCNDataset
window_size: ${data.train_dataset.window_size}
window_label_idx: ${data.train_dataset.window_label_idx}
vectorize: ${data.train_dataset.vectorize}
transform_frame_data:
_target_: torchvision.transforms.Compose
Expand Down
36 changes: 26 additions & 10 deletions configs/experiment/m2/feat_locsconfs.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example
topic: "medical"
task: "m2"

defaults:
- override /data: ptg
- override /model: ptg
Expand Down Expand Up @@ -33,17 +28,23 @@ tags: ["m2", "ms_tcn", "debug"]

seed: 12345

#callbacks:
# model_checkpoint:
# # save all ~80MB checkpoints for post-training investigation.
# # Total: ~45GB
# save_top_k: 500

trainer:
min_epochs: 50
max_epochs: 500
log_every_n_steps: 1

model:
num_classes: 9
num_classes: 9 # number of activity classification classes
compile: false
net:
# Length of feature vector for a single frame.
# Currently derived from feature version and other hyperparameters.
# Currently derived from the parameterization of dataset vectorizer.
dim: 102

# # Once upon a time defaults
Expand All @@ -67,11 +68,14 @@ data:
batch_size: 512
num_workers: 16
target_framerate: 15 # BBN Hololens2 Framerate
# This is a little more than the number of windows in the training dataset.
epoch_length: 100000
# This is a bit more than the number of windows in the training dataset so
# the weighted sampler has more of an opportunity to sample the space
# proportionally.
epoch_length: 300000

train_dataset:
window_size: 25
window_label_idx: ${model.pred_frame_index}
vectorize:
_target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs
top_k: 1
Expand All @@ -93,11 +97,23 @@ data:
pose_latency: 0.1
dets_throughput_std: 0.2
pose_throughput_std: 0.2
fixed_pattern: false
val_dataset:
# Augmentations on windows of frame data before performing vectorization.
# Sharing transform with training dataset as it is only the drop-out aug to
# simulate stream processing dropout the same.
transform_frame_data: ${data.train_dataset.transform_frame_data}
transform_frame_data:
transforms:
- _target_: tcn_hpl.data.frame_data_aug.window_frame_dropout.DropoutFrameDataTransform
# Mirror training hparams, except used fixed patterns.
frame_rate: ${data.train_dataset.frame_rate}
dets_throughput_mean: ${data.train_dataset.dets_throughput_mean}
pose_throughput_mean: ${data.train_dataset.pose_throughput_mean}
dets_latency: ${data.train_dataset.dets_latency}
pose_latency: ${data.train_dataset.pose_latency}
dets_throughput_std: ${data.train_dataset.dets_throughput_std}
pose_throughput_std: ${data.train_dataset.pose_throughput_std}
fixed_pattern: true
# Test dataset usually configured the same as val, unless there is some
# different set of transforms that should be used during test/prediction.

Expand Down
26 changes: 18 additions & 8 deletions configs/experiment/r18/feat_locsconfs.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example
topic: "medical"
task: "r18"

defaults:
- override /data: ptg
- override /model: ptg
Expand Down Expand Up @@ -67,11 +62,14 @@ data:
batch_size: 512
num_workers: 16
target_framerate: 15 # BBN Hololens2 Framerate
# This is a little more than the number of windows in the training dataset.
epoch_length: 80000
# This is a bit more than the number of windows in the training dataset so
# the weighted sampler has more of an opportunity to sample the space
# proportionally.
epoch_length: 300000

train_dataset:
window_size: 25
window_label_idx: ${model.pred_frame_index}
vectorize:
_target_: tcn_hpl.data.vectorize.locs_and_confs.LocsAndConfs
top_k: 1
Expand All @@ -93,11 +91,23 @@ data:
pose_latency: 0.1
dets_throughput_std: 0.2
pose_throughput_std: 0.2
fixed_pattern: false
val_dataset:
# Augmentations on windows of frame data before performing vectorization.
# Sharing transform with training dataset as it is only the drop-out aug to
# simulate stream processing dropout the same.
transform_frame_data: ${data.train_dataset.transform_frame_data}
transform_frame_data:
transforms:
- _target_: tcn_hpl.data.frame_data_aug.window_frame_dropout.DropoutFrameDataTransform
# Mirror training hparams, except used fixed patterns.
frame_rate: ${data.train_dataset.frame_rate}
dets_throughput_mean: ${data.train_dataset.dets_throughput_mean}
pose_throughput_mean: ${data.train_dataset.pose_throughput_mean}
dets_latency: ${data.train_dataset.dets_latency}
pose_latency: ${data.train_dataset.pose_latency}
dets_throughput_std: ${data.train_dataset.dets_throughput_std}
pose_throughput_std: ${data.train_dataset.pose_throughput_std}
fixed_pattern: true
# Test dataset usually configured the same as val, unless there is some
# different set of transforms that should be used during test/prediction.

Expand Down
5 changes: 4 additions & 1 deletion configs/model/ptg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ smoothing_loss: 0.0015
use_smoothing_loss: False

# Number of classes
num_classes: ${data.num_classes}
num_classes: 9

# compile model for faster training with pytorch 2.0
compile: false

# Which frame in a window of predictions should represent the while window.
pred_frame_index: -1
76 changes: 50 additions & 26 deletions tcn_hpl/data/frame_data_aug/window_frame_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class DropoutFrameDataTransform(torch.nn.Module):
Standard deviation of the throughput rate for object detections.
pose_throughput_std:
Standard deviation of the throughput rate for pose estimations.
fixed_pattern:
Create a single, fixed dropout pattern to be applied to every
window based on the input throughput and latency mean values with
no random variation. This is idea to use for validation and test
dataset passes that require dropout simulation but do not want
random variation.
"""

def __init__(
Expand All @@ -54,6 +60,7 @@ def __init__(
pose_latency: Optional[float] = None,
dets_throughput_std: float = 0.0,
pose_throughput_std: float = 0.0,
fixed_pattern: bool = False,
):
super().__init__()
self.frame_rate = frame_rate
Expand All @@ -68,6 +75,7 @@ def __init__(
self.pose_latency = (
pose_latency if pose_latency is not None else 1.0 / pose_throughput_mean
)
self.fixed_pattern = fixed_pattern

def forward(self, window: Sequence[FrameData]) -> List[FrameData]:
# Starting from some latency back from the end of the window, start
Expand All @@ -94,32 +102,41 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]:

# Define processing intervals (how often a frame is processed)
# This cursed formatting is because of `black`.
dets_interval = (
1.0
/ torch.normal(
mean=self.dets_throughput_mean,
std=self.dets_throughput_std,
size=(n_frames,),
).numpy()
)
pose_interval = (
1.0
/ torch.normal(
mean=self.pose_throughput_mean,
std=self.pose_throughput_std,
size=(n_frames,),
).numpy()
)
if self.fixed_pattern:
dets_interval = 1.0 / np.full(n_frames, self.dets_throughput_mean)
pose_interval = 1.0 / np.full(n_frames, self.pose_throughput_mean)
# Fixed simulation of half-way into processing previous frame.
dets_initial_end = 0.5 * dets_interval[0]
pose_initial_end = 0.5 * pose_interval[0]
else:
dets_interval = (
1.0
/ torch.normal(
mean=self.dets_throughput_mean,
std=self.dets_throughput_std,
size=(n_frames,),
).numpy()
)
pose_interval = (
1.0
/ torch.normal(
mean=self.pose_throughput_mean,
std=self.pose_throughput_std,
size=(n_frames,),
).numpy()
)
dets_initial_end = torch.rand(1).item() * dets_interval[0]
pose_initial_end = torch.rand(1).item() * pose_interval[0]

# Initialize end time trackers for processing detections and poses.
# Simulate that agents may already be part-way through processing a
# frame before the beginning of this window, utilizing the first value
# from respective interval vectors.
dets_processing_end = np.full(
n_frames + 1, torch.rand(1).item() * dets_interval[0]
n_frames + 1, dets_initial_end
)
pose_processing_end = np.full(
n_frames + 1, torch.rand(1).item() * pose_interval[0]
n_frames + 1, pose_initial_end
)

# Boolean arrays to keep track of whether a frame can be processed
Expand Down Expand Up @@ -184,7 +201,8 @@ def forward(self, window: Sequence[FrameData]) -> List[FrameData]:


def test():
import numpy as np
from IPython.core.getipython import get_ipython
import pandas as pd
from tcn_hpl.data.frame_data import FrameObjectDetections, FramePoses

frame1 = FrameData(
Expand Down Expand Up @@ -216,18 +234,24 @@ def test():
pose_latency=1/10, # (1 / 10) - (1 / 14.5),
dets_throughput_std=0.2,
pose_throughput_std=0.2,
fixed_pattern=True,
)
modified_sequence = transform(sequence)

for idx, frame in enumerate(modified_sequence):
print(
f"Frame {idx}: Object Detections: {frame.object_detections is not None}, Poses: {frame.poses is not None}"
)

from IPython import get_ipython
print(
pd.DataFrame({
"object detections": [
frame.object_detections is not None for frame in modified_sequence
],
"pose estimation": [
frame.poses is not None for frame in modified_sequence
]
})
)

ipython = get_ipython()
ipython.run_line_magic("timeit", "transform(sequence)")
if ipython is not None:
ipython.run_line_magic("timeit", "transform(sequence)")


if __name__ == "__main__":
Expand Down
9 changes: 8 additions & 1 deletion tcn_hpl/data/ptg_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,19 @@ def val_dataloader(self) -> DataLoader[Any]:
:return: The validation dataloader.
"""
val_sampler = torch.utils.data.WeightedRandomSampler(
self.data_val.window_weights,
len(self.data_val) * 3,
replacement=True,
generator=None,
)
return DataLoader(
dataset=self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
# shuffle=False,
sampler=val_sampler,
)

def test_dataloader(self) -> DataLoader[Any]:
Expand Down
16 changes: 12 additions & 4 deletions tcn_hpl/data/tcn_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,33 @@ class TCNDataset(Dataset):
window_size:
The size of the sliding window used to collect inputs from either a
real-time or offline source.
vectorize:
Vectorization functor to convert frame data into an embedding
space.
window_label_idx:
Indicate which frame in a window of frames the window's truth label
should be drawn from. E.g. `-1` means assign the truth of the
window from the truth of the last frame in the window, `5` means
assign the truth of the window from the truth of the 5th frame in
the window, etc.
transform_frame_data:
Optional augmentation function that operates on a window of
FrameData before being input to vectorization. Such an augmentation
function should not modify the input FrameData.
vectorize:
Vectorization functor to convert frame data into an embedding
space.
"""

def __init__(
self,
window_size: int,
vectorize: Vectorize,
window_label_idx: int = -1,
transform_frame_data: Optional[
Callable[[Sequence[FrameData]], Sequence[FrameData]]
] = None,
):
self.window_size = window_size
self.vectorize = vectorize
self.window_label_idx = window_label_idx
self.transform_frame_data = transform_frame_data

# For offline mode, pre-cut videos into clips according to window
Expand Down Expand Up @@ -376,7 +384,7 @@ class that has pose keypoints associated with it. The current

# Collect for weighting the truth labels for the final frames of
# windows, which is the truth value for the window as a whole.
window_final_class_ids = self._window_truth[:, -1]
window_final_class_ids = self._window_truth[:, self.window_label_idx]
cls_ids, cls_counts = np.unique(window_final_class_ids, return_counts=True)
# Some classes may not be represented in the truth, so initialize the
# weights vector separately, and then assign weight values based on
Expand Down
Loading

0 comments on commit d904388

Please sign in to comment.