Skip to content

Commit

Permalink
Merge pull request #28 from periakiva/dev-dset
Browse files Browse the repository at this point in the history
Dataset development that supports both offline and online input.
  • Loading branch information
Purg authored Oct 11, 2024
2 parents eb916b3 + 4ce16f1 commit 46e792a
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 13 deletions.
2 changes: 1 addition & 1 deletion configs/experiment/m2/feat_v6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ data:
num_workers: 16
epoch_length: 20000
window_size: 25
sample_rate: 2
sample_rate: 1

all_transforms:
train_order: [] #["MoveCenterPts", "NormalizePixelPts"]
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/m5/feat_v6.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ data:
num_workers: 12
epoch_length: 20000
window_size: 25
sample_rate: 2
sample_rate: 1


all_transforms:
Expand Down
136 changes: 136 additions & 0 deletions tcn_hpl/data/add_gt_to_kwcoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import argparse

import kwcoco
import numpy as np
import numpy.typing as npt
import tcn_hpl.utils.utils as utils
import ubelt as ub
import yaml

from angel_system.data.medical.data_paths import LAB_TASK_TO_NAME


def text_to_labels(
text_file: str, num_frames: int, task: str, mapping: dict
) -> npt.NDArray[int]:
"""
Convert a "skill_labels_by_frame" text truth file from BBN into labels for
the given task and number of frames.
:param text_file: Filesystem path to the BBN activity text file.
:param num_frames: Number of frames in the video the truth file is related
to.
:param task: The identifying name of the task, e.g. "m2", "m3", "r18", etc.
:param mapping: Mapping of task step descriptions to the integer label
value for that step.
:return:
"""
# set background to everything first (assuming value 0).
activity_gt_list = np.zeros(num_frames)
f = open(text_file, "r")
text = f.read()
f.close()
text = text.replace("\n", "\t")
text_list = text.split("\t")
if text_list[-1] == "":
text_list = text_list[:-1]

# this check handles inconsistencies in the GT we get from BBN
if task == "r18" or task == "m3":
jump = 4
elif task == "m2" or task == "m5":
jump = 3

for index in range(0, len(text_list), jump):
triplet = text_list[index : index + jump]
start_frame = int(triplet[0])
end_frame = int(triplet[1])
desc = triplet[jump - 1]

gt_label = mapping[desc]

if end_frame - 1 > num_frames:
### address issue with GT activity labels
print("Max frame in GT is larger than number of frames in the video")

for label_index in range(start_frame, min(end_frame - 1, num_frames)):
activity_gt_list[label_index] = gt_label

return activity_gt_list


def main(config_path: str):
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)

task_name = config["task"]
raw_data_root = (
f"{config['data_gen']['raw_data_root']}/{LAB_TASK_TO_NAME[task_name]}/"
)

dset = kwcoco.CocoDataset(config["data_gen"]["dataset_kwcoco"])

with open(config["data_gen"]["activity_config_fn"], "r") as stream:
activity_config = yaml.safe_load(stream)
activity_labels = activity_config["labels"]

activity_labels_desc_mapping = {}
activity_labels_label_mapping = {}
for label in activity_labels:
i = label["id"]
label_str = label["label"]
if "description" in label.keys():
activity_labels_desc_mapping[label["description"]] = label["id"]
elif "full_str" in label.keys():
activity_labels_desc_mapping[label["full_str"]] = label["id"]
activity_labels_label_mapping[label_str] = label["id"]
if label_str == "done":
continue

gt_paths_to_names_dict = {}
gt_paths = utils.dictionary_contents(raw_data_root, types=["*.txt"])
for gt_path in gt_paths:
name = gt_path.split("/")[-1].split(".")[0]
gt_paths_to_names_dict[name] = gt_path

print(gt_paths_to_names_dict)

