diff --git a/README.md b/README.md index 6228bec..d26f71f 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,6 @@ player --folder path_to_folder [--save] [--imshow] ``` -cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json] +cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json] [--no_images] ``` diff --git a/src/kabr_tools/cvat2slowfast.py b/src/kabr_tools/cvat2slowfast.py index 3255f93..1a19a25 100644 --- a/src/kabr_tools/cvat2slowfast.py +++ b/src/kabr_tools/cvat2slowfast.py @@ -11,7 +11,7 @@ def cvat2slowfast(path_to_mini_scenes: str, path_to_new_dataset: str, - label2number: dict, old2new: Optional[dict]) -> None: + label2number: dict, old2new: Optional[dict], no_images: bool) -> None: """ Convert CVAT annotations to the dataset in Charades format. @@ -20,6 +20,7 @@ def cvat2slowfast(path_to_mini_scenes: str, path_to_new_dataset: str, path_to_new_dataset - str. Path to the folder to output dataset files. label2number - dict. Mapping of ethogram labels to integers. old2new - dict [optional]. Mapping of old ethogram labels to new ethogram labels. + no_images - bool. Flag to stop image output. """ if not os.path.exists(path_to_new_dataset): os.makedirs(path_to_new_dataset) @@ -33,12 +34,13 @@ def cvat2slowfast(path_to_mini_scenes: str, path_to_new_dataset: str, with open(f"{path_to_new_dataset}/annotation/classes.json", "w") as file: json.dump(label2number, file) - headers = {"original_vido_id": [], "video_id": pd.Series(dtype="int"), "frame_id": pd.Series(dtype="int"), - "path": [], "labels": []} + headers = ["original_vido_id", "video_id", "frame_id", "path", "labels"] + charades_data = [] + charades_df = pd.DataFrame(data=headers) video_id = 1 folder_name = 1 - flag = False + flag = not no_images for i, folder in enumerate(natsorted(os.listdir(path_to_mini_scenes))): if os.path.exists(f"{path_to_mini_scenes}/{folder}/actions"): @@ -127,13 +129,11 @@ def cvat2slowfast(path_to_mini_scenes: str, path_to_new_dataset: str, if flag: cv2.imwrite(f"{output_folder}/{adjusted_index}.jpg", frame) - # TODO: Major slow down here. Add to a list rather than dataframe, - # and create dataframe at the end. - charades_df.loc[len(charades_df.index)] = [f"{folder_code}", - video_id, - adjusted_index, - f"{folder_code}/{adjusted_index}.jpg", - str(label2number[behavior])] + charades_data.append([f"{folder_code}", + video_id, + adjusted_index, + f"{folder_code}/{adjusted_index}.jpg", + str(label2number[behavior])]) adjusted_index += 1 @@ -145,9 +145,11 @@ def cvat2slowfast(path_to_mini_scenes: str, path_to_new_dataset: str, video_id += 1 if video_id % 10 == 0: + charades_df = pd.DataFrame(charades_data, columns=headers) charades_df.to_csv( f"{path_to_new_dataset}/annotation/data.csv", sep=" ", index=False) + charades_df = pd.DataFrame(charades_data, columns=headers) charades_df.to_csv( f"{path_to_new_dataset}/annotation/data.csv", sep=" ", index=False) @@ -178,6 +180,11 @@ def parse_args() -> argparse.Namespace: help="path to old to new ethogram labels json", required=False ) + local_parser.add_argument( + "--no_images", + action="store_true", + help="flag to stop image output" + ) return local_parser.parse_args() @@ -194,8 +201,8 @@ def main() -> None: else: old2new = None - cvat2slowfast(args.miniscene, args.dataset, label2number, old2new) + cvat2slowfast(args.miniscene, args.dataset, label2number, old2new, args.no_images) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/kabr_tools/miniscene2behavior.py b/src/kabr_tools/miniscene2behavior.py index 8347350..4ce2fc1 100644 --- a/src/kabr_tools/miniscene2behavior.py +++ b/src/kabr_tools/miniscene2behavior.py @@ -19,12 +19,15 @@ def get_input_clip(cap: cv2.VideoCapture, cfg: CfgNode, keyframe_idx: int) -> li # https://github.com/facebookresearch/SlowFast/blob/bac7b672f40d44166a84e8c51d1a5ba367ace816/slowfast/visualization/ava_demo_precomputed_boxes.py seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + assert keyframe_idx < total_frames, f"keyframe_idx: {keyframe_idx}" \ + f" >= total_frames: {total_frames}" seq = get_sequence( keyframe_idx, seq_length // 2, cfg.DATA.SAMPLING_RATE, total_frames, ) + clip = [] for frame_idx in seq: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) @@ -124,29 +127,34 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module, # find all tracks tracks = [] + frames = {} for track in root.iterfind("track"): track_id = track.attrib["id"] tracks.append(track_id) + frames[track_id] = [] - # find all frames - # TODO: rewrite - some tracks may have different frames - assert len(tracks) > 0, "No tracks found in track file" - frames = [] - for box in track.iterfind("box"): - frames.append(int(box.attrib["frame"])) + # find all frames + for box in track.iterfind("box"): + frames[track_id].append(int(box.attrib["frame"])) # run model on miniscene for track in tracks: video_file = f"{miniscene_path}/{track}.mp4" cap = cv2.VideoCapture(video_file) - for frame in tqdm(frames, desc=f"{track} frames"): - inputs = get_input_clip(cap, cfg, frame) + index = 0 + for frame in tqdm(frames[track], desc=f"{track} frames"): + try: + inputs = get_input_clip(cap, cfg, index) + except AssertionError as e: + print(e) + break + index += 1 if cfg.NUM_GPUS: # transfer the data to the current GPU device. if isinstance(inputs, (list,)): - for i in range(len(inputs)): - inputs[i] = inputs[i].cuda(non_blocking=True) + for i, input_clip in enumerate(inputs): + inputs[i] = input_clip.cuda(non_blocking=True) else: inputs = inputs.cuda(non_blocking=True) @@ -163,6 +171,7 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module, if frame % 20 == 0: pd.DataFrame(label_data).to_csv( output_path, sep=" ", index=False) + cap.release() pd.DataFrame(label_data).to_csv(output_path, sep=" ", index=False) diff --git a/tests/examples/DETECTOR1/DJI_tracks.xml b/tests/examples/DETECTOR1/DJI_tracks.xml new file mode 100644 index 0000000..22c734f --- /dev/null +++ b/tests/examples/DETECTOR1/DJI_tracks.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_cvat2slowfast.py b/tests/test_cvat2slowfast.py index 674837c..1c4bf62 100644 --- a/tests/test_cvat2slowfast.py +++ b/tests/test_cvat2slowfast.py @@ -1,11 +1,20 @@ import unittest import sys import os +import json +import pandas as pd +import cv2 from kabr_tools import cvat2slowfast +from tests.test_tracks_extractor import ( + scene_width, + scene_height +) from tests.utils import ( - get_behavior, del_dir, - del_file + del_file, + dir_exists, + file_exists, + get_behavior ) @@ -50,6 +59,40 @@ def test_run(self): "--classes", self.classes] run() + # check output dirs + self.assertTrue(dir_exists(self.dataset)) + self.assertTrue(dir_exists(f"{self.dataset}/annotation")) + self.assertTrue(dir_exists(f"{self.dataset}/dataset/image")) + self.assertTrue(file_exists(f"{self.dataset}/annotation/classes.json")) + self.assertTrue(file_exists(f"{self.dataset}/annotation/data.csv")) + + # check classes.json + with open(f"{self.dataset}/annotation/classes.json", "r", encoding="utf-8") as f: + classes = json.load(f) + with open(self.classes, "r", encoding="utf-8") as f: + ethogram = json.load(f) + self.assertEqual(classes, ethogram) + + # check data.csv + with open(f"{self.dataset}/annotation/data.csv", "r", encoding="utf-8") as f: + df = pd.read_csv(f, sep=" ") + + video_id = 1 + for i, row in df.iterrows(): + self.assertEqual(row["original_vido_id"], f"Z{video_id:04d}") + self.assertEqual(row["video_id"], video_id) + self.assertEqual(row["frame_id"], i+1) + self.assertEqual(row["path"], f"Z{video_id:04d}/{i+1}.jpg") + self.assertEqual(row["labels"], 1) + self.assertEqual(i, 90) + + # check dataset + for i in range(1, 92): + data_im = f"{self.dataset}/dataset/image/Z{video_id:04d}/{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_im = cv2.imread(data_im) + self.assertEqual(data_im.shape, (scene_height, scene_width, 3)) + def test_parse_arg_min(self): # parse arguments sys.argv = [self.tool, @@ -65,9 +108,7 @@ def test_parse_arg_min(self): # check default argument values self.assertEqual(args.old2new, None) - - # run cvat2slowfast - run() + self.assertTrue(not args.no_images) def test_parse_arg_full(self): # parse arguments @@ -75,7 +116,8 @@ def test_parse_arg_full(self): "--miniscene", self.miniscene, "--dataset", self.dataset, "--classes", self.classes, - "--old2new", self.old2new] + "--old2new", self.old2new, + "--no_images"] args = cvat2slowfast.parse_args() # check parsed argument values @@ -83,6 +125,4 @@ def test_parse_arg_full(self): self.assertEqual(args.dataset, self.dataset) self.assertEqual(args.classes, self.classes) self.assertEqual(args.old2new, self.old2new) - - # run cvat2slowfast - run() + self.assertTrue(args.no_images) diff --git a/tests/test_cvat2ultralytics.py b/tests/test_cvat2ultralytics.py index 881245a..0fb246a 100644 --- a/tests/test_cvat2ultralytics.py +++ b/tests/test_cvat2ultralytics.py @@ -1,11 +1,16 @@ import unittest import sys import os +from lxml import etree +import pandas as pd +import cv2 from kabr_tools import cvat2ultralytics from tests.utils import ( del_dir, del_file, - get_detection + get_detection, + dir_exists, + file_exists ) @@ -32,7 +37,7 @@ def setUp(self): self.video = TestCvat2Ultralytics.dir self.annotation = TestCvat2Ultralytics.dir self.dataset = "tests/ultralytics" - self.skip = "5" + self.skip = "1" self.label2index = "ethogram/label2index.json" def tearDown(self): @@ -44,9 +49,75 @@ def test_run(self): sys.argv = [self.tool, "--video", self.video, "--annotation", self.annotation, - "--dataset", "tests/ultralytics"] + "--dataset", "tests/ultralytics", + "--skip", self.skip] run() + # check for output dirs + self.assertTrue(dir_exists(self.dataset)) + self.assertTrue(dir_exists(f"{self.dataset}/images/test")) + self.assertTrue(dir_exists(f"{self.dataset}/images/train")) + self.assertTrue(dir_exists(f"{self.dataset}/images/val")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/test")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/train")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/val")) + + # check output + annotations = etree.parse(TestCvat2Ultralytics.annotation).getroot() + tracks = [list(track.findall("box")) for track in annotations.findall("track")] + self.assertEqual(len(tracks[0]), 21) + self.assertEqual(len(tracks[0]), len(tracks[1])) + original_size = annotations.find("meta").find("task").find("original_size") + height = int(original_size.find("height").text) + width = int(original_size.find("width").text) + for i in range(len(tracks[0])): + # check existence + if i < 16: + data_im = f"{self.dataset}/images/train/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/train/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + elif i < 18: + data_im = f"{self.dataset}/images/val/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/val/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + else: + data_im = f"{self.dataset}/images/test/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/test/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + + # check image + data_im = cv2.imread(data_im) + self.assertEqual(data_im.shape, (height, width, 3)) + + # check label + data_label = pd.read_csv(data_label, sep = " ", header = None) + annotation_label = [] + for track in tracks: + box = track[i] + x_start = float(box.attrib["xtl"]) + y_start = float(box.attrib["ytl"]) + x_end = float(box.attrib["xbr"]) + y_end = float(box.attrib["ybr"]) + x_center = (x_start + (x_end - x_start) / 2) / width + y_center = (y_start + (y_end - y_start) / 2) / height + w = (x_end - x_start) / width + h = (y_end - y_start) / height + annotation_label.append( + [0, x_center, y_center, w, h] + ) + self.assertEqual(len(data_label.index), len(annotation_label)) + + for i, label in enumerate(annotation_label): + self.assertEqual(label[0], annotation_label[i][0]) + self.assertAlmostEqual(label[1], annotation_label[i][1], places=4) + self.assertAlmostEqual(label[2], annotation_label[i][2], places=4) + self.assertAlmostEqual(label[3], annotation_label[i][3], places=4) + self.assertAlmostEqual(label[4], annotation_label[i][4], places=4) + + def test_parse_arg_min(self): # parse arguments sys.argv = [self.tool, @@ -81,7 +152,7 @@ def test_parse_arg_full(self): self.assertEqual(args.video, self.video) self.assertEqual(args.annotation, self.annotation) self.assertEqual(args.dataset, self.dataset) - self.assertEqual(args.skip, 5) + self.assertEqual(args.skip, int(self.skip)) self.assertEqual(args.label2index, self.label2index) # run cvat2ultralytics diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py index 675e7bc..b00a122 100644 --- a/tests/test_detector2cvat.py +++ b/tests/test_detector2cvat.py @@ -1,14 +1,67 @@ import unittest import sys import os +from lxml import etree +from unittest.mock import MagicMock, patch +import cv2 +from lxml import etree +import numpy as np from kabr_tools import detector2cvat +from kabr_tools.utils.yolo import YOLOv8 from tests.utils import ( del_dir, del_file, + file_exists, get_detection ) +class DetectionData: + def __init__(self, video_dim, video_len, annotation): + self.video_dim = video_dim + self.video_len = video_len + self.frame = -1 + self.video_frame = np.zeros(video_dim, dtype=np.uint8) + + annotation = etree.parse(annotation).getroot() + self.tracks = [[0, list(track.findall("box")), + track.get("label")] + for track in annotation.findall("track")] + + def read(self): + if self.frame >= self.video_len - 1: + return False, None + self.frame += 1 + return True, self.video_frame + + def get(self, param): + if param == cv2.CAP_PROP_FRAME_COUNT: + return self.video_len + elif param == cv2.CAP_PROP_FRAME_HEIGHT: + return self.video_dim[0] + elif param == cv2.CAP_PROP_FRAME_WIDTH: + return self.video_dim[1] + else: + return None + + def forward(self, data): + soln = [] + for track in self.tracks: + ptr = track[0] + if ptr < len(track[1]): + box = track[1][ptr] + frame = int(box.get("frame")) + if frame == self.frame: + soln.append([[int(float(box.get("xtl"))), + int(float(box.get("ytl"))), + int(float(box.get("xbr"))), + int(float(box.get("ybr")))], + 0.95, + track[2]]) + track[0] += 1 + return soln + + def run(): detector2cvat.main() @@ -33,17 +86,231 @@ def setUp(self): self.tool = "detector2cvat.py" self.video = TestDetector2Cvat.dir self.save = "tests/detector2cvat" + self.dir = "/".join(os.path.splitext(self.video)[0].split('/')[-2:]) def tearDown(self): # delete outputs del_dir(self.save) def test_run(self): + # check if tool runs on real data + save = f"{self.save}/run" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + @patch('kabr_tools.detector2cvat.YOLOv8') + def test_mock_yolo(self, yolo): + # create fake YOLO + yolo_instance = MagicMock() + yolo_instance.forward.return_value = [[[0, 0, 0, 0], 0.95, 'Grevy']] + yolo.get_centroid.return_value = (50, 50) + yolo.return_value = yolo_instance + # run detector2cvat + save = f"{self.save}/mock/0" sys.argv = [self.tool, "--video", self.video, - "--save", self.save] - run() + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + for j, track in enumerate(xml_content.findall("track")): + track_len = len(track.findall("box")) + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + # TODO: Check if source should be manual + self.assertEqual(track.get("source"), "manual") + for i, box in enumerate(track.findall("box")): + self.assertEqual(box.get("frame"), str(i)) + self.assertEqual(box.get("xtl"), "0.00") + self.assertEqual(box.get("ytl"), "0.00") + self.assertEqual(box.get("xbr"), "0.00") + self.assertEqual(box.get("ybr"), "0.00") + self.assertEqual(box.get("occluded"), "0") + self.assertEqual(box.get("keyframe"), "1") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if i == track_len - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + + # checkout output video + cap = cv2.VideoCapture(demo_path) + video = cv2.VideoCapture(TestDetector2Cvat.video) + self.assertTrue(cap.isOpened()) + self.assertTrue(video.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + video.get(cv2.CAP_PROP_FRAME_COUNT)) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + video.get(cv2.CAP_PROP_FRAME_WIDTH)) + cap.release() + video.release() + + @patch('kabr_tools.detector2cvat.YOLOv8') + @patch('kabr_tools.detector2cvat.cv2.VideoCapture') + def test_mock_with_data(self, video_capture, yolo): + # mock outputs CVAT data + ref_path = "tests/examples/MINISCENE1/metadata/DJI_tracks.xml" + height, width, frames = 3078, 5472, 21 + data = DetectionData((height, width, 3), + frames, + ref_path) + yolo_instance = MagicMock() + yolo_instance.forward = data.forward + yolo.return_value = yolo_instance + yolo.get_centroid = MagicMock( + side_effect=lambda pred: YOLOv8.get_centroid(pred)) + + vc = MagicMock() + vc.read = data.read + vc.get = data.get + video_capture.return_value = vc + + # run detector2cvat + save = f"{self.save}/mock/1" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + ref_content = etree.parse(ref_path).getroot() + ref_track = list(ref_content.findall("track")) + for j, track in enumerate(xml_content.findall("track")): + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + self.assertEqual(track.get("source"), "manual") + ref_box = list(ref_track[j].findall("box")) + for i, box in enumerate(track.findall("box")): + self.assertEqual(box.get("frame"), ref_box[i].get("frame")) + self.assertEqual(box.get("xtl"), ref_box[i].get("xtl")) + self.assertEqual(box.get("ytl"), ref_box[i].get("ytl")) + self.assertEqual(box.get("xbr"), ref_box[i].get("xbr")) + self.assertEqual(box.get("ybr"), ref_box[i].get("ybr")) + self.assertEqual(box.get("occluded"), "0") + self.assertEqual(box.get("keyframe"), "1") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if i == frames - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + + # checkout output video + cap = cv2.VideoCapture(demo_path) + self.assertTrue(cap.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + frames) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + height) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + width) + cap.release() + + @patch('kabr_tools.detector2cvat.YOLOv8') + @patch('kabr_tools.detector2cvat.cv2.VideoCapture') + def test_mock_noncontiguous(self, video_capture, yolo): + # mock outputs non-contiguous frame detections + ref_path = "tests/examples/DETECTOR1/DJI_tracks.xml" + height, width, frames = 3078, 5472, 31 + data = DetectionData((height, width, 3), + frames, + ref_path) + yolo_instance = MagicMock() + yolo_instance.forward = data.forward + yolo.return_value = yolo_instance + yolo.get_centroid = MagicMock( + side_effect=lambda pred: YOLOv8.get_centroid(pred)) + + vc = MagicMock() + vc.read = data.read + vc.get = data.get + video_capture.return_value = vc + + # run detector2cvat + save = f"{self.save}/mock/2" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + ref_content = etree.parse(ref_path).getroot() + ref_track = list(ref_content.findall("track")) + for j, track in enumerate(xml_content.findall("track")): + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + self.assertEqual(track.get("source"), "manual") + ref_box = list(ref_track[j].findall("box")) + i = 0 + frame = int(track.find("box").get("frame")) + for box in track.findall("box"): + if box.get("frame") == ref_box[i+1].get("frame"): + i += 1 + print(box.get("frame"), ref_box[i].get("frame")) + self.assertEqual(box.get("frame"), str(frame)) + self.assertEqual(box.get("xtl"), ref_box[i].get("xtl")) + self.assertEqual(box.get("ytl"), ref_box[i].get("ytl")) + self.assertEqual(box.get("xbr"), ref_box[i].get("xbr")) + self.assertEqual(box.get("ybr"), ref_box[i].get("ybr")) + self.assertEqual(box.get("occluded"), "0") + if box.get("frame") == ref_box[i].get("frame"): + self.assertEqual(box.get("keyframe"), "1") + else: + self.assertEqual(box.get("keyframe"), "0") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if frame == frames - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + frame += 1 + + # checkout output video + cap = cv2.VideoCapture(demo_path) + self.assertTrue(cap.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + frames) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + height) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + width) + cap.release() def test_parse_arg_min(self): # parse arguments diff --git a/tests/test_miniscene2behavior.py b/tests/test_miniscene2behavior.py index 7875e2d..35b589a 100644 --- a/tests/test_miniscene2behavior.py +++ b/tests/test_miniscene2behavior.py @@ -4,6 +4,8 @@ import os import requests from unittest.mock import Mock, patch +from lxml import etree +from ruamel.yaml import YAML import torch import numpy as np import pandas as pd @@ -15,6 +17,7 @@ from tests.utils import ( del_file, del_dir, + clean_dir, get_detection ) @@ -33,6 +36,7 @@ class TestMiniscene2Behavior(unittest.TestCase): def setUpClass(cls): # download the model from Imageomics HF cls.checkpoint = "checkpoint_epoch_00075.pyth" + cls.model_output = None cls.download_model() # download data @@ -48,9 +52,10 @@ def setUpClass(cls): @classmethod def download_model(cls): if not os.path.exists(cls.checkpoint): + # download checkpoint archive url = "https://huggingface.co/imageomics/" \ - + "x3d-kabr-kinetics/resolve/main/" \ - + f"{cls.checkpoint}.zip" + + "x3d-kabr-kinetics/resolve/main/" \ + + f"{cls.checkpoint}.zip" r = requests.get(url, allow_redirects=True, timeout=120) with open(f"{cls.checkpoint}.zip", "wb") as f: f.write(r.content) @@ -59,13 +64,25 @@ def download_model(cls): with zipfile.ZipFile(f"{cls.checkpoint}.zip", "r") as zip_ref: zip_ref.extractall(".") + # get checkpoint directory + try: + cfg = torch.load(cls.checkpoint, + weights_only=True, + map_location=torch.device("cpu"))["cfg"] + yaml = YAML(typ="rt") + cls.model_output = f"{yaml.load(cfg)['OUTPUT_DIR']}/checkpoints" + except Exception: + pass + @classmethod def tearDownClass(cls): # remove model files after tests - if os.path.exists(f"{cls.checkpoint}.zip"): - os.remove(f"{cls.checkpoint}.zip") - if os.path.exists(cls.checkpoint): - os.remove(cls.checkpoint) + del_file(f"{cls.checkpoint}.zip") + del_file(cls.checkpoint) + if cls.model_output: + clean_dir(cls.model_output) + + # remove data after tests del_file(cls.video) del_file(cls.annotation) del_dir(cls.miniscene) @@ -91,18 +108,40 @@ def test_run(self): sys.argv = [self.tool, "--checkpoint", self.checkpoint, "--miniscene", self.miniscene, - "--video", self.video] + "--video", self.video, + "--output", self.output] run() + # check output CSV + df = pd.read_csv(self.output, sep=' ') + self.assertEqual(list(df.columns), [ + "video", "track", "frame", "label"]) + row_ct = 0 + + root = etree.parse( + f"{self.miniscene}/metadata/{self.video}_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row = list(df.loc[row_ct]) + self.assertEqual(row[0], self.video) + self.assertEqual(row[1], track_id) + self.assertEqual(row[2], int(box.get("frame"))) + self.assertTrue(row[3] >= 0) + self.assertTrue(row[3] <= 7) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) + @patch('kabr_tools.miniscene2behavior.process_cv2_inputs') @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture') def test_matching_tracks(self, video_capture, process_cv2_inputs): - - # Create fake model that always returns a prediction of 1 + # create fake model that weights class 98 mock_model = Mock() - mock_model.return_value = torch.tensor([1]) + prob = torch.zeros(99) + prob[-1] = 1 + mock_model.return_value = prob - # Create fake cfg + # create fake cfg mock_config = Mock( DATA=Mock(NUM_FRAMES=16, SAMPLING_RATE=5, @@ -111,25 +150,36 @@ def test_matching_tracks(self, video_capture, process_cv2_inputs): OUTPUT_DIR='' ) - # Create fake video capture + # create fake video capture vc = video_capture.return_value vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8) - vc.get.return_value = 1 + vc.get.return_value = 21 self.output = '/tmp/annotation_data.csv' + miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE1") + video_name = "DJI" annotate_miniscene(cfg=mock_config, model=mock_model, - miniscene_path=os.path.join( - EXAMPLESDIR, "MINISCENE1"), - video='DJI', + miniscene_path=miniscene_dir, + video=video_name, output_path=self.output) - # Read in output CSV and make sure we have the expected columns and at least one row + # check output CSV df = pd.read_csv(self.output, sep=' ') self.assertEqual(list(df.columns), [ "video", "track", "frame", "label"]) - self.assertGreater(len(df.index), 0) + row_ct = 0 + + root = etree.parse( + f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row_val = [video_name, track_id, int(box.get("frame")), 98] + self.assertEqual(list(df.loc[row_ct]), row_val) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) @patch('kabr_tools.miniscene2behavior.process_cv2_inputs') @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture') @@ -151,9 +201,11 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs): # Create fake video capture vc = video_capture.return_value vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8) - vc.get.return_value = 1 + vc.get.return_value = 21 self.output = '/tmp/annotation_data.csv' + miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE2") + video_name = "DJI" annotate_miniscene(cfg=mock_config, model=mock_model, @@ -162,11 +214,21 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs): video='DJI', output_path=self.output) - # Read in output CSV and make sure we have the expected columns and at least one row + # check output CSV df = pd.read_csv(self.output, sep=' ') self.assertEqual(list(df.columns), [ "video", "track", "frame", "label"]) - self.assertGreater(len(df.index), 0) + row_ct = 0 + + root = etree.parse( + f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row_val = [video_name, track_id, int(box.get("frame")), 0] + self.assertEqual(list(df.loc[row_ct]), row_val) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) def test_parse_arg_min(self): # parse arguments diff --git a/tests/test_player.py b/tests/test_player.py index 0a2d3fb..091aeb0 100644 --- a/tests/test_player.py +++ b/tests/test_player.py @@ -6,6 +6,7 @@ from tests.utils import ( del_file, del_dir, + file_exists, get_behavior ) @@ -58,6 +59,7 @@ def test_run(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, "--folder", self.folder, "--save"] run() + self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4")) @patch('kabr_tools.player.cv2.imshow') @patch('kabr_tools.player.cv2.namedWindow') @@ -75,10 +77,10 @@ def test_parse_arg_min(self, getTrackbarPos, setTrackbarPos, createTrackbar, nam # check default arguments self.assertEqual(args.save, False) - self.assertEqual(args.imshow, False) # run player run() + self.assertTrue(not file_exists(f"{self.folder}/{self.video}_demo.mp4")) @patch('kabr_tools.player.cv2.imshow') @patch('kabr_tools.player.cv2.namedWindow') @@ -95,7 +97,7 @@ def test_parse_arg_full(self, getTrackbarPos, setTrackbarPos, createTrackbar, na # check parsed arguments self.assertEqual(args.folder, self.folder) self.assertEqual(args.save, True) - self.assertEqual(args.imshow, True) # run player run() + self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4")) diff --git a/tests/test_tracks_extractor.py b/tests/test_tracks_extractor.py index 8f61a1e..71a7e9a 100644 --- a/tests/test_tracks_extractor.py +++ b/tests/test_tracks_extractor.py @@ -1,13 +1,25 @@ import unittest import sys -from unittest.mock import patch +import os +from unittest.mock import patch, Mock +import json +from lxml import etree +import cv2 from kabr_tools import tracks_extractor +from kabr_tools.utils.tracker import Tracker +from kabr_tools.utils.detector import Detector +from kabr_tools.utils.utils import get_scene from tests.utils import ( get_detection, + dir_exists, + file_exists, del_dir, del_file ) +# TODO: make constants for kabr tools (copied values in tracks_extractor.py) +scene_width, scene_height = 400, 300 + def run(): tracks_extractor.main() @@ -46,6 +58,113 @@ def test_run(self): "--annotation", self.annotation] run() + # check output exists + mini_folder = os.path.splitext("|".join(self.video.split("/")[-3:]))[0] + video_name = "DJI_0068" + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}")) + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/actions")) + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/metadata")) + self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/0.mp4")) + self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/1.mp4")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/{video_name}.mp4")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}.jpg")) + + # check metadata.json + root = etree.parse(self.annotation).getroot() + tracks = {"main": + [-1] * int("".join(root.find("meta").find("task").find("size").itertext()))} + for track in root.iterfind("track"): + track_id = track.attrib["id"] + tracks[track_id] = [] + for box in track.iter("box"): + frame_id = int(box.attrib["frame"]) + tracks[track_id].append(frame_id) + tracks["main"][frame_id] = frame_id + + colors = list(Tracker.colors_table.values()) + + with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json", + "r", encoding="utf-8") as f: + metadata = json.load(f) + self.assertTrue("original" in metadata) + self.assertTrue("tracks" in metadata) + self.assertTrue("colors" in metadata) + self.assertEqual(metadata["original"], self.video) + self.assertEqual(metadata["tracks"]["main"], tracks["main"]) + self.assertEqual(metadata["tracks"]["0"], tracks["0"]) + self.assertEqual(metadata["tracks"]["1"], tracks["1"]) + self.assertEqual(metadata["colors"]["0"], + list(colors[0 % len(colors)])) + self.assertEqual(metadata["colors"]["1"], + list(colors[1 % len(colors)])) + + # check tracks.xml + with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml", + "r", encoding="utf-8") as f: + track_copy = f.read() + + with open(self.annotation, "r", encoding="utf-8") as f: + track = f.read() + + self.assertEqual(track, track_copy) + + # check 0.mp4, 1.mp4 + root = etree.parse(self.annotation).getroot() + xml_tracks = {} + for track in root.findall("track"): + track_id = track.attrib["id"] + xml_tracks[track_id] = track + self.assertEqual(xml_tracks.keys(), {"0", "1"}) + + original = cv2.VideoCapture(self.video) + self.assertTrue(original.isOpened()) + mock = Mock() + + for track_id, xml_track in xml_tracks.items(): + track = cv2.VideoCapture( + f"mini-scenes/{mini_folder}/{track_id}.mp4") + self.assertTrue(track.isOpened()) + + for i, box in enumerate(xml_track.iter("box")): + original.set(cv2.CAP_PROP_POS_FRAMES, int(box.attrib["frame"])) + track.set(cv2.CAP_PROP_POS_FRAMES, i) + original_returned, original_frame = original.read() + track_returned, track_frame = track.read() + + self.assertTrue(original_returned) + self.assertTrue(track_returned) + + mock.box = [int(float(box.attrib["xtl"])), + int(float(box.attrib["ytl"])), + int(float(box.attrib["xbr"])), + int(float(box.attrib["ybr"]))] + mock.centroid = Detector.get_centroid(mock.box) + original_frame = get_scene( + original_frame, mock, scene_width, scene_height) + + # encoding seems to add some noise to frames, allow for that + self.assertTrue( + cv2.norm(original_frame - track_frame) < 1e6) + track.release() + + # check DJI_0068.mp4 + copy = cv2.VideoCapture(f"mini-scenes/{mini_folder}/{video_name}.mp4") + self.assertTrue(copy.isOpened()) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_COUNT), + original.get(cv2.CAP_PROP_FRAME_COUNT)) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_WIDTH), + original.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_HEIGHT), + original.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + original.release() + def test_parse_arg_min(self): # parse arguments sys.argv = [self.tool, diff --git a/tests/utils.py b/tests/utils.py index c3b2e5c..b8d3004 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,6 @@ import os import shutil +from pathlib import Path from huggingface_hub import hf_hub_download DATA_HUB = "imageomics/kabr_testing" @@ -31,6 +32,10 @@ def get_detection(): annotation = get_cached_datafile(DETECTION_ANNOTATION) return video, annotation +def clean_dir(path): + if os.path.exists(path): + os.removedirs(path) + def del_dir(path): if os.path.exists(path): shutil.rmtree(path) @@ -38,3 +43,10 @@ def del_dir(path): def del_file(path): if os.path.exists(path): os.remove(path) + +def file_exists(path): + return Path(path).is_file() + + +def dir_exists(path): + return Path(path).is_dir()