-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Marc Tonsen
committed
Oct 22, 2024
1 parent
dce7b74
commit ce0ebd2
Showing
3 changed files
with
183 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,12 @@ | ||
from .multi_part_reader import MultiPartReader | ||
from .reader import Reader | ||
from .video_frame import PixelFormat, VideoFrame | ||
from .writer import Writer | ||
|
||
__all__: list[str] = ["Reader", "Writer", "VideoFrame", "PixelFormat"] | ||
__all__: list[str] = [ | ||
"Reader", | ||
"MultiPartReader", | ||
"Writer", | ||
"VideoFrame", | ||
"PixelFormat", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from collections.abc import Sequence | ||
from pathlib import Path | ||
from types import TracebackType | ||
from typing import Optional, overload | ||
|
||
import numpy as np | ||
|
||
from .reader import Reader, TimesArray | ||
from .video_frame import VideoFrame | ||
|
||
|
||
class MultiPartReader(Sequence[VideoFrame]): | ||
def __init__( | ||
self, paths: list[str] | list[Path], times: Optional[list[TimesArray]] = None | ||
): | ||
if times is not None and len(times) != len(paths): | ||
raise ValueError("Number of times arrays must match number of video parts.") | ||
|
||
if times is None: | ||
self.parts = [Reader(path) for path in paths] | ||
else: | ||
self.parts = [Reader(path, time) for path, time in zip(paths, times)] | ||
|
||
self._start_indices = np.cumsum([0] + [len(part) for part in self.parts]) | ||
|
||
def __len__(self) -> int: | ||
return sum(len(part) for part in self.parts) | ||
|
||
@overload | ||
def __getitem__(self, key: int) -> VideoFrame: ... | ||
@overload | ||
def __getitem__(self, key: slice) -> Sequence[VideoFrame]: ... | ||
|
||
def __getitem__(self, key: int | slice) -> VideoFrame | Sequence[VideoFrame]: | ||
if isinstance(key, int): | ||
if key >= len(self): | ||
raise IndexError("Index out of range.") | ||
|
||
part_index = ( | ||
np.searchsorted(self._start_indices, key, side="right").item() - 1 | ||
) | ||
part_key = int(key - self._start_indices[part_index]) | ||
frame = self.parts[part_index][part_key] | ||
frame.index = key | ||
# TODO(marc): How do we want to set frame.ts and frame.pts? | ||
return frame | ||
else: | ||
raise NotImplementedError | ||
|
||
def _parse_key(self, key: int | slice) -> tuple[int, int]: | ||
if isinstance(key, slice): | ||
start_index, stop_index = key.start, key.stop | ||
elif isinstance(key, int): | ||
start_index, stop_index = key, key + 1 | ||
if key < 0: | ||
start_index = len(self) + key | ||
stop_index = start_index + 1 | ||
else: | ||
raise TypeError(f"key must be int or slice, not {type(key)}") | ||
|
||
if start_index is None: | ||
start_index = 0 | ||
if start_index < 0: | ||
start_index = len(self) + start_index | ||
if stop_index is None: | ||
stop_index = len(self) | ||
if stop_index < 0: | ||
stop_index = len(self) + stop_index | ||
|
||
return start_index, stop_index | ||
|
||
def __enter__(self) -> "MultiPartReader": | ||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: type[BaseException] | None, | ||
exc_value: BaseException | None, | ||
traceback: TracebackType | None, | ||
) -> None: | ||
self.close() | ||
|
||
def close(self) -> None: | ||
raise NotImplementedError | ||
|
||
@property | ||
def width(self) -> int: | ||
# TODO(marc): Add an appropriate attribute to the Reader class. | ||
return self.parts[0]._container.streams.video[0].width | ||
|
||
@property | ||
def height(self) -> int: | ||
# TODO(marc): Add an appropriate attribute to the Reader class. | ||
return self.parts[0]._container.streams.video[0].height |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from dataclasses import dataclass | ||
from functools import cached_property | ||
|
||
import av | ||
import numpy as np | ||
import pytest | ||
|
||
from pupil_labs.video.multi_part_reader import MultiPartReader | ||
|
||
|
||
@dataclass | ||
class PacketData: | ||
pts: list[int] | ||
times: list[float] | ||
keyframe_indices: list[int] | ||
|
||
@cached_property | ||
def gop_size(self) -> int: | ||
return int(max(np.diff(self.keyframe_indices))) | ||
|
||
def _summarize_list(self, lst: list) -> str: | ||
return f"""[{ | ||
( | ||
", ".join( | ||
x if isinstance(x, str) else str(round(x, 4)) | ||
for x in lst[:3] + ["..."] + lst[-3:] | ||
) | ||
) | ||
}]""" | ||
|
||
def __repr__(self) -> str: | ||
return ( | ||
f"{self.__class__.__name__}(" | ||
+ ", ".join( | ||
f"{key}={value}" | ||
for key, value in [ | ||
("len", len(self.pts)), | ||
("pts", self._summarize_list(self.pts)), | ||
("times", self._summarize_list(self.times)), | ||
("keyframe_indices", self._summarize_list(self.keyframe_indices)), | ||
] | ||
) | ||
+ ")" | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def correct_data(multi_part_video_paths: list[str]) -> PacketData: | ||
pts_bias = 0 | ||
times_bias = 0 | ||
pts = [] | ||
times = [] | ||
index = 0 | ||
keyframe_indices = [] | ||
for video_path in multi_part_video_paths: | ||
container = av.open(str(video_path)) | ||
stream = container.streams.video[0] | ||
assert stream.time_base | ||
|
||
for packet in container.demux(stream): | ||
if packet.pts is None: | ||
continue | ||
pts.append(packet.pts + pts_bias) | ||
times.append(float(packet.pts * stream.time_base) + times_bias) | ||
if packet.is_keyframe: | ||
keyframe_indices.append(index) | ||
index += 1 | ||
|
||
pts_bias += container.duration | ||
times_bias = pts_bias * stream.time_base | ||
return PacketData(pts=pts, times=times, keyframe_indices=keyframe_indices) | ||
|
||
|
||
@pytest.fixture | ||
def reader(multi_part_video_paths: list[str]) -> MultiPartReader: | ||
return MultiPartReader(multi_part_video_paths) | ||
|
||
|
||
def test_indexing(reader: MultiPartReader, correct_data: PacketData) -> None: | ||
for i in range(len(reader)): | ||
assert reader[i].index == i |