if not "activity_gt" in list(dset.imgs.values())[0].keys():
print("adding activity ground truth to the dataset")
for video_id in ub.ProgIter(dset.index.videos.keys()):
video = dset.index.videos[video_id]
video_name = video["name"]
if video_name in gt_paths_to_names_dict.keys():
gt_text = gt_paths_to_names_dict[video_name]
else:
print(f"GT file does not exist for {video_name}. Continue...")
continue

image_ids = dset.index.vidid_to_gids[video_id]
num_frames = len(image_ids)

activity_gt_list = text_to_labels(
gt_text, num_frames, task_name, activity_labels_desc_mapping
)

for index, img_id in enumerate(image_ids):
im = dset.index.imgs[img_id]
frame_index = int(im["frame_index"])
dset.index.imgs[img_id]["activity_gt"] = activity_gt_list[frame_index]

dset.dump("test.mscoco.json", newlines=True)
print(activity_labels_desc_mapping)


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--config",
default="/home/local/KHQ/peri.akiva/projects/TCN_HPL/configs/experiment/r18/feat_v6.yaml",
help="",
)

args = parser.parse_args()

main(args.config)
141 changes: 141 additions & 0 deletions tcn_hpl/data/tcn_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
from torch.utils.data import Dataset
import kwcoco
from tqdm import tqdm
import logging
import tcn_hpl.utils.utils as utils

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)


def standarize_online_inputs(inputs):
pass


def standarize_offline_inputs(data):
dset = data[0]
frames = data[1]


def collect_inputs(data, offline=True):
"""
Collects inputs from either an offline or real-time source.
Args:
data: data stream input. For offline, a coco dataset. For online, a set of ROS messages
offline (bool): If True, fetches data from offline source; otherwise, real-time.
Returns:
inputs: The collected inputs from either a real-time or offline source in standarized format.
Designed as a dict of lists. all lists must have the size of the widnow_size:
{object_dets: [], pose_estimations:[], left_hand: [], right_hand: []}
"""
if offline:
inputs = standarize_offline_inputs(data)
else:
inputs = standarize_online_inputs(data)

return inputs


def define_tcn_vector(inputs):
"""
Define the TCN vector using the collected inputs.
Args:
inputs: The inputs collected from either real-time or offline source.
Returns:
tcn_vector: The defined TCN vector.
"""
tcn_vector = torch.tensor(inputs).float() # Dummy transformation
return tcn_vector


class TCNDataset(Dataset):
def __init__(self, kwcoco_path: str, sample_rate: int, window_size: int):
"""
Initializes the dataset.
Args:
kwcoco_path: The source of data (can be real-time or offline).
sample_rate:
window_size:
"""
self.sample_rate = sample_rate
self.window_size = window_size

self.dset = kwcoco.CocoDataset(kwcoco_path)

# for offline training, pre-cut videos into clips according to window size for easy batching
self.frames = []

vid_list = list(self.dset.index.videos.keys())
logger.info(
f"Generating dataset with {len(vid_list)} videos"
)
pber = tqdm(
vid_list,
total=len(vid_list),
)
for vid in pber:
video_dict = self.dset.index.videos[vid]

vid_frames = self.dset.index.vidid_to_gids[vid]

for index in range(0, len(vid_frames) - window_size - 1, sample_rate):
video_slice = vid_frames[index : index + window_size]
window_frame_dicts = [self.dset.index.imgs[gid] for gid in video_slice]

# start_frame = window_frame_dicts[0]['frame_index']
# end_frame = window_frame_dicts[-1]['frame_index']

# n_frames = end_frame - start_frame + 1

self.frames.append(window_frame_dicts)

def __getitem__(self, index):
"""
Fetches the data point and defines its TCN vector.
Args:
index: The index of the data point.
Returns:
tcn_vector: The TCN vector for the given index.
"""

data = self.frames[index]

inputs = collect_inputs([self.dset, data], offline=True)

# tcn_vector = define_tcn_vector(inputs)

