diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..a4cdadd
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,34 @@
+name: Test
+
+on:
+ push:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.10"]
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip setuptools wheel
+ pip install -r requirements.txt
+ python -m pip install hatchling
+ python -m pip install --no-build-isolation .
+ - name: Running unit tests
+ env:
+ HF_TOKEN: ${{ secrets.HF_TOKEN_TESTING }}
+ run: |
+ python -m unittest tests/test_cvat2slowfast.py
+ python -m unittest tests/test_cvat2ultralytics.py
+ python -m unittest tests/test_detector2cvat.py
+ python -m unittest tests/test_miniscene2behavior.py
+ python -m unittest tests/test_player.py
+ python -m unittest tests/test_tracks_extractor.py
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 36d016c..77b27f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -177,5 +177,18 @@ cython_debug/
# Mac System
.DS_Store
+# Tool output
+*.json
+*.xml
+*.jpg
+*.yaml
+*.csv
+*.txt
+
+# Model files
+*.pyth
+*.pyth.zip
+*.yml
+
helper_scripts/mini-scenes
\ No newline at end of file
diff --git a/README.md b/README.md
index 6228bec..d26f71f 100644
--- a/README.md
+++ b/README.md
@@ -170,6 +170,6 @@ player --folder path_to_folder [--save] [--imshow]
```
-cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json]
+cvat2slowfast --miniscene path_to_mini_scenes --dataset dataset_name --classes path_to_classes_json [--old2new path_to_old2new_json] [--no_images]
```
diff --git a/ethogram/label2index.json b/ethogram/label2index.json
new file mode 100644
index 0000000..a110b82
--- /dev/null
+++ b/ethogram/label2index.json
@@ -0,0 +1,6 @@
+{
+ "Grevy": 0,
+ "Zebra": 0,
+ "Baboon": 1,
+ "Giraffe": 2
+}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 46ea954..a9b7d35 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,3 +10,4 @@ ultralytics~=8.0.36
pandas>=1.3.5
pillow==10.4.0
scikit-learn==1.5.1
+huggingface_hub
diff --git a/src/kabr_tools/cvat2ultralytics.py b/src/kabr_tools/cvat2ultralytics.py
index 8dd14d8..073108b 100644
--- a/src/kabr_tools/cvat2ultralytics.py
+++ b/src/kabr_tools/cvat2ultralytics.py
@@ -3,7 +3,7 @@
import argparse
import json
import cv2
-import ruamel.yaml as yaml
+from ruamel.yaml import YAML
from lxml import etree
from collections import OrderedDict
from tqdm import tqdm
@@ -39,8 +39,10 @@ def cvat2ultralytics(video_path: str, annotation_path: str,
shutil.rmtree(f"{dataset}")
with open(f"{dataset}.yaml", "w") as file:
- yaml.dump(yaml.load(dataset_file, Loader=yaml.RoundTripLoader, preserve_quotes=True),
- file, Dumper=yaml.RoundTripDumper)
+ yaml = YAML(typ='rt')
+ yaml.preserve_quotes = True
+ data = yaml.load(dataset_file)
+ yaml.dump(data, file)
if not os.path.exists(f"{dataset}/images/train"):
os.makedirs(f"{dataset}/images/train")
@@ -57,6 +59,7 @@ def cvat2ultralytics(video_path: str, annotation_path: str,
if label2index is None:
label2index = {
+ "Grevy": 0,
"Zebra": 0,
"Baboon": 1,
"Giraffe": 2
@@ -69,21 +72,24 @@ def cvat2ultralytics(video_path: str, annotation_path: str,
for root, dirs, files in os.walk(annotation_path):
for file in files:
video_name = os.path.join(video_path + root[len(annotation_path):], os.path.splitext(file)[0])
-
- if os.path.exists(video_name + ".MP4"):
- videos.append(video_name + ".MP4")
- else:
- videos.append(video_name + ".mp4")
-
- annotations.append(os.path.join(root, file))
+ if file.endswith(".xml"):
+ if os.path.exists(video_name + ".MP4"):
+ videos.append(video_name + ".MP4")
+ else:
+ videos.append(video_name + ".mp4")
+ annotations.append(os.path.join(root, file))
for i, (video, annotation) in enumerate(zip(videos, annotations)):
- print(f"{i + 1}/{len(annotations)}:")
+ print(f"{i + 1}/{len(annotations)}:", flush=True)
if not os.path.exists(video):
print(f"Path {video} does not exist.")
continue
+ if not os.path.exists(annotation):
+ print(f"Path {annotation} does not exist.")
+ continue
+
# Parse CVAT for video 1.1 annotation file.
root = etree.parse(annotation).getroot()
name = os.path.splitext(video.split("/")[-1])[0]
diff --git a/src/kabr_tools/miniscene2behavior.py b/src/kabr_tools/miniscene2behavior.py
index 8347350..4ce2fc1 100644
--- a/src/kabr_tools/miniscene2behavior.py
+++ b/src/kabr_tools/miniscene2behavior.py
@@ -19,12 +19,15 @@ def get_input_clip(cap: cv2.VideoCapture, cfg: CfgNode, keyframe_idx: int) -> li
# https://github.com/facebookresearch/SlowFast/blob/bac7b672f40d44166a84e8c51d1a5ba367ace816/slowfast/visualization/ava_demo_precomputed_boxes.py
seq_length = cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ assert keyframe_idx < total_frames, f"keyframe_idx: {keyframe_idx}" \
+ f" >= total_frames: {total_frames}"
seq = get_sequence(
keyframe_idx,
seq_length // 2,
cfg.DATA.SAMPLING_RATE,
total_frames,
)
+
clip = []
for frame_idx in seq:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
@@ -124,29 +127,34 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module,
# find all tracks
tracks = []
+ frames = {}
for track in root.iterfind("track"):
track_id = track.attrib["id"]
tracks.append(track_id)
+ frames[track_id] = []
- # find all frames
- # TODO: rewrite - some tracks may have different frames
- assert len(tracks) > 0, "No tracks found in track file"
- frames = []
- for box in track.iterfind("box"):
- frames.append(int(box.attrib["frame"]))
+ # find all frames
+ for box in track.iterfind("box"):
+ frames[track_id].append(int(box.attrib["frame"]))
# run model on miniscene
for track in tracks:
video_file = f"{miniscene_path}/{track}.mp4"
cap = cv2.VideoCapture(video_file)
- for frame in tqdm(frames, desc=f"{track} frames"):
- inputs = get_input_clip(cap, cfg, frame)
+ index = 0
+ for frame in tqdm(frames[track], desc=f"{track} frames"):
+ try:
+ inputs = get_input_clip(cap, cfg, index)
+ except AssertionError as e:
+ print(e)
+ break
+ index += 1
if cfg.NUM_GPUS:
# transfer the data to the current GPU device.
if isinstance(inputs, (list,)):
- for i in range(len(inputs)):
- inputs[i] = inputs[i].cuda(non_blocking=True)
+ for i, input_clip in enumerate(inputs):
+ inputs[i] = input_clip.cuda(non_blocking=True)
else:
inputs = inputs.cuda(non_blocking=True)
@@ -163,6 +171,7 @@ def annotate_miniscene(cfg: CfgNode, model: torch.nn.Module,
if frame % 20 == 0:
pd.DataFrame(label_data).to_csv(
output_path, sep=" ", index=False)
+ cap.release()
pd.DataFrame(label_data).to_csv(output_path, sep=" ", index=False)
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/examples/DETECTOR1/DJI_tracks.xml b/tests/examples/DETECTOR1/DJI_tracks.xml
new file mode 100644
index 0000000..22c734f
--- /dev/null
+++ b/tests/examples/DETECTOR1/DJI_tracks.xml
@@ -0,0 +1,49 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/examples/MINISCENE1/metadata/DJI_tracks.xml b/tests/examples/MINISCENE1/metadata/DJI_tracks.xml
new file mode 100644
index 0000000..ee92dbd
--- /dev/null
+++ b/tests/examples/MINISCENE1/metadata/DJI_tracks.xml
@@ -0,0 +1,49 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/examples/MINISCENE2/metadata/DJI_tracks.xml b/tests/examples/MINISCENE2/metadata/DJI_tracks.xml
new file mode 100644
index 0000000..89696a4
--- /dev/null
+++ b/tests/examples/MINISCENE2/metadata/DJI_tracks.xml
@@ -0,0 +1,29 @@
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/test_cvat2slowfast.py b/tests/test_cvat2slowfast.py
new file mode 100644
index 0000000..1c4bf62
--- /dev/null
+++ b/tests/test_cvat2slowfast.py
@@ -0,0 +1,128 @@
+import unittest
+import sys
+import os
+import json
+import pandas as pd
+import cv2
+from kabr_tools import cvat2slowfast
+from tests.test_tracks_extractor import (
+ scene_width,
+ scene_height
+)
+from tests.utils import (
+ del_dir,
+ del_file,
+ dir_exists,
+ file_exists,
+ get_behavior
+)
+
+
+def run():
+ cvat2slowfast.main()
+
+
+class TestCvat2Slowfast(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download data
+ cls.video, cls.miniscene, cls.annotation, cls.metadata = get_behavior()
+ cls.dir = os.path.dirname(os.path.dirname(cls.video))
+
+ @classmethod
+ def tearDownClass(cls):
+ # delete data
+ del_file(cls.video)
+ del_file(cls.miniscene)
+ del_file(cls.annotation)
+ del_file(cls.metadata)
+ del_dir(cls.dir)
+
+ def setUp(self):
+ # set params
+ self.tool = "cvat2slowfast.py"
+ self.miniscene = TestCvat2Slowfast.dir
+ self.dataset = "tests/slowfast"
+ self.classes = "ethogram/classes.json"
+ self.old2new = "ethogram/old2new.json"
+
+ def tearDown(self):
+ # delete outputs
+ del_dir(self.dataset)
+
+ def test_run(self):
+ # run cvat2slowfast
+ sys.argv = [self.tool,
+ "--miniscene", self.miniscene,
+ "--dataset", self.dataset,
+ "--classes", self.classes]
+ run()
+
+ # check output dirs
+ self.assertTrue(dir_exists(self.dataset))
+ self.assertTrue(dir_exists(f"{self.dataset}/annotation"))
+ self.assertTrue(dir_exists(f"{self.dataset}/dataset/image"))
+ self.assertTrue(file_exists(f"{self.dataset}/annotation/classes.json"))
+ self.assertTrue(file_exists(f"{self.dataset}/annotation/data.csv"))
+
+ # check classes.json
+ with open(f"{self.dataset}/annotation/classes.json", "r", encoding="utf-8") as f:
+ classes = json.load(f)
+ with open(self.classes, "r", encoding="utf-8") as f:
+ ethogram = json.load(f)
+ self.assertEqual(classes, ethogram)
+
+ # check data.csv
+ with open(f"{self.dataset}/annotation/data.csv", "r", encoding="utf-8") as f:
+ df = pd.read_csv(f, sep=" ")
+
+ video_id = 1
+ for i, row in df.iterrows():
+ self.assertEqual(row["original_vido_id"], f"Z{video_id:04d}")
+ self.assertEqual(row["video_id"], video_id)
+ self.assertEqual(row["frame_id"], i+1)
+ self.assertEqual(row["path"], f"Z{video_id:04d}/{i+1}.jpg")
+ self.assertEqual(row["labels"], 1)
+ self.assertEqual(i, 90)
+
+ # check dataset
+ for i in range(1, 92):
+ data_im = f"{self.dataset}/dataset/image/Z{video_id:04d}/{i}.jpg"
+ self.assertTrue(file_exists(data_im))
+ data_im = cv2.imread(data_im)
+ self.assertEqual(data_im.shape, (scene_height, scene_width, 3))
+
+ def test_parse_arg_min(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--miniscene", self.miniscene,
+ "--dataset", self.dataset,
+ "--classes", self.classes]
+ args = cvat2slowfast.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.miniscene, self.miniscene)
+ self.assertEqual(args.dataset, self.dataset)
+ self.assertEqual(args.classes, self.classes)
+
+ # check default argument values
+ self.assertEqual(args.old2new, None)
+ self.assertTrue(not args.no_images)
+
+ def test_parse_arg_full(self):
+ # parse arguments
+ sys.argv = ["cvat2slowfast.py",
+ "--miniscene", self.miniscene,
+ "--dataset", self.dataset,
+ "--classes", self.classes,
+ "--old2new", self.old2new,
+ "--no_images"]
+ args = cvat2slowfast.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.miniscene, self.miniscene)
+ self.assertEqual(args.dataset, self.dataset)
+ self.assertEqual(args.classes, self.classes)
+ self.assertEqual(args.old2new, self.old2new)
+ self.assertTrue(args.no_images)
diff --git a/tests/test_cvat2ultralytics.py b/tests/test_cvat2ultralytics.py
new file mode 100644
index 0000000..0fb246a
--- /dev/null
+++ b/tests/test_cvat2ultralytics.py
@@ -0,0 +1,159 @@
+import unittest
+import sys
+import os
+from lxml import etree
+import pandas as pd
+import cv2
+from kabr_tools import cvat2ultralytics
+from tests.utils import (
+ del_dir,
+ del_file,
+ get_detection,
+ dir_exists,
+ file_exists
+)
+
+
+def run():
+ cvat2ultralytics.main()
+
+
+class TestCvat2Ultralytics(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download data
+ cls.video, cls.annotation = get_detection()
+ cls.dir = os.path.dirname(os.path.dirname(cls.video))
+
+ @classmethod
+ def tearDownClass(cls):
+ # delete data
+ del_file(cls.video)
+ del_file(cls.annotation)
+
+ def setUp(self):
+ self.tool = "cvat2ultralytics.py"
+ self.video = TestCvat2Ultralytics.dir
+ self.annotation = TestCvat2Ultralytics.dir
+ self.dataset = "tests/ultralytics"
+ self.skip = "1"
+ self.label2index = "ethogram/label2index.json"
+
+ def tearDown(self):
+ # delete outputs
+ del_dir(self.dataset)
+
+ def test_run(self):
+ # run cvat2ultralytics
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation,
+ "--dataset", "tests/ultralytics",
+ "--skip", self.skip]
+ run()
+
+ # check for output dirs
+ self.assertTrue(dir_exists(self.dataset))
+ self.assertTrue(dir_exists(f"{self.dataset}/images/test"))
+ self.assertTrue(dir_exists(f"{self.dataset}/images/train"))
+ self.assertTrue(dir_exists(f"{self.dataset}/images/val"))
+ self.assertTrue(dir_exists(f"{self.dataset}/labels/test"))
+ self.assertTrue(dir_exists(f"{self.dataset}/labels/train"))
+ self.assertTrue(dir_exists(f"{self.dataset}/labels/val"))
+
+ # check output
+ annotations = etree.parse(TestCvat2Ultralytics.annotation).getroot()
+ tracks = [list(track.findall("box")) for track in annotations.findall("track")]
+ self.assertEqual(len(tracks[0]), 21)
+ self.assertEqual(len(tracks[0]), len(tracks[1]))
+ original_size = annotations.find("meta").find("task").find("original_size")
+ height = int(original_size.find("height").text)
+ width = int(original_size.find("width").text)
+ for i in range(len(tracks[0])):
+ # check existence
+ if i < 16:
+ data_im = f"{self.dataset}/images/train/DJI_0068_{i}.jpg"
+ self.assertTrue(file_exists(data_im))
+ data_label = f"{self.dataset}/labels/train/DJI_0068_{i}.txt"
+ self.assertTrue(file_exists(data_label))
+ elif i < 18:
+ data_im = f"{self.dataset}/images/val/DJI_0068_{i}.jpg"
+ self.assertTrue(file_exists(data_im))
+ data_label = f"{self.dataset}/labels/val/DJI_0068_{i}.txt"
+ self.assertTrue(file_exists(data_label))
+ else:
+ data_im = f"{self.dataset}/images/test/DJI_0068_{i}.jpg"
+ self.assertTrue(file_exists(data_im))
+ data_label = f"{self.dataset}/labels/test/DJI_0068_{i}.txt"
+ self.assertTrue(file_exists(data_label))
+
+ # check image
+ data_im = cv2.imread(data_im)
+ self.assertEqual(data_im.shape, (height, width, 3))
+
+ # check label
+ data_label = pd.read_csv(data_label, sep = " ", header = None)
+ annotation_label = []
+ for track in tracks:
+ box = track[i]
+ x_start = float(box.attrib["xtl"])
+ y_start = float(box.attrib["ytl"])
+ x_end = float(box.attrib["xbr"])
+ y_end = float(box.attrib["ybr"])
+ x_center = (x_start + (x_end - x_start) / 2) / width
+ y_center = (y_start + (y_end - y_start) / 2) / height
+ w = (x_end - x_start) / width
+ h = (y_end - y_start) / height
+ annotation_label.append(
+ [0, x_center, y_center, w, h]
+ )
+ self.assertEqual(len(data_label.index), len(annotation_label))
+
+ for i, label in enumerate(annotation_label):
+ self.assertEqual(label[0], annotation_label[i][0])
+ self.assertAlmostEqual(label[1], annotation_label[i][1], places=4)
+ self.assertAlmostEqual(label[2], annotation_label[i][2], places=4)
+ self.assertAlmostEqual(label[3], annotation_label[i][3], places=4)
+ self.assertAlmostEqual(label[4], annotation_label[i][4], places=4)
+
+
+ def test_parse_arg_min(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation,
+ "--dataset", self.dataset]
+ args = cvat2ultralytics.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.annotation, self.annotation)
+ self.assertEqual(args.dataset, self.dataset)
+
+ # check default argument values
+ self.assertEqual(args.skip, 10)
+ self.assertEqual(args.label2index, None)
+
+ # run cvat2ultralytics
+ run()
+
+ def test_parse_arg_full(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation,
+ "--dataset", self.dataset,
+ "--skip", self.skip,
+ "--label2index", self.label2index]
+ args = cvat2ultralytics.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.annotation, self.annotation)
+ self.assertEqual(args.dataset, self.dataset)
+ self.assertEqual(args.skip, int(self.skip))
+ self.assertEqual(args.label2index, self.label2index)
+
+ # run cvat2ultralytics
+ run()
diff --git a/tests/test_detector2cvat.py b/tests/test_detector2cvat.py
new file mode 100644
index 0000000..b00a122
--- /dev/null
+++ b/tests/test_detector2cvat.py
@@ -0,0 +1,338 @@
+import unittest
+import sys
+import os
+from lxml import etree
+from unittest.mock import MagicMock, patch
+import cv2
+from lxml import etree
+import numpy as np
+from kabr_tools import detector2cvat
+from kabr_tools.utils.yolo import YOLOv8
+from tests.utils import (
+ del_dir,
+ del_file,
+ file_exists,
+ get_detection
+)
+
+
+class DetectionData:
+ def __init__(self, video_dim, video_len, annotation):
+ self.video_dim = video_dim
+ self.video_len = video_len
+ self.frame = -1
+ self.video_frame = np.zeros(video_dim, dtype=np.uint8)
+
+ annotation = etree.parse(annotation).getroot()
+ self.tracks = [[0, list(track.findall("box")),
+ track.get("label")]
+ for track in annotation.findall("track")]
+
+ def read(self):
+ if self.frame >= self.video_len - 1:
+ return False, None
+ self.frame += 1
+ return True, self.video_frame
+
+ def get(self, param):
+ if param == cv2.CAP_PROP_FRAME_COUNT:
+ return self.video_len
+ elif param == cv2.CAP_PROP_FRAME_HEIGHT:
+ return self.video_dim[0]
+ elif param == cv2.CAP_PROP_FRAME_WIDTH:
+ return self.video_dim[1]
+ else:
+ return None
+
+ def forward(self, data):
+ soln = []
+ for track in self.tracks:
+ ptr = track[0]
+ if ptr < len(track[1]):
+ box = track[1][ptr]
+ frame = int(box.get("frame"))
+ if frame == self.frame:
+ soln.append([[int(float(box.get("xtl"))),
+ int(float(box.get("ytl"))),
+ int(float(box.get("xbr"))),
+ int(float(box.get("ybr")))],
+ 0.95,
+ track[2]])
+ track[0] += 1
+ return soln
+
+
+def run():
+ detector2cvat.main()
+
+
+class TestDetector2Cvat(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download data
+ cls.video, cls.annotation = get_detection()
+ cls.dir = os.path.dirname(cls.video)
+
+ @classmethod
+ def tearDownClass(cls):
+ # delete data
+ del_file(cls.video)
+ del_file(cls.annotation)
+ del_dir(cls.dir)
+
+ def setUp(self):
+ # set params
+ self.tool = "detector2cvat.py"
+ self.video = TestDetector2Cvat.dir
+ self.save = "tests/detector2cvat"
+ self.dir = "/".join(os.path.splitext(self.video)[0].split('/')[-2:])
+
+ def tearDown(self):
+ # delete outputs
+ del_dir(self.save)
+
+ def test_run(self):
+ # check if tool runs on real data
+ save = f"{self.save}/run"
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", save]
+ detector2cvat.main()
+
+ # check output exists
+ output_path = f"{save}/{self.dir}/DJI_0068.xml"
+ self.assertTrue(file_exists(output_path))
+ demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4"
+ self.assertTrue(file_exists(demo_path))
+
+ @patch('kabr_tools.detector2cvat.YOLOv8')
+ def test_mock_yolo(self, yolo):
+ # create fake YOLO
+ yolo_instance = MagicMock()
+ yolo_instance.forward.return_value = [[[0, 0, 0, 0], 0.95, 'Grevy']]
+ yolo.get_centroid.return_value = (50, 50)
+ yolo.return_value = yolo_instance
+
+ # run detector2cvat
+ save = f"{self.save}/mock/0"
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", save]
+ detector2cvat.main()
+
+ # check output exists
+ output_path = f"{save}/{self.dir}/DJI_0068.xml"
+ self.assertTrue(file_exists(output_path))
+ demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4"
+ self.assertTrue(file_exists(demo_path))
+
+ # check output xml
+ xml_content = etree.parse(output_path).getroot()
+ self.assertEqual(xml_content.tag, "annotations")
+ for j, track in enumerate(xml_content.findall("track")):
+ track_len = len(track.findall("box"))
+ self.assertEqual(track.get("id"), str(j+1))
+ self.assertEqual(track.get("label"), "Grevy")
+ # TODO: Check if source should be manual
+ self.assertEqual(track.get("source"), "manual")
+ for i, box in enumerate(track.findall("box")):
+ self.assertEqual(box.get("frame"), str(i))
+ self.assertEqual(box.get("xtl"), "0.00")
+ self.assertEqual(box.get("ytl"), "0.00")
+ self.assertEqual(box.get("xbr"), "0.00")
+ self.assertEqual(box.get("ybr"), "0.00")
+ self.assertEqual(box.get("occluded"), "0")
+ self.assertEqual(box.get("keyframe"), "1")
+ self.assertEqual(box.get("z_order"), "0")
+ # tracker marks last box as outside
+ if i == track_len - 1:
+ self.assertEqual(box.get("outside"), "1")
+ else:
+ self.assertEqual(box.get("outside"), "0")
+
+ # checkout output video
+ cap = cv2.VideoCapture(demo_path)
+ video = cv2.VideoCapture(TestDetector2Cvat.video)
+ self.assertTrue(cap.isOpened())
+ self.assertTrue(video.isOpened())
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT),
+ video.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT),
+ video.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH),
+ video.get(cv2.CAP_PROP_FRAME_WIDTH))
+ cap.release()
+ video.release()
+
+ @patch('kabr_tools.detector2cvat.YOLOv8')
+ @patch('kabr_tools.detector2cvat.cv2.VideoCapture')
+ def test_mock_with_data(self, video_capture, yolo):
+ # mock outputs CVAT data
+ ref_path = "tests/examples/MINISCENE1/metadata/DJI_tracks.xml"
+ height, width, frames = 3078, 5472, 21
+ data = DetectionData((height, width, 3),
+ frames,
+ ref_path)
+ yolo_instance = MagicMock()
+ yolo_instance.forward = data.forward
+ yolo.return_value = yolo_instance
+ yolo.get_centroid = MagicMock(
+ side_effect=lambda pred: YOLOv8.get_centroid(pred))
+
+ vc = MagicMock()
+ vc.read = data.read
+ vc.get = data.get
+ video_capture.return_value = vc
+
+ # run detector2cvat
+ save = f"{self.save}/mock/1"
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", save]
+ detector2cvat.main()
+
+ # check output exists
+ output_path = f"{save}/{self.dir}/DJI_0068.xml"
+ self.assertTrue(file_exists(output_path))
+ demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4"
+ self.assertTrue(file_exists(demo_path))
+
+ # check output xml
+ xml_content = etree.parse(output_path).getroot()
+ self.assertEqual(xml_content.tag, "annotations")
+ ref_content = etree.parse(ref_path).getroot()
+ ref_track = list(ref_content.findall("track"))
+ for j, track in enumerate(xml_content.findall("track")):
+ self.assertEqual(track.get("id"), str(j+1))
+ self.assertEqual(track.get("label"), "Grevy")
+ self.assertEqual(track.get("source"), "manual")
+ ref_box = list(ref_track[j].findall("box"))
+ for i, box in enumerate(track.findall("box")):
+ self.assertEqual(box.get("frame"), ref_box[i].get("frame"))
+ self.assertEqual(box.get("xtl"), ref_box[i].get("xtl"))
+ self.assertEqual(box.get("ytl"), ref_box[i].get("ytl"))
+ self.assertEqual(box.get("xbr"), ref_box[i].get("xbr"))
+ self.assertEqual(box.get("ybr"), ref_box[i].get("ybr"))
+ self.assertEqual(box.get("occluded"), "0")
+ self.assertEqual(box.get("keyframe"), "1")
+ self.assertEqual(box.get("z_order"), "0")
+ # tracker marks last box as outside
+ if i == frames - 1:
+ self.assertEqual(box.get("outside"), "1")
+ else:
+ self.assertEqual(box.get("outside"), "0")
+
+ # checkout output video
+ cap = cv2.VideoCapture(demo_path)
+ self.assertTrue(cap.isOpened())
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT),
+ frames)
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT),
+ height)
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH),
+ width)
+ cap.release()
+
+ @patch('kabr_tools.detector2cvat.YOLOv8')
+ @patch('kabr_tools.detector2cvat.cv2.VideoCapture')
+ def test_mock_noncontiguous(self, video_capture, yolo):
+ # mock outputs non-contiguous frame detections
+ ref_path = "tests/examples/DETECTOR1/DJI_tracks.xml"
+ height, width, frames = 3078, 5472, 31
+ data = DetectionData((height, width, 3),
+ frames,
+ ref_path)
+ yolo_instance = MagicMock()
+ yolo_instance.forward = data.forward
+ yolo.return_value = yolo_instance
+ yolo.get_centroid = MagicMock(
+ side_effect=lambda pred: YOLOv8.get_centroid(pred))
+
+ vc = MagicMock()
+ vc.read = data.read
+ vc.get = data.get
+ video_capture.return_value = vc
+
+ # run detector2cvat
+ save = f"{self.save}/mock/2"
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", save]
+ detector2cvat.main()
+
+ # check output exists
+ output_path = f"{save}/{self.dir}/DJI_0068.xml"
+ self.assertTrue(file_exists(output_path))
+ demo_path = f"{save}/{self.dir}/DJI_0068_demo.mp4"
+ self.assertTrue(file_exists(demo_path))
+
+ # check output xml
+ xml_content = etree.parse(output_path).getroot()
+ self.assertEqual(xml_content.tag, "annotations")
+ ref_content = etree.parse(ref_path).getroot()
+ ref_track = list(ref_content.findall("track"))
+ for j, track in enumerate(xml_content.findall("track")):
+ self.assertEqual(track.get("id"), str(j+1))
+ self.assertEqual(track.get("label"), "Grevy")
+ self.assertEqual(track.get("source"), "manual")
+ ref_box = list(ref_track[j].findall("box"))
+ i = 0
+ frame = int(track.find("box").get("frame"))
+ for box in track.findall("box"):
+ if box.get("frame") == ref_box[i+1].get("frame"):
+ i += 1
+ print(box.get("frame"), ref_box[i].get("frame"))
+ self.assertEqual(box.get("frame"), str(frame))
+ self.assertEqual(box.get("xtl"), ref_box[i].get("xtl"))
+ self.assertEqual(box.get("ytl"), ref_box[i].get("ytl"))
+ self.assertEqual(box.get("xbr"), ref_box[i].get("xbr"))
+ self.assertEqual(box.get("ybr"), ref_box[i].get("ybr"))
+ self.assertEqual(box.get("occluded"), "0")
+ if box.get("frame") == ref_box[i].get("frame"):
+ self.assertEqual(box.get("keyframe"), "1")
+ else:
+ self.assertEqual(box.get("keyframe"), "0")
+ self.assertEqual(box.get("z_order"), "0")
+ # tracker marks last box as outside
+ if frame == frames - 1:
+ self.assertEqual(box.get("outside"), "1")
+ else:
+ self.assertEqual(box.get("outside"), "0")
+ frame += 1
+
+ # checkout output video
+ cap = cv2.VideoCapture(demo_path)
+ self.assertTrue(cap.isOpened())
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_COUNT),
+ frames)
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_HEIGHT),
+ height)
+ self.assertEqual(cap.get(cv2.CAP_PROP_FRAME_WIDTH),
+ width)
+ cap.release()
+
+ def test_parse_arg_min(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", self.save]
+ args = detector2cvat.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.save, self.save)
+ self.assertEqual(args.imshow, False)
+
+ def test_parse_arg_full(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--save", self.save,
+ "--imshow"]
+ args = detector2cvat.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.save, self.save)
+ self.assertEqual(args.imshow, True)
diff --git a/tests/test_miniscene2behavior.py b/tests/test_miniscene2behavior.py
new file mode 100644
index 0000000..35b589a
--- /dev/null
+++ b/tests/test_miniscene2behavior.py
@@ -0,0 +1,268 @@
+import unittest
+import zipfile
+import sys
+import os
+import requests
+from unittest.mock import Mock, patch
+from lxml import etree
+from ruamel.yaml import YAML
+import torch
+import numpy as np
+import pandas as pd
+from kabr_tools import (
+ miniscene2behavior,
+ tracks_extractor
+)
+from kabr_tools.miniscene2behavior import annotate_miniscene
+from tests.utils import (
+ del_file,
+ del_dir,
+ clean_dir,
+ get_detection
+)
+
+
+TESTSDIR = os.path.dirname(os.path.realpath(__file__))
+EXAMPLESDIR = os.path.join(TESTSDIR, "examples")
+
+
+def run():
+ miniscene2behavior.main()
+
+
+class TestMiniscene2Behavior(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download the model from Imageomics HF
+ cls.checkpoint = "checkpoint_epoch_00075.pyth"
+ cls.model_output = None
+ cls.download_model()
+
+ # download data
+ cls.video, cls.annotation = get_detection()
+
+ # extract mini-scene
+ sys.argv = ["tracks_extractor.py",
+ "--video", cls.video,
+ "--annotation", cls.annotation]
+ tracks_extractor.main()
+ cls.miniscene = f'mini-scenes/{os.path.splitext("|".join(cls.video.split("/")[-3:]))[0]}'
+
+ @classmethod
+ def download_model(cls):
+ if not os.path.exists(cls.checkpoint):
+ # download checkpoint archive
+ url = "https://huggingface.co/imageomics/" \
+ + "x3d-kabr-kinetics/resolve/main/" \
+ + f"{cls.checkpoint}.zip"
+ r = requests.get(url, allow_redirects=True, timeout=120)
+ with open(f"{cls.checkpoint}.zip", "wb") as f:
+ f.write(r.content)
+
+ # unzip model checkpoint
+ with zipfile.ZipFile(f"{cls.checkpoint}.zip", "r") as zip_ref:
+ zip_ref.extractall(".")
+
+ # get checkpoint directory
+ try:
+ cfg = torch.load(cls.checkpoint,
+ weights_only=True,
+ map_location=torch.device("cpu"))["cfg"]
+ yaml = YAML(typ="rt")
+ cls.model_output = f"{yaml.load(cfg)['OUTPUT_DIR']}/checkpoints"
+ except Exception:
+ pass
+
+ @classmethod
+ def tearDownClass(cls):
+ # remove model files after tests
+ del_file(f"{cls.checkpoint}.zip")
+ del_file(cls.checkpoint)
+ if cls.model_output:
+ clean_dir(cls.model_output)
+
+ # remove data after tests
+ del_file(cls.video)
+ del_file(cls.annotation)
+ del_dir(cls.miniscene)
+
+ def setUp(self):
+ self.tool = "miniscene2behavior.py"
+ self.checkpoint = "checkpoint_epoch_00075.pyth"
+ self.miniscene = TestMiniscene2Behavior.miniscene
+ self.video = "DJI_0068"
+ self.config = "special_config.yml"
+ self.gpu_num = "1"
+ self.output = "DJI_0068.csv"
+
+ def tearDown(self):
+ # delete outputs
+ del_file(self.output)
+
+ def test_run(self):
+ # download model
+ self.download_model()
+
+ # annotate mini-scenes
+ sys.argv = [self.tool,
+ "--checkpoint", self.checkpoint,
+ "--miniscene", self.miniscene,
+ "--video", self.video,
+ "--output", self.output]
+ run()
+
+ # check output CSV
+ df = pd.read_csv(self.output, sep=' ')
+ self.assertEqual(list(df.columns), [
+ "video", "track", "frame", "label"])
+ row_ct = 0
+
+ root = etree.parse(
+ f"{self.miniscene}/metadata/{self.video}_tracks.xml").getroot()
+ for track in root.iterfind("track"):
+ track_id = int(track.get("id"))
+ for box in track.iterfind("box"):
+ row = list(df.loc[row_ct])
+ self.assertEqual(row[0], self.video)
+ self.assertEqual(row[1], track_id)
+ self.assertEqual(row[2], int(box.get("frame")))
+ self.assertTrue(row[3] >= 0)
+ self.assertTrue(row[3] <= 7)
+ row_ct += 1
+ self.assertEqual(len(df.index), row_ct)
+
+ @patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
+ @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
+ def test_matching_tracks(self, video_capture, process_cv2_inputs):
+ # create fake model that weights class 98
+ mock_model = Mock()
+ prob = torch.zeros(99)
+ prob[-1] = 1
+ mock_model.return_value = prob
+
+ # create fake cfg
+ mock_config = Mock(
+ DATA=Mock(NUM_FRAMES=16,
+ SAMPLING_RATE=5,
+ TEST_CROP_SIZE=300),
+ NUM_GPUS=0,
+ OUTPUT_DIR=''
+ )
+
+ # create fake video capture
+ vc = video_capture.return_value
+ vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
+ vc.get.return_value = 21
+
+ self.output = '/tmp/annotation_data.csv'
+ miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE1")
+ video_name = "DJI"
+
+ annotate_miniscene(cfg=mock_config,
+ model=mock_model,
+ miniscene_path=miniscene_dir,
+ video=video_name,
+ output_path=self.output)
+
+ # check output CSV
+ df = pd.read_csv(self.output, sep=' ')
+ self.assertEqual(list(df.columns), [
+ "video", "track", "frame", "label"])
+ row_ct = 0
+
+ root = etree.parse(
+ f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
+ for track in root.iterfind("track"):
+ track_id = int(track.get("id"))
+ for box in track.iterfind("box"):
+ row_val = [video_name, track_id, int(box.get("frame")), 98]
+ self.assertEqual(list(df.loc[row_ct]), row_val)
+ row_ct += 1
+ self.assertEqual(len(df.index), row_ct)
+
+ @patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
+ @patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
+ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
+
+ # Create fake model that always returns a prediction of 1
+ mock_model = Mock()
+ mock_model.return_value = torch.tensor([1])
+
+ # Create fake cfg
+ mock_config = Mock(
+ DATA=Mock(NUM_FRAMES=16,
+ SAMPLING_RATE=5,
+ TEST_CROP_SIZE=300),
+ NUM_GPUS=0,
+ OUTPUT_DIR=''
+ )
+
+ # Create fake video capture
+ vc = video_capture.return_value
+ vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
+ vc.get.return_value = 21
+
+ self.output = '/tmp/annotation_data.csv'
+ miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE2")
+ video_name = "DJI"
+
+ annotate_miniscene(cfg=mock_config,
+ model=mock_model,
+ miniscene_path=os.path.join(
+ EXAMPLESDIR, "MINISCENE2"),
+ video='DJI',
+ output_path=self.output)
+
+ # check output CSV
+ df = pd.read_csv(self.output, sep=' ')
+ self.assertEqual(list(df.columns), [
+ "video", "track", "frame", "label"])
+ row_ct = 0
+
+ root = etree.parse(
+ f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
+ for track in root.iterfind("track"):
+ track_id = int(track.get("id"))
+ for box in track.iterfind("box"):
+ row_val = [video_name, track_id, int(box.get("frame")), 0]
+ self.assertEqual(list(df.loc[row_ct]), row_val)
+ row_ct += 1
+ self.assertEqual(len(df.index), row_ct)
+
+ def test_parse_arg_min(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--checkpoint", self.checkpoint,
+ "--miniscene", self.miniscene,
+ "--video", self.video]
+ args = miniscene2behavior.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.checkpoint, self.checkpoint)
+ self.assertEqual(args.miniscene, self.miniscene)
+ self.assertEqual(args.video, self.video)
+
+ # check default argument values
+ self.assertEqual(args.config, "config.yml")
+ self.assertEqual(args.gpu_num, 0)
+ self.assertEqual(args.output, "annotation_data.csv")
+
+ def test_parse_arg_full(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--config", self.config,
+ "--checkpoint", self.checkpoint,
+ "--gpu_num", self.gpu_num,
+ "--miniscene", self.miniscene,
+ "--video", self.video,
+ "--output", self.output]
+ args = miniscene2behavior.parse_args()
+
+ # check parsed argument values
+ self.assertEqual(args.config, self.config)
+ self.assertEqual(args.checkpoint, self.checkpoint)
+ self.assertEqual(args.gpu_num, 1)
+ self.assertEqual(args.miniscene, self.miniscene)
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.output, self.output)
diff --git a/tests/test_player.py b/tests/test_player.py
new file mode 100644
index 0000000..091aeb0
--- /dev/null
+++ b/tests/test_player.py
@@ -0,0 +1,103 @@
+import unittest
+import sys
+import os
+from unittest.mock import patch
+from kabr_tools import player
+from tests.utils import (
+ del_file,
+ del_dir,
+ file_exists,
+ get_behavior
+)
+
+
+def run():
+ player.main()
+
+
+class TestPlayer(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download data
+ cls.video, cls.miniscene, cls.annotation, cls.metadata = get_behavior()
+ cls.dir = os.path.dirname(cls.video)
+
+ @classmethod
+ def tearDownClass(cls):
+ # delete data
+ del_file(cls.video)
+ del_file(cls.miniscene)
+ del_file(cls.annotation)
+ del_file(cls.metadata)
+ del_dir(cls.dir)
+
+ def setUp(self):
+ # set params
+ self.tool = "player.py"
+ self.folder = TestPlayer.dir
+ self.video = self.folder.rsplit("/", maxsplit=1)[-1]
+
+ # delete output
+ del_file(f"{self.folder}/{self.video}_demo.mp4")
+
+ def tearDown(self):
+ # delete output
+ del_file(f"{self.folder}/{self.video}_demo.mp4")
+
+ @patch('kabr_tools.player.cv2.imshow')
+ @patch('kabr_tools.player.cv2.namedWindow')
+ @patch('kabr_tools.player.cv2.createTrackbar')
+ @patch('kabr_tools.player.cv2.setTrackbarPos')
+ @patch('kabr_tools.player.cv2.getTrackbarPos')
+ def test_run(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow):
+ # mock getTrackbarPos
+ getTrackbarPos.return_value = 0
+
+ # run player
+ sys.argv = [self.tool,
+ "--folder", self.folder,
+ "--save"]
+ run()
+ self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4"))
+
+ @patch('kabr_tools.player.cv2.imshow')
+ @patch('kabr_tools.player.cv2.namedWindow')
+ @patch('kabr_tools.player.cv2.createTrackbar')
+ @patch('kabr_tools.player.cv2.setTrackbarPos')
+ @patch('kabr_tools.player.cv2.getTrackbarPos')
+ def test_parse_arg_min(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--folder", self.folder]
+ args = player.parse_args()
+
+ # check parsed arguments
+ self.assertEqual(args.folder, self.folder)
+
+ # check default arguments
+ self.assertEqual(args.save, False)
+
+ # run player
+ run()
+ self.assertTrue(not file_exists(f"{self.folder}/{self.video}_demo.mp4"))
+
+ @patch('kabr_tools.player.cv2.imshow')
+ @patch('kabr_tools.player.cv2.namedWindow')
+ @patch('kabr_tools.player.cv2.createTrackbar')
+ @patch('kabr_tools.player.cv2.setTrackbarPos')
+ @patch('kabr_tools.player.cv2.getTrackbarPos')
+ def test_parse_arg_full(self, getTrackbarPos, setTrackbarPos, createTrackbar, namedWindow, imshow):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--folder", self.folder,
+ "--save", "--imshow"]
+ args = player.parse_args()
+
+ # check parsed arguments
+ self.assertEqual(args.folder, self.folder)
+ self.assertEqual(args.save, True)
+
+ # run player
+ run()
+ self.assertTrue(file_exists(f"{self.folder}/{self.video}_demo.mp4"))
diff --git a/tests/test_tracks_extractor.py b/tests/test_tracks_extractor.py
new file mode 100644
index 0000000..71a7e9a
--- /dev/null
+++ b/tests/test_tracks_extractor.py
@@ -0,0 +1,203 @@
+import unittest
+import sys
+import os
+from unittest.mock import patch, Mock
+import json
+from lxml import etree
+import cv2
+from kabr_tools import tracks_extractor
+from kabr_tools.utils.tracker import Tracker
+from kabr_tools.utils.detector import Detector
+from kabr_tools.utils.utils import get_scene
+from tests.utils import (
+ get_detection,
+ dir_exists,
+ file_exists,
+ del_dir,
+ del_file
+)
+
+# TODO: make constants for kabr tools (copied values in tracks_extractor.py)
+scene_width, scene_height = 400, 300
+
+
+def run():
+ tracks_extractor.main()
+
+
+class TestTracksExtractor(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # download data
+ cls.video, cls.annotation = get_detection()
+
+ @classmethod
+ def tearDownClass(cls):
+ # delete data
+ del_file(cls.video)
+ del_file(cls.annotation)
+
+ def setUp(self):
+ # set params
+ self.tool = "tracks_extractor.py"
+ self.video = TestTracksExtractor.video
+ self.annotation = TestTracksExtractor.annotation
+
+ # remove output directory
+ del_dir("mini-scenes")
+
+ def tearDown(self):
+ # remove output directory
+ del_dir("mini-scenes")
+
+ def test_run(self):
+ # run tracks_extractor
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation]
+ run()
+
+ # check output exists
+ mini_folder = os.path.splitext("|".join(self.video.split("/")[-3:]))[0]
+ video_name = "DJI_0068"
+ self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}"))
+ self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/actions"))
+ self.assertTrue(dir_exists(f"mini-scenes/{mini_folder}/metadata"))
+ self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/0.mp4"))
+ self.assertTrue(file_exists(f"mini-scenes/{mini_folder}/1.mp4"))
+ self.assertTrue(file_exists(
+ f"mini-scenes/{mini_folder}/{video_name}.mp4"))
+ self.assertTrue(file_exists(
+ f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json"))
+ self.assertTrue(file_exists(
+ f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml"))
+ self.assertTrue(file_exists(
+ f"mini-scenes/{mini_folder}/metadata/{video_name}.jpg"))
+
+ # check metadata.json
+ root = etree.parse(self.annotation).getroot()
+ tracks = {"main":
+ [-1] * int("".join(root.find("meta").find("task").find("size").itertext()))}
+ for track in root.iterfind("track"):
+ track_id = track.attrib["id"]
+ tracks[track_id] = []
+ for box in track.iter("box"):
+ frame_id = int(box.attrib["frame"])
+ tracks[track_id].append(frame_id)
+ tracks["main"][frame_id] = frame_id
+
+ colors = list(Tracker.colors_table.values())
+
+ with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_metadata.json",
+ "r", encoding="utf-8") as f:
+ metadata = json.load(f)
+ self.assertTrue("original" in metadata)
+ self.assertTrue("tracks" in metadata)
+ self.assertTrue("colors" in metadata)
+ self.assertEqual(metadata["original"], self.video)
+ self.assertEqual(metadata["tracks"]["main"], tracks["main"])
+ self.assertEqual(metadata["tracks"]["0"], tracks["0"])
+ self.assertEqual(metadata["tracks"]["1"], tracks["1"])
+ self.assertEqual(metadata["colors"]["0"],
+ list(colors[0 % len(colors)]))
+ self.assertEqual(metadata["colors"]["1"],
+ list(colors[1 % len(colors)]))
+
+ # check tracks.xml
+ with open(f"mini-scenes/{mini_folder}/metadata/{video_name}_tracks.xml",
+ "r", encoding="utf-8") as f:
+ track_copy = f.read()
+
+ with open(self.annotation, "r", encoding="utf-8") as f:
+ track = f.read()
+
+ self.assertEqual(track, track_copy)
+
+ # check 0.mp4, 1.mp4
+ root = etree.parse(self.annotation).getroot()
+ xml_tracks = {}
+ for track in root.findall("track"):
+ track_id = track.attrib["id"]
+ xml_tracks[track_id] = track
+ self.assertEqual(xml_tracks.keys(), {"0", "1"})
+
+ original = cv2.VideoCapture(self.video)
+ self.assertTrue(original.isOpened())
+ mock = Mock()
+
+ for track_id, xml_track in xml_tracks.items():
+ track = cv2.VideoCapture(
+ f"mini-scenes/{mini_folder}/{track_id}.mp4")
+ self.assertTrue(track.isOpened())
+
+ for i, box in enumerate(xml_track.iter("box")):
+ original.set(cv2.CAP_PROP_POS_FRAMES, int(box.attrib["frame"]))
+ track.set(cv2.CAP_PROP_POS_FRAMES, i)
+ original_returned, original_frame = original.read()
+ track_returned, track_frame = track.read()
+
+ self.assertTrue(original_returned)
+ self.assertTrue(track_returned)
+
+ mock.box = [int(float(box.attrib["xtl"])),
+ int(float(box.attrib["ytl"])),
+ int(float(box.attrib["xbr"])),
+ int(float(box.attrib["ybr"]))]
+ mock.centroid = Detector.get_centroid(mock.box)
+ original_frame = get_scene(
+ original_frame, mock, scene_width, scene_height)
+
+ # encoding seems to add some noise to frames, allow for that
+ self.assertTrue(
+ cv2.norm(original_frame - track_frame) < 1e6)
+ track.release()
+
+ # check DJI_0068.mp4
+ copy = cv2.VideoCapture(f"mini-scenes/{mini_folder}/{video_name}.mp4")
+ self.assertTrue(copy.isOpened())
+ self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_COUNT),
+ original.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_WIDTH),
+ original.get(cv2.CAP_PROP_FRAME_WIDTH))
+ self.assertEqual(copy.get(cv2.CAP_PROP_FRAME_HEIGHT),
+ original.get(cv2.CAP_PROP_FRAME_HEIGHT))
+
+ original.release()
+
+ def test_parse_arg_min(self):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation]
+ args = tracks_extractor.parse_args()
+
+ # check parsed arguments
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.annotation, self.annotation)
+
+ # check default arguments
+ self.assertEqual(args.tracking, False)
+ self.assertEqual(args.imshow, False)
+
+ # run tracks_extractor
+ run()
+
+ @patch('kabr_tools.tracks_extractor.cv2.imshow')
+ def test_parse_arg_full(self, imshow):
+ # parse arguments
+ sys.argv = [self.tool,
+ "--video", self.video,
+ "--annotation", self.annotation,
+ "--tracking",
+ "--imshow"]
+ args = tracks_extractor.parse_args()
+
+ # check parsed arguments
+ self.assertEqual(args.video, self.video)
+ self.assertEqual(args.annotation, self.annotation)
+ self.assertEqual(args.tracking, True)
+ self.assertEqual(args.imshow, True)
+
+ # run tracks_extractor
+ run()
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..b8d3004
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,52 @@
+import os
+import shutil
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+
+DATA_HUB = "imageomics/kabr_testing"
+REPO_TYPE = "dataset"
+
+DETECTION_VIDEO = "DJI_0068/DJI_0068.mp4"
+DETECTION_ANNOTATION = "DJI_0068/DJI_0068.xml"
+
+BEHAVIOR_VIDEO = "DJI_0001/DJI_0001.mp4"
+BEHAVIOR_MINISCENE = "DJI_0001/43.mp4"
+BEHAVIOR_ANNOTATION = "DJI_0001/actions/43.xml"
+BEHAVIOR_METADATA = "DJI_0001/metadata/DJI_0001_metadata.json"
+
+def get_hf(repo_id: str, filename: str, repo_type: str):
+ return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type)
+
+def get_cached_datafile(filename: str):
+ return get_hf(DATA_HUB, filename, REPO_TYPE)
+
+def get_behavior():
+ video = get_cached_datafile(BEHAVIOR_VIDEO)
+ miniscene = get_cached_datafile(BEHAVIOR_MINISCENE)
+ annotation = get_cached_datafile(BEHAVIOR_ANNOTATION)
+ metadata = get_cached_datafile(BEHAVIOR_METADATA)
+ return video, miniscene, annotation, metadata
+
+def get_detection():
+ video = get_cached_datafile(DETECTION_VIDEO)
+ annotation = get_cached_datafile(DETECTION_ANNOTATION)
+ return video, annotation
+
+def clean_dir(path):
+ if os.path.exists(path):
+ os.removedirs(path)
+
+def del_dir(path):
+ if os.path.exists(path):
+ shutil.rmtree(path)
+
+def del_file(path):
+ if os.path.exists(path):
+ os.remove(path)
+
+def file_exists(path):
+ return Path(path).is_file()
+
+
+def dir_exists(path):
+ return Path(path).is_dir()