diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..a4cdadd --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,34 @@ +name: Test + +on: + push: + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + pip install -r requirements.txt + python -m pip install hatchling + python -m pip install --no-build-isolation . + - name: Running unit tests + env: + HF_TOKEN: ${{ secrets.HF_TOKEN_TESTING }} + run: | + python -m unittest tests/test_cvat2slowfast.py + python -m unittest tests/test_cvat2ultralytics.py + python -m unittest tests/test_detector2cvat.py + python -m unittest tests/test_miniscene2behavior.py + python -m unittest tests/test_player.py + python -m unittest tests/test_tracks_extractor.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 36d016c..77b27f5 100644 --- a/.gitignore +++ b/.gitignore @@ -177,5 +177,18 @@ cython_debug/ # Mac System .DS_Store +# Tool output +*.json +*.xml +*.jpg +*.yaml +*.csv +*.txt + +# Model files +*.pyth +*.pyth.zip +*.yml + helper_scripts/mini-scenes \ No newline at end of file diff --git a/README.md b/README.md index 6228bec..d26f71f 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,6 @@ player --folder path_to_folder [--save] [--imshow] ``` -cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json] +cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json] [--no_images] ``` diff --git a/ethogram/label2index.json b/ethogram/label2index.json new file mode 100644 index 0000000..a110b82 --- /dev/null +++ b/ethogram/label2index.json @@ -0,0 +1,6 @@ +{ + "Grevy": 0, + "Zebra": 0, + "Baboon": 1, + "Giraffe": 2 +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 46ea954..a9b7d35 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ ultralytics~=8.0.36 pandas>=1.3.5 pillow==10.4.0 scikit-learn==1.5.1 +huggingface_hub diff --git a/src/kabr_tools/cvat2ultralytics.py b/src/kabr_tools/cvat2ultralytics.py index 8dd14d8..073108b 100644 --- a/src/kabr_tools/cvat2ultralytics.py +++ b/src/kabr_tools/cvat2ultralytics.py @@ -3,7 +3,7 @@ import argparse import json import cv2 -import ruamel.yaml as yaml +from ruamel.yaml import YAML from lxml import etree from collections import OrderedDict from tqdm import tqdm @@ -39,8 +39,10 @@ def cvat2ultralytics(video_path: str, annotation_path: str, shutil.rmtree(f"{dataset}") with open(f"{dataset}.yaml", "w") as file: - yaml.dump(yaml.load(dataset_file, Loader=yaml.RoundTripLoader, preserve_quotes=True), - file, Dumper=yaml.RoundTripDumper) + yaml = YAML(typ='rt') + yaml.preserve_quotes = True + data = yaml.load(dataset_file) + yaml.dump(data, file) if not os.path.exists(f"{dataset}/images/train"): os.makedirs(f"{dataset}/images/train") @@ -57,6 +59,7 @@ def cvat2ultralytics(video_path: str, annotation_path: str, if label2index is None: label2index = { + "Grevy": 0, "Zebra": 0, "Baboon": 1, "Giraffe": 2 @@ -69,21 +72,24 @@ def cvat2ultralytics(video_path: str, annotation_path: str, for root, dirs, files in os.walk(annotation_path): for file in files: video_name = os.path.join(video_path + root[len(annotation_path):], os.path.splitext(file)[0]) - - if os.path.exists(video_name + ".MP4"): - videos.append(video_name + ".MP4") - else: - videos.append(video_name + ".mp4") - - annotations.append(os.path.join(root, file)) + if file.endswith(".xml"): + if os.path.exists(video_name + ".MP4"): + videos.append(video_name + ".MP4") + else: + videos.append(video_name + ".mp4") + annotations.append(os.path.join(root, file)) for i, (video, annotation) in enumerate(zip(videos, annotations)): - print(f"{i + 1}/{len(annotations)}:") + print(f"{i + 1}/{len(annotations)}:", flush=True) if not os.path.exists(video): print(f"Path {video} does not exist.") continue + if not os.path.exists(annotation): + print(f"Path {annotation} does not exist.") + continue + # Parse CVAT for video 1.1 annotation file. root = etree.parse(annotation).getroot() name = os.path.splitext(video.split("/")[-1])[0] diff --git a/src/kabr_tools/miniscene2behavior.py b/src/kabr_tools/miniscene2behavior.py index 8347350..4ce2fc1 100644 --- a/src/kabr_tools/miniscene2behavior.py +++ b/src/kabr_tools/miniscene2behavior.py @@ -19,12 +19,15 @@ def get_input_clip(cap: cv2.VideoCapture, cfg: CfgNode, keyframe_idx: int) -> li # https://github.com/facebookresearch/SlowFast/blob/bac7b672f40d44166a84e8c51d1a5ba367ace816/slowfast/visualization/ava_demo_precomputed_boxes.py seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + assert keyframe_idx < total_frames, f"keyframe_idx: {keyframe_idx}" \ + f" >= total_frames: {total_frames}" seq = get_sequence( keyframe_idx, seq_length // 2, cfg.DATA.SAMPLING_RATE, total_frames, ) + clip = [] for frame_idx in seq: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) @@ -124,29 +127,34 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module, # find all tracks tracks = [] + frames = {} for track in root.iterfind("track"): track_id = track.attrib["id"] tracks.append(track_id) + frames[track_id] = [] - # find all frames - # TODO: rewrite - some tracks may have different frames - assert len(tracks) > 0, "No tracks found in track file" - frames = [] - for box in track.iterfind("box"): - frames.append(int(box.attrib["frame"])) + # find all frames + for box in track.iterfind("box"): + frames[track_id].append(int(box.attrib["frame"])) # run model on miniscene for track in tracks: video_file = f"{miniscene_path}/{track}.mp4" cap = cv2.VideoCapture(video_file) - for frame in tqdm(frames, desc=f"{track} frames"): - inputs = get_input_clip(cap, cfg, frame) + index = 0 + for frame in tqdm(frames[track], desc=f"{track} frames"): + try: + inputs = get_input_clip(cap, cfg, index) + except AssertionError as e: + print(e) + break + index += 1 if cfg.NUM_GPUS: # transfer the data to the current GPU device. if isinstance(inputs, (list,)): - for i in range(len(inputs)): - inputs[i] = inputs[i].cuda(non_blocking=True) + for i, input_clip in enumerate(inputs): + inputs[i] = input_clip.cuda(non_blocking=True) else: inputs = inputs.cuda(non_blocking=True) @@ -163,6 +171,7 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module, if frame % 20 == 0: pd.DataFrame(label_data).to_csv( output_path, sep=" ", index=False) + cap.release() pd.DataFrame(label_data).to_csv(output_path, sep=" ", index=False) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/examples/DETECTOR1/DJI_tracks.xml b/tests/examples/DETECTOR1/DJI_tracks.xml new file mode 100644 index 0000000..22c734f --- /dev/null +++ b/tests/examples/DETECTOR1/DJI_tracks.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/examples/MINISCENE1/metadata/DJI_tracks.xml b/tests/examples/MINISCENE1/metadata/DJI_tracks.xml new file mode 100644 index 0000000..ee92dbd --- /dev/null +++ b/tests/examples/MINISCENE1/metadata/DJI_tracks.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/examples/MINISCENE2/metadata/DJI_tracks.xml b/tests/examples/MINISCENE2/metadata/DJI_tracks.xml new file mode 100644 index 0000000..89696a4 --- /dev/null +++ b/tests/examples/MINISCENE2/metadata/DJI_tracks.xml @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/test_cvat2slowfast.py b/tests/test_cvat2slowfast.py new file mode 100644 index 0000000..1c4bf62 --- /dev/null +++ b/tests/test_cvat2slowfast.py @@ -0,0 +1,128 @@ +import unittest +import sys +import os +import json +import pandas as pd +import cv2 +from kabr_tools import cvat2slowfast +from tests.test_tracks_extractor import ( + scene_width, + scene_height +) +from tests.utils import ( + del_dir, + del_file, + dir_exists, + file_exists, + get_behavior +) + + +def run(): + cvat2slowfast.main() + + +class TestCvat2Slowfast(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download data + cls.video, cls.miniscene, cls.annotation, cls.metadata = get_behavior() + cls.dir = os.path.dirname(os.path.dirname(cls.video)) + + @classmethod + def tearDownClass(cls): + # delete data + del_file(cls.video) + del_file(cls.miniscene) + del_file(cls.annotation) + del_file(cls.metadata) + del_dir(cls.dir) + + def setUp(self): + # set params + self.tool = "cvat2slowfast.py" + self.miniscene = TestCvat2Slowfast.dir + self.dataset = "tests/slowfast" + self.classes = "ethogram/classes.json" + self.old2new = "ethogram/old2new.json" + + def tearDown(self): + # delete outputs + del_dir(self.dataset) + + def test_run(self): + # run cvat2slowfast + sys.argv = [self.tool, + "--miniscene", self.miniscene, + "--dataset", self.dataset, + "--classes", self.classes] + run() + + # check output dirs + self.assertTrue(dir_exists(self.dataset)) + self.assertTrue(dir_exists(f"{self.dataset}/annotation")) + self.assertTrue(dir_exists(f"{self.dataset}/dataset/image")) + self.assertTrue(file_exists(f"{self.dataset}/annotation/classes.json")) + self.assertTrue(file_exists(f"{self.dataset}/annotation/data.csv")) + + # check classes.json + with open(f"{self.dataset}/annotation/classes.json", "r", encoding="utf-8") as f: + classes = json.load(f) + with open(self.classes, "r", encoding="utf-8") as f: + ethogram = json.load(f) + self.assertEqual(classes, ethogram) + + # check data.csv + with open(f"{self.dataset}/annotation/data.csv", "r", encoding="utf-8") as f: + df = pd.read_csv(f, sep=" ") + + video_id = 1 + for i, row in df.iterrows(): + self.assertEqual(row["original_vido_id"], f"Z{video_id:04d}") + self.assertEqual(row["video_id"], video_id) + self.assertEqual(row["frame_id"], i+1) + self.assertEqual(row["path"], f"Z{video_id:04d}/{i+1}.jpg") + self.assertEqual(row["labels"], 1) + self.assertEqual(i, 90) + + # check dataset + for i in range(1, 92): + data_im = f"{self.dataset}/dataset/image/Z{video_id:04d}/{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_im = cv2.imread(data_im) + self.assertEqual(data_im.shape, (scene_height, scene_width, 3)) + + def test_parse_arg_min(self): + # parse arguments + sys.argv = [self.tool, + "--miniscene", self.miniscene, + "--dataset", self.dataset, + "--classes", self.classes] + args = cvat2slowfast.parse_args() + + # check parsed argument values + self.assertEqual(args.miniscene, self.miniscene) + self.assertEqual(args.dataset, self.dataset) + self.assertEqual(args.classes, self.classes) + + # check default argument values + self.assertEqual(args.old2new, None) + self.assertTrue(not args.no_images) + + def test_parse_arg_full(self): + # parse arguments + sys.argv = ["cvat2slowfast.py", + "--miniscene", self.miniscene, + "--dataset", self.dataset, + "--classes", self.classes, + "--old2new", self.old2new, + "--no_images"] + args = cvat2slowfast.parse_args() + + # check parsed argument values + self.assertEqual(args.miniscene, self.miniscene) + self.assertEqual(args.dataset, self.dataset) + self.assertEqual(args.classes, self.classes) + self.assertEqual(args.old2new, self.old2new) + self.assertTrue(args.no_images) diff --git a/tests/test_cvat2ultralytics.py b/tests/test_cvat2ultralytics.py new file mode 100644 index 0000000..0fb246a --- /dev/null +++ b/tests/test_cvat2ultralytics.py @@ -0,0 +1,159 @@ +import unittest +import sys +import os +from lxml import etree +import pandas as pd +import cv2 +from kabr_tools import cvat2ultralytics +from tests.utils import ( + del_dir, + del_file, + get_detection, + dir_exists, + file_exists +) + + +def run(): + cvat2ultralytics.main() + + +class TestCvat2Ultralytics(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download data + cls.video, cls.annotation = get_detection() + cls.dir = os.path.dirname(os.path.dirname(cls.video)) + + @classmethod + def tearDownClass(cls): + # delete data + del_file(cls.video) + del_file(cls.annotation) + + def setUp(self): + self.tool = "cvat2ultralytics.py" + self.video = TestCvat2Ultralytics.dir + self.annotation = TestCvat2Ultralytics.dir + self.dataset = "tests/ultralytics" + self.skip = "1" + self.label2index = "ethogram/label2index.json" + + def tearDown(self): + # delete outputs + del_dir(self.dataset) + + def test_run(self): + # run cvat2ultralytics + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation, + "--dataset", "tests/ultralytics", + "--skip", self.skip] + run() + + # check for output dirs + self.assertTrue(dir_exists(self.dataset)) + self.assertTrue(dir_exists(f"{self.dataset}/images/test")) + self.assertTrue(dir_exists(f"{self.dataset}/images/train")) + self.assertTrue(dir_exists(f"{self.dataset}/images/val")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/test")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/train")) + self.assertTrue(dir_exists(f"{self.dataset}/labels/val")) + + # check output + annotations = etree.parse(TestCvat2Ultralytics.annotation).getroot() + tracks = [list(track.findall("box")) for track in annotations.findall("track")] + self.assertEqual(len(tracks[0]), 21) + self.assertEqual(len(tracks[0]), len(tracks[1])) + original_size = annotations.find("meta").find("task").find("original_size") + height = int(original_size.find("height").text) + width = int(original_size.find("width").text) + for i in range(len(tracks[0])): + # check existence + if i < 16: + data_im = f"{self.dataset}/images/train/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/train/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + elif i < 18: + data_im = f"{self.dataset}/images/val/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/val/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + else: + data_im = f"{self.dataset}/images/test/DJI_0068_{i}.jpg" + self.assertTrue(file_exists(data_im)) + data_label = f"{self.dataset}/labels/test/DJI_0068_{i}.txt" + self.assertTrue(file_exists(data_label)) + + # check image + data_im = cv2.imread(data_im) + self.assertEqual(data_im.shape, (height, width, 3)) + + # check label + data_label = pd.read_csv(data_label, sep = " ", header = None) + annotation_label = [] + for track in tracks: + box = track[i] + x_start = float(box.attrib["xtl"]) + y_start = float(box.attrib["ytl"]) + x_end = float(box.attrib["xbr"]) + y_end = float(box.attrib["ybr"]) + x_center = (x_start + (x_end - x_start) / 2) / width + y_center = (y_start + (y_end - y_start) / 2) / height + w = (x_end - x_start) / width + h = (y_end - y_start) / height + annotation_label.append( + [0, x_center, y_center, w, h] + ) + self.assertEqual(len(data_label.index), len(annotation_label)) + + for i, label in enumerate(annotation_label): + self.assertEqual(label[0], annotation_label[i][0]) + self.assertAlmostEqual(label[1], annotation_label[i][1], places=4) + self.assertAlmostEqual(label[2], annotation_label[i][2], places=4) + self.assertAlmostEqual(label[3], annotation_label[i][3], places=4) + self.assertAlmostEqual(label[4], annotation_label[i][4], places=4) + + + def test_parse_arg_min(self): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation, + "--dataset", self.dataset] + args = cvat2ultralytics.parse_args() + + # check parsed argument values + self.assertEqual(args.video, self.video) + self.assertEqual(args.annotation, self.annotation) + self.assertEqual(args.dataset, self.dataset) + + # check default argument values + self.assertEqual(args.skip, 10) + self.assertEqual(args.label2index, None) + + # run cvat2ultralytics + run() + + def test_parse_arg_full(self): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation, + "--dataset", self.dataset, + "--skip", self.skip, + "--label2index", self.label2index] + args = cvat2ultralytics.parse_args() + + # check parsed argument values + self.assertEqual(args.video, self.video) + self.assertEqual(args.annotation, self.annotation) + self.assertEqual(args.dataset, self.dataset) + self.assertEqual(args.skip, int(self.skip)) + self.assertEqual(args.label2index, self.label2index) + + # run cvat2ultralytics + run() diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py new file mode 100644 index 0000000..b00a122 --- /dev/null +++ b/tests/test_detector2cvat.py @@ -0,0 +1,338 @@ +import unittest +import sys +import os +from lxml import etree +from unittest.mock import MagicMock, patch +import cv2 +from lxml import etree +import numpy as np +from kabr_tools import detector2cvat +from kabr_tools.utils.yolo import YOLOv8 +from tests.utils import ( + del_dir, + del_file, + file_exists, + get_detection +) + + +class DetectionData: + def __init__(self, video_dim, video_len, annotation): + self.video_dim = video_dim + self.video_len = video_len + self.frame = -1 + self.video_frame = np.zeros(video_dim, dtype=np.uint8) + + annotation = etree.parse(annotation).getroot() + self.tracks = [[0, list(track.findall("box")), + track.get("label")] + for track in annotation.findall("track")] + + def read(self): + if self.frame >= self.video_len - 1: + return False, None + self.frame += 1 + return True, self.video_frame + + def get(self, param): + if param == cv2.CAP_PROP_FRAME_COUNT: + return self.video_len + elif param == cv2.CAP_PROP_FRAME_HEIGHT: + return self.video_dim[0] + elif param == cv2.CAP_PROP_FRAME_WIDTH: + return self.video_dim[1] + else: + return None + + def forward(self, data): + soln = [] + for track in self.tracks: + ptr = track[0] + if ptr < len(track[1]): + box = track[1][ptr] + frame = int(box.get("frame")) + if frame == self.frame: + soln.append([[int(float(box.get("xtl"))), + int(float(box.get("ytl"))), + int(float(box.get("xbr"))), + int(float(box.get("ybr")))], + 0.95, + track[2]]) + track[0] += 1 + return soln + + +def run(): + detector2cvat.main() + + +class TestDetector2Cvat(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download data + cls.video, cls.annotation = get_detection() + cls.dir = os.path.dirname(cls.video) + + @classmethod + def tearDownClass(cls): + # delete data + del_file(cls.video) + del_file(cls.annotation) + del_dir(cls.dir) + + def setUp(self): + # set params + self.tool = "detector2cvat.py" + self.video = TestDetector2Cvat.dir + self.save = "tests/detector2cvat" + self.dir = "/".join(os.path.splitext(self.video)[0].split('/')[-2:]) + + def tearDown(self): + # delete outputs + del_dir(self.save) + + def test_run(self): + # check if tool runs on real data + save = f"{self.save}/run" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + @patch('kabr_tools.detector2cvat.YOLOv8') + def test_mock_yolo(self, yolo): + # create fake YOLO + yolo_instance = MagicMock() + yolo_instance.forward.return_value = [[[0, 0, 0, 0], 0.95, 'Grevy']] + yolo.get_centroid.return_value = (50, 50) + yolo.return_value = yolo_instance + + # run detector2cvat + save = f"{self.save}/mock/0" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + for j, track in enumerate(xml_content.findall("track")): + track_len = len(track.findall("box")) + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + # TODO: Check if source should be manual + self.assertEqual(track.get("source"), "manual") + for i, box in enumerate(track.findall("box")): + self.assertEqual(box.get("frame"), str(i)) + self.assertEqual(box.get("xtl"), "0.00") + self.assertEqual(box.get("ytl"), "0.00") + self.assertEqual(box.get("xbr"), "0.00") + self.assertEqual(box.get("ybr"), "0.00") + self.assertEqual(box.get("occluded"), "0") + self.assertEqual(box.get("keyframe"), "1") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if i == track_len - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + + # checkout output video + cap = cv2.VideoCapture(demo_path) + video = cv2.VideoCapture(TestDetector2Cvat.video) + self.assertTrue(cap.isOpened()) + self.assertTrue(video.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + video.get(cv2.CAP_PROP_FRAME_COUNT)) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + video.get(cv2.CAP_PROP_FRAME_WIDTH)) + cap.release() + video.release() + + @patch('kabr_tools.detector2cvat.YOLOv8') + @patch('kabr_tools.detector2cvat.cv2.VideoCapture') + def test_mock_with_data(self, video_capture, yolo): + # mock outputs CVAT data + ref_path = "tests/examples/MINISCENE1/metadata/DJI_tracks.xml" + height, width, frames = 3078, 5472, 21 + data = DetectionData((height, width, 3), + frames, + ref_path) + yolo_instance = MagicMock() + yolo_instance.forward = data.forward + yolo.return_value = yolo_instance + yolo.get_centroid = MagicMock( + side_effect=lambda pred: YOLOv8.get_centroid(pred)) + + vc = MagicMock() + vc.read = data.read + vc.get = data.get + video_capture.return_value = vc + + # run detector2cvat + save = f"{self.save}/mock/1" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + ref_content = etree.parse(ref_path).getroot() + ref_track = list(ref_content.findall("track")) + for j, track in enumerate(xml_content.findall("track")): + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + self.assertEqual(track.get("source"), "manual") + ref_box = list(ref_track[j].findall("box")) + for i, box in enumerate(track.findall("box")): + self.assertEqual(box.get("frame"), ref_box[i].get("frame")) + self.assertEqual(box.get("xtl"), ref_box[i].get("xtl")) + self.assertEqual(box.get("ytl"), ref_box[i].get("ytl")) + self.assertEqual(box.get("xbr"), ref_box[i].get("xbr")) + self.assertEqual(box.get("ybr"), ref_box[i].get("ybr")) + self.assertEqual(box.get("occluded"), "0") + self.assertEqual(box.get("keyframe"), "1") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if i == frames - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + + # checkout output video + cap = cv2.VideoCapture(demo_path) + self.assertTrue(cap.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + frames) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + height) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + width) + cap.release() + + @patch('kabr_tools.detector2cvat.YOLOv8') + @patch('kabr_tools.detector2cvat.cv2.VideoCapture') + def test_mock_noncontiguous(self, video_capture, yolo): + # mock outputs non-contiguous frame detections + ref_path = "tests/examples/DETECTOR1/DJI_tracks.xml" + height, width, frames = 3078, 5472, 31 + data = DetectionData((height, width, 3), + frames, + ref_path) + yolo_instance = MagicMock() + yolo_instance.forward = data.forward + yolo.return_value = yolo_instance + yolo.get_centroid = MagicMock( + side_effect=lambda pred: YOLOv8.get_centroid(pred)) + + vc = MagicMock() + vc.read = data.read + vc.get = data.get + video_capture.return_value = vc + + # run detector2cvat + save = f"{self.save}/mock/2" + sys.argv = [self.tool, + "--video", self.video, + "--save", save] + detector2cvat.main() + + # check output exists + output_path = f"{save}/{self.dir}/DJI_0068.xml" + self.assertTrue(file_exists(output_path)) + demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4" + self.assertTrue(file_exists(demo_path)) + + # check output xml + xml_content = etree.parse(output_path).getroot() + self.assertEqual(xml_content.tag, "annotations") + ref_content = etree.parse(ref_path).getroot() + ref_track = list(ref_content.findall("track")) + for j, track in enumerate(xml_content.findall("track")): + self.assertEqual(track.get("id"), str(j+1)) + self.assertEqual(track.get("label"), "Grevy") + self.assertEqual(track.get("source"), "manual") + ref_box = list(ref_track[j].findall("box")) + i = 0 + frame = int(track.find("box").get("frame")) + for box in track.findall("box"): + if box.get("frame") == ref_box[i+1].get("frame"): + i += 1 + print(box.get("frame"), ref_box[i].get("frame")) + self.assertEqual(box.get("frame"), str(frame)) + self.assertEqual(box.get("xtl"), ref_box[i].get("xtl")) + self.assertEqual(box.get("ytl"), ref_box[i].get("ytl")) + self.assertEqual(box.get("xbr"), ref_box[i].get("xbr")) + self.assertEqual(box.get("ybr"), ref_box[i].get("ybr")) + self.assertEqual(box.get("occluded"), "0") + if box.get("frame") == ref_box[i].get("frame"): + self.assertEqual(box.get("keyframe"), "1") + else: + self.assertEqual(box.get("keyframe"), "0") + self.assertEqual(box.get("z_order"), "0") + # tracker marks last box as outside + if frame == frames - 1: + self.assertEqual(box.get("outside"), "1") + else: + self.assertEqual(box.get("outside"), "0") + frame += 1 + + # checkout output video + cap = cv2.VideoCapture(demo_path) + self.assertTrue(cap.isOpened()) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT), + frames) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT), + height) + self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH), + width) + cap.release() + + def test_parse_arg_min(self): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--save", self.save] + args = detector2cvat.parse_args() + + # check parsed argument values + self.assertEqual(args.video, self.video) + self.assertEqual(args.save, self.save) + self.assertEqual(args.imshow, False) + + def test_parse_arg_full(self): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--save", self.save, + "--imshow"] + args = detector2cvat.parse_args() + + # check parsed argument values + self.assertEqual(args.video, self.video) + self.assertEqual(args.save, self.save) + self.assertEqual(args.imshow, True) diff --git a/tests/test_miniscene2behavior.py b/tests/test_miniscene2behavior.py new file mode 100644 index 0000000..35b589a --- /dev/null +++ b/tests/test_miniscene2behavior.py @@ -0,0 +1,268 @@ +import unittest +import zipfile +import sys +import os +import requests +from unittest.mock import Mock, patch +from lxml import etree +from ruamel.yaml import YAML +import torch +import numpy as np +import pandas as pd +from kabr_tools import ( + miniscene2behavior, + tracks_extractor +) +from kabr_tools.miniscene2behavior import annotate_miniscene +from tests.utils import ( + del_file, + del_dir, + clean_dir, + get_detection +) + + +TESTSDIR = os.path.dirname(os.path.realpath(__file__)) +EXAMPLESDIR = os.path.join(TESTSDIR, "examples") + + +def run(): + miniscene2behavior.main() + + +class TestMiniscene2Behavior(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download the model from Imageomics HF + cls.checkpoint = "checkpoint_epoch_00075.pyth" + cls.model_output = None + cls.download_model() + + # download data + cls.video, cls.annotation = get_detection() + + # extract mini-scene + sys.argv = ["tracks_extractor.py", + "--video", cls.video, + "--annotation", cls.annotation] + tracks_extractor.main() + cls.miniscene = f'mini-scenes/{os.path.splitext("|".join(cls.video.split("/")[-3:]))[0]}' + + @classmethod + def download_model(cls): + if not os.path.exists(cls.checkpoint): + # download checkpoint archive + url = "https://huggingface.co/imageomics/" \ + + "x3d-kabr-kinetics/resolve/main/" \ + + f"{cls.checkpoint}.zip" + r = requests.get(url, allow_redirects=True, timeout=120) + with open(f"{cls.checkpoint}.zip", "wb") as f: + f.write(r.content) + + # unzip model checkpoint + with zipfile.ZipFile(f"{cls.checkpoint}.zip", "r") as zip_ref: + zip_ref.extractall(".") + + # get checkpoint directory + try: + cfg = torch.load(cls.checkpoint, + weights_only=True, + map_location=torch.device("cpu"))["cfg"] + yaml = YAML(typ="rt") + cls.model_output = f"{yaml.load(cfg)['OUTPUT_DIR']}/checkpoints" + except Exception: + pass + + @classmethod + def tearDownClass(cls): + # remove model files after tests + del_file(f"{cls.checkpoint}.zip") + del_file(cls.checkpoint) + if cls.model_output: + clean_dir(cls.model_output) + + # remove data after tests + del_file(cls.video) + del_file(cls.annotation) + del_dir(cls.miniscene) + + def setUp(self): + self.tool = "miniscene2behavior.py" + self.checkpoint = "checkpoint_epoch_00075.pyth" + self.miniscene = TestMiniscene2Behavior.miniscene + self.video = "DJI_0068" + self.config = "special_config.yml" + self.gpu_num = "1" + self.output = "DJI_0068.csv" + + def tearDown(self): + # delete outputs + del_file(self.output) + + def test_run(self): + # download model + self.download_model() + + # annotate mini-scenes + sys.argv = [self.tool, + "--checkpoint", self.checkpoint, + "--miniscene", self.miniscene, + "--video", self.video, + "--output", self.output] + run() + + # check output CSV + df = pd.read_csv(self.output, sep=' ') + self.assertEqual(list(df.columns), [ + "video", "track", "frame", "label"]) + row_ct = 0 + + root = etree.parse( + f"{self.miniscene}/metadata/{self.video}_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row = list(df.loc[row_ct]) + self.assertEqual(row[0], self.video) + self.assertEqual(row[1], track_id) + self.assertEqual(row[2], int(box.get("frame"))) + self.assertTrue(row[3] >= 0) + self.assertTrue(row[3] <= 7) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) + + @patch('kabr_tools.miniscene2behavior.process_cv2_inputs') + @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture') + def test_matching_tracks(self, video_capture, process_cv2_inputs): + # create fake model that weights class 98 + mock_model = Mock() + prob = torch.zeros(99) + prob[-1] = 1 + mock_model.return_value = prob + + # create fake cfg + mock_config = Mock( + DATA=Mock(NUM_FRAMES=16, + SAMPLING_RATE=5, + TEST_CROP_SIZE=300), + NUM_GPUS=0, + OUTPUT_DIR='' + ) + + # create fake video capture + vc = video_capture.return_value + vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8) + vc.get.return_value = 21 + + self.output = '/tmp/annotation_data.csv' + miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE1") + video_name = "DJI" + + annotate_miniscene(cfg=mock_config, + model=mock_model, + miniscene_path=miniscene_dir, + video=video_name, + output_path=self.output) + + # check output CSV + df = pd.read_csv(self.output, sep=' ') + self.assertEqual(list(df.columns), [ + "video", "track", "frame", "label"]) + row_ct = 0 + + root = etree.parse( + f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row_val = [video_name, track_id, int(box.get("frame")), 98] + self.assertEqual(list(df.loc[row_ct]), row_val) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) + + @patch('kabr_tools.miniscene2behavior.process_cv2_inputs') + @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture') + def test_nonmatching_tracks(self, video_capture, process_cv2_inputs): + + # Create fake model that always returns a prediction of 1 + mock_model = Mock() + mock_model.return_value = torch.tensor([1]) + + # Create fake cfg + mock_config = Mock( + DATA=Mock(NUM_FRAMES=16, + SAMPLING_RATE=5, + TEST_CROP_SIZE=300), + NUM_GPUS=0, + OUTPUT_DIR='' + ) + + # Create fake video capture + vc = video_capture.return_value + vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8) + vc.get.return_value = 21 + + self.output = '/tmp/annotation_data.csv' + miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE2") + video_name = "DJI" + + annotate_miniscene(cfg=mock_config, + model=mock_model, + miniscene_path=os.path.join( + EXAMPLESDIR, "MINISCENE2"), + video='DJI', + output_path=self.output) + + # check output CSV + df = pd.read_csv(self.output, sep=' ') + self.assertEqual(list(df.columns), [ + "video", "track", "frame", "label"]) + row_ct = 0 + + root = etree.parse( + f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot() + for track in root.iterfind("track"): + track_id = int(track.get("id")) + for box in track.iterfind("box"): + row_val = [video_name, track_id, int(box.get("frame")), 0] + self.assertEqual(list(df.loc[row_ct]), row_val) + row_ct += 1 + self.assertEqual(len(df.index), row_ct) + + def test_parse_arg_min(self): + # parse arguments + sys.argv = [self.tool, + "--checkpoint", self.checkpoint, + "--miniscene", self.miniscene, + "--video", self.video] + args = miniscene2behavior.parse_args() + + # check parsed argument values + self.assertEqual(args.checkpoint, self.checkpoint) + self.assertEqual(args.miniscene, self.miniscene) + self.assertEqual(args.video, self.video) + + # check default argument values + self.assertEqual(args.config, "config.yml") + self.assertEqual(args.gpu_num, 0) + self.assertEqual(args.output, "annotation_data.csv") + + def test_parse_arg_full(self): + # parse arguments + sys.argv = [self.tool, + "--config", self.config, + "--checkpoint", self.checkpoint, + "--gpu_num", self.gpu_num, + "--miniscene", self.miniscene, + "--video", self.video, + "--output", self.output] + args = miniscene2behavior.parse_args() + + # check parsed argument values + self.assertEqual(args.config, self.config) + self.assertEqual(args.checkpoint, self.checkpoint) + self.assertEqual(args.gpu_num, 1) + self.assertEqual(args.miniscene, self.miniscene) + self.assertEqual(args.video, self.video) + self.assertEqual(args.output, self.output) diff --git a/tests/test_player.py b/tests/test_player.py new file mode 100644 index 0000000..091aeb0 --- /dev/null +++ b/tests/test_player.py @@ -0,0 +1,103 @@ +import unittest +import sys +import os +from unittest.mock import patch +from kabr_tools import player +from tests.utils import ( + del_file, + del_dir, + file_exists, + get_behavior +) + + +def run(): + player.main() + + +class TestPlayer(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download data + cls.video, cls.miniscene, cls.annotation, cls.metadata = get_behavior() + cls.dir = os.path.dirname(cls.video) + + @classmethod + def tearDownClass(cls): + # delete data + del_file(cls.video) + del_file(cls.miniscene) + del_file(cls.annotation) + del_file(cls.metadata) + del_dir(cls.dir) + + def setUp(self): + # set params + self.tool = "player.py" + self.folder = TestPlayer.dir + self.video = self.folder.rsplit("/", maxsplit=1)[-1] + + # delete output + del_file(f"{self.folder}/{self.video}_demo.mp4") + + def tearDown(self): + # delete output + del_file(f"{self.folder}/{self.video}_demo.mp4") + + @patch('kabr_tools.player.cv2.imshow') + @patch('kabr_tools.player.cv2.namedWindow') + @patch('kabr_tools.player.cv2.createTrackbar') + @patch('kabr_tools.player.cv2.setTrackbarPos') + @patch('kabr_tools.player.cv2.getTrackbarPos') + def test_run(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow): + # mock getTrackbarPos + getTrackbarPos.return_value = 0 + + # run player + sys.argv = [self.tool, + "--folder", self.folder, + "--save"] + run() + self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4")) + + @patch('kabr_tools.player.cv2.imshow') + @patch('kabr_tools.player.cv2.namedWindow') + @patch('kabr_tools.player.cv2.createTrackbar') + @patch('kabr_tools.player.cv2.setTrackbarPos') + @patch('kabr_tools.player.cv2.getTrackbarPos') + def test_parse_arg_min(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow): + # parse arguments + sys.argv = [self.tool, + "--folder", self.folder] + args = player.parse_args() + + # check parsed arguments + self.assertEqual(args.folder, self.folder) + + # check default arguments + self.assertEqual(args.save, False) + + # run player + run() + self.assertTrue(not file_exists(f"{self.folder}/{self.video}_demo.mp4")) + + @patch('kabr_tools.player.cv2.imshow') + @patch('kabr_tools.player.cv2.namedWindow') + @patch('kabr_tools.player.cv2.createTrackbar') + @patch('kabr_tools.player.cv2.setTrackbarPos') + @patch('kabr_tools.player.cv2.getTrackbarPos') + def test_parse_arg_full(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow): + # parse arguments + sys.argv = [self.tool, + "--folder", self.folder, + "--save", "--imshow"] + args = player.parse_args() + + # check parsed arguments + self.assertEqual(args.folder, self.folder) + self.assertEqual(args.save, True) + + # run player + run() + self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4")) diff --git a/tests/test_tracks_extractor.py b/tests/test_tracks_extractor.py new file mode 100644 index 0000000..71a7e9a --- /dev/null +++ b/tests/test_tracks_extractor.py @@ -0,0 +1,203 @@ +import unittest +import sys +import os +from unittest.mock import patch, Mock +import json +from lxml import etree +import cv2 +from kabr_tools import tracks_extractor +from kabr_tools.utils.tracker import Tracker +from kabr_tools.utils.detector import Detector +from kabr_tools.utils.utils import get_scene +from tests.utils import ( + get_detection, + dir_exists, + file_exists, + del_dir, + del_file +) + +# TODO: make constants for kabr tools (copied values in tracks_extractor.py) +scene_width, scene_height = 400, 300 + + +def run(): + tracks_extractor.main() + + +class TestTracksExtractor(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # download data + cls.video, cls.annotation = get_detection() + + @classmethod + def tearDownClass(cls): + # delete data + del_file(cls.video) + del_file(cls.annotation) + + def setUp(self): + # set params + self.tool = "tracks_extractor.py" + self.video = TestTracksExtractor.video + self.annotation = TestTracksExtractor.annotation + + # remove output directory + del_dir("mini-scenes") + + def tearDown(self): + # remove output directory + del_dir("mini-scenes") + + def test_run(self): + # run tracks_extractor + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation] + run() + + # check output exists + mini_folder = os.path.splitext("|".join(self.video.split("/")[-3:]))[0] + video_name = "DJI_0068" + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}")) + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/actions")) + self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/metadata")) + self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/0.mp4")) + self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/1.mp4")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/{video_name}.mp4")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml")) + self.assertTrue(file_exists( + f"mini-scenes/{mini_folder}/metadata/{video_name}.jpg")) + + # check metadata.json + root = etree.parse(self.annotation).getroot() + tracks = {"main": + [-1] * int("".join(root.find("meta").find("task").find("size").itertext()))} + for track in root.iterfind("track"): + track_id = track.attrib["id"] + tracks[track_id] = [] + for box in track.iter("box"): + frame_id = int(box.attrib["frame"]) + tracks[track_id].append(frame_id) + tracks["main"][frame_id] = frame_id + + colors = list(Tracker.colors_table.values()) + + with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json", + "r", encoding="utf-8") as f: + metadata = json.load(f) + self.assertTrue("original" in metadata) + self.assertTrue("tracks" in metadata) + self.assertTrue("colors" in metadata) + self.assertEqual(metadata["original"], self.video) + self.assertEqual(metadata["tracks"]["main"], tracks["main"]) + self.assertEqual(metadata["tracks"]["0"], tracks["0"]) + self.assertEqual(metadata["tracks"]["1"], tracks["1"]) + self.assertEqual(metadata["colors"]["0"], + list(colors[0 % len(colors)])) + self.assertEqual(metadata["colors"]["1"], + list(colors[1 % len(colors)])) + + # check tracks.xml + with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml", + "r", encoding="utf-8") as f: + track_copy = f.read() + + with open(self.annotation, "r", encoding="utf-8") as f: + track = f.read() + + self.assertEqual(track, track_copy) + + # check 0.mp4, 1.mp4 + root = etree.parse(self.annotation).getroot() + xml_tracks = {} + for track in root.findall("track"): + track_id = track.attrib["id"] + xml_tracks[track_id] = track + self.assertEqual(xml_tracks.keys(), {"0", "1"}) + + original = cv2.VideoCapture(self.video) + self.assertTrue(original.isOpened()) + mock = Mock() + + for track_id, xml_track in xml_tracks.items(): + track = cv2.VideoCapture( + f"mini-scenes/{mini_folder}/{track_id}.mp4") + self.assertTrue(track.isOpened()) + + for i, box in enumerate(xml_track.iter("box")): + original.set(cv2.CAP_PROP_POS_FRAMES, int(box.attrib["frame"])) + track.set(cv2.CAP_PROP_POS_FRAMES, i) + original_returned, original_frame = original.read() + track_returned, track_frame = track.read() + + self.assertTrue(original_returned) + self.assertTrue(track_returned) + + mock.box = [int(float(box.attrib["xtl"])), + int(float(box.attrib["ytl"])), + int(float(box.attrib["xbr"])), + int(float(box.attrib["ybr"]))] + mock.centroid = Detector.get_centroid(mock.box) + original_frame = get_scene( + original_frame, mock, scene_width, scene_height) + + # encoding seems to add some noise to frames, allow for that + self.assertTrue( + cv2.norm(original_frame - track_frame) < 1e6) + track.release() + + # check DJI_0068.mp4 + copy = cv2.VideoCapture(f"mini-scenes/{mini_folder}/{video_name}.mp4") + self.assertTrue(copy.isOpened()) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_COUNT), + original.get(cv2.CAP_PROP_FRAME_COUNT)) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_WIDTH), + original.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_HEIGHT), + original.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + original.release() + + def test_parse_arg_min(self): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation] + args = tracks_extractor.parse_args() + + # check parsed arguments + self.assertEqual(args.video, self.video) + self.assertEqual(args.annotation, self.annotation) + + # check default arguments + self.assertEqual(args.tracking, False) + self.assertEqual(args.imshow, False) + + # run tracks_extractor + run() + + @patch('kabr_tools.tracks_extractor.cv2.imshow') + def test_parse_arg_full(self, imshow): + # parse arguments + sys.argv = [self.tool, + "--video", self.video, + "--annotation", self.annotation, + "--tracking", + "--imshow"] + args = tracks_extractor.parse_args() + + # check parsed arguments + self.assertEqual(args.video, self.video) + self.assertEqual(args.annotation, self.annotation) + self.assertEqual(args.tracking, True) + self.assertEqual(args.imshow, True) + + # run tracks_extractor + run() diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..b8d3004 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,52 @@ +import os +import shutil +from pathlib import Path +from huggingface_hub import hf_hub_download + +DATA_HUB = "imageomics/kabr_testing" +REPO_TYPE = "dataset" + +DETECTION_VIDEO = "DJI_0068/DJI_0068.mp4" +DETECTION_ANNOTATION = "DJI_0068/DJI_0068.xml" + +BEHAVIOR_VIDEO = "DJI_0001/DJI_0001.mp4" +BEHAVIOR_MINISCENE = "DJI_0001/43.mp4" +BEHAVIOR_ANNOTATION = "DJI_0001/actions/43.xml" +BEHAVIOR_METADATA = "DJI_0001/metadata/DJI_0001_metadata.json" + +def get_hf(repo_id: str, filename: str, repo_type: str): + return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type) + +def get_cached_datafile(filename: str): + return get_hf(DATA_HUB, filename, REPO_TYPE) + +def get_behavior(): + video = get_cached_datafile(BEHAVIOR_VIDEO) + miniscene = get_cached_datafile(BEHAVIOR_MINISCENE) + annotation = get_cached_datafile(BEHAVIOR_ANNOTATION) + metadata = get_cached_datafile(BEHAVIOR_METADATA) + return video, miniscene, annotation, metadata + +def get_detection(): + video = get_cached_datafile(DETECTION_VIDEO) + annotation = get_cached_datafile(DETECTION_ANNOTATION) + return video, annotation + +def clean_dir(path): + if os.path.exists(path): + os.removedirs(path) + +def del_dir(path): + if os.path.exists(path): + shutil.rmtree(path) + +def del_file(path): + if os.path.exists(path): + os.remove(path) + +def file_exists(path): + return Path(path).is_file() + + +def dir_exists(path): + return Path(path).is_dir()