Skip to content

Commit

Permalink
Added simple multi part reader
Browse files Browse the repository at this point in the history
  • Loading branch information
Marc Tonsen committed Oct 22, 2024
1 parent dce7b74 commit ce0ebd2
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/pupil_labs/video/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from .multi_part_reader import MultiPartReader
from .reader import Reader
from .video_frame import PixelFormat, VideoFrame
from .writer import Writer

__all__: list[str] = ["Reader", "Writer", "VideoFrame", "PixelFormat"]
__all__: list[str] = [
"Reader",
"MultiPartReader",
"Writer",
"VideoFrame",
"PixelFormat",
]
94 changes: 94 additions & 0 deletions src/pupil_labs/video/multi_part_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from collections.abc import Sequence
from pathlib import Path
from types import TracebackType
from typing import Optional, overload

import numpy as np

from .reader import Reader, TimesArray
from .video_frame import VideoFrame


class MultiPartReader(Sequence[VideoFrame]):
def __init__(
self, paths: list[str] | list[Path], times: Optional[list[TimesArray]] = None
):
if times is not None and len(times) != len(paths):
raise ValueError("Number of times arrays must match number of video parts.")

if times is None:
self.parts = [Reader(path) for path in paths]
else:
self.parts = [Reader(path, time) for path, time in zip(paths, times)]

self._start_indices = np.cumsum([0] + [len(part) for part in self.parts])

def __len__(self) -> int:
return sum(len(part) for part in self.parts)

@overload
def __getitem__(self, key: int) -> VideoFrame: ...
@overload
def __getitem__(self, key: slice) -> Sequence[VideoFrame]: ...

def __getitem__(self, key: int | slice) -> VideoFrame | Sequence[VideoFrame]:
if isinstance(key, int):
if key >= len(self):
raise IndexError("Index out of range.")

part_index = (
np.searchsorted(self._start_indices, key, side="right").item() - 1
)
part_key = int(key - self._start_indices[part_index])
frame = self.parts[part_index][part_key]
frame.index = key
# TODO(marc): How do we want to set frame.ts and frame.pts?
return frame
else:
raise NotImplementedError

def _parse_key(self, key: int | slice) -> tuple[int, int]:
if isinstance(key, slice):
start_index, stop_index = key.start, key.stop
elif isinstance(key, int):
start_index, stop_index = key, key + 1
if key < 0:
start_index = len(self) + key
stop_index = start_index + 1
else:
raise TypeError(f"key must be int or slice, not {type(key)}")

if start_index is None:
start_index = 0
if start_index < 0:
start_index = len(self) + start_index
if stop_index is None:
stop_index = len(self)
if stop_index < 0:
stop_index = len(self) + stop_index

return start_index, stop_index

def __enter__(self) -> "MultiPartReader":
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()

def close(self) -> None:
raise NotImplementedError

@property
def width(self) -> int:
# TODO(marc): Add an appropriate attribute to the Reader class.
return self.parts[0]._container.streams.video[0].width

@property
def height(self) -> int:
# TODO(marc): Add an appropriate attribute to the Reader class.
return self.parts[0]._container.streams.video[0].height
81 changes: 81 additions & 0 deletions tests/test_multi_part_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dataclasses import dataclass
from functools import cached_property

import av
import numpy as np
import pytest

from pupil_labs.video.multi_part_reader import MultiPartReader


@dataclass
class PacketData:
pts: list[int]
times: list[float]
keyframe_indices: list[int]

@cached_property
def gop_size(self) -> int:
return int(max(np.diff(self.keyframe_indices)))

def _summarize_list(self, lst: list) -> str:
return f"""[{
(
", ".join(
x if isinstance(x, str) else str(round(x, 4))
for x in lst[:3] + ["..."] + lst[-3:]
)
)
}]"""

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
+ ", ".join(
f"{key}={value}"
for key, value in [
("len", len(self.pts)),
("pts", self._summarize_list(self.pts)),
("times", self._summarize_list(self.times)),
("keyframe_indices", self._summarize_list(self.keyframe_indices)),
]
)
+ ")"
)


@pytest.fixture
def correct_data(multi_part_video_paths: list[str]) -> PacketData:
pts_bias = 0
times_bias = 0
pts = []
times = []
index = 0
keyframe_indices = []
for video_path in multi_part_video_paths:
container = av.open(str(video_path))
stream = container.streams.video[0]
assert stream.time_base

for packet in container.demux(stream):
if packet.pts is None:
continue
pts.append(packet.pts + pts_bias)
times.append(float(packet.pts * stream.time_base) + times_bias)
if packet.is_keyframe:
keyframe_indices.append(index)
index += 1

pts_bias += container.duration
times_bias = pts_bias * stream.time_base
return PacketData(pts=pts, times=times, keyframe_indices=keyframe_indices)


@pytest.fixture
def reader(multi_part_video_paths: list[str]) -> MultiPartReader:
return MultiPartReader(multi_part_video_paths)


def test_indexing(reader: MultiPartReader, correct_data: PacketData) -> None:
for i in range(len(reader)):
assert reader[i].index == i

0 comments on commit ce0ebd2

Please sign in to comment.