return

def __len__(self):
"""
Returns the length of the dataset.
Returns:
length: Length of the dataset.
"""
return len(self.frames)


if __name__ == "__main__":
# Example usage:
kwcoco_path = "/data/PTG/medical/training/yolo_object_detector/detect/r18_all/r18_all_all_obj_results_with_dets_and_pose.mscoco.json"

dataset = TCNDataset(kwcoco_path=kwcoco_path, sample_rate=1, window_size=25)

print(f"dataset: {len(dataset)}")
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

for index, batch in enumerate(data_loader):
print(batch) # This will print the TCN vectors for the batch
if index > 15:
exit()
6 changes: 4 additions & 2 deletions tcn_hpl/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import hydra
import rootutils
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

import pytorch_lightning as L
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import Logger

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
Expand Down
50 changes: 43 additions & 7 deletions tcn_hpl/models/ptg_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification import F1Score, Recall, Precision
from torchmetrics.classification import MulticlassConfusionMatrix
from sklearn.metrics import confusion_matrix

Expand Down Expand Up @@ -126,6 +127,27 @@ def __init__(
task="multiclass", average="weighted", num_classes=num_classes
)

self.val_f1 = F1Score(
num_classes=num_classes, average="none", task="multiclass"
)
self.test_f1 = F1Score(
num_classes=num_classes, average="none", task="multiclass"
)

self.val_recall = Recall(
num_classes=num_classes, average="none", task="multiclass"
)
self.test_recall = Recall(
num_classes=num_classes, average="none", task="multiclass"
)

self.val_precision = Precision(
num_classes=num_classes, average="none", task="multiclass"
)
self.test_precision = Precision(
num_classes=num_classes, average="none", task="multiclass"
)

# for averaging loss across batches
self.train_loss = MeanMetric()
self.val_loss = MeanMetric()
Expand Down Expand Up @@ -467,6 +489,7 @@ def validation_step(
windowed_ys = torch.tensor(windowed_ys).to(targets)

self.val_acc(preds, targets[:, -1])

self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

Expand All @@ -480,13 +503,6 @@ def on_validation_epoch_end(self) -> None:
"Lightning hook that is called when a validation epoch ends."
acc = self.val_acc.compute() # get current val acc

if self.current_epoch >= 15:
self.val_acc_best(acc) # update best so far val acc
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
# otherwise metric would be reset by lightning after each epoch
best_val_acc = self.val_acc_best.compute()
self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True)

all_targets = torch.cat(self.validation_step_outputs_target) # shape: #frames
all_preds = torch.cat(self.validation_step_outputs_pred) # shape: #frames
all_probs = torch.cat(
Expand All @@ -495,6 +511,26 @@ def on_validation_epoch_end(self) -> None:
all_source_vids = torch.cat(self.validation_step_outputs_source_vid)
all_source_frames = torch.cat(self.validation_step_outputs_source_frame)

if self.current_epoch >= 15:
current_best_val_acc = self.val_acc_best.compute()
self.val_acc_best(acc) # update best so far val acc
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
# otherwise metric would be reset by lightning after each epoch
best_val_acc = self.val_acc_best.compute()

if best_val_acc > current_best_val_acc:
val_f1_score = self.val_f1(all_preds, all_targets)
val_recall_score = self.val_recall(all_preds, all_targets)
val_precision_score = self.val_precision(all_preds, all_targets)

# print(f"preds: {all_preds}")
# print(f"all_targets: {all_targets}")
print(f"validation f1 score: {val_f1_score}")
print(f"validation recall score: {val_recall_score}")
print(f"validation precision score: {val_precision_score}")

self.log("val/acc_best", best_val_acc, sync_dist=True, prog_bar=True)

# Load val vidoes
if self.val_frames is None:
self.val_frames = {}
Expand Down
Loading

0 comments on commit 46e792a

Please sign in to comment.