Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed and RAM consumption of buffered slice writer #937

Merged
merged 13 commits into from
Aug 15, 2023
125 changes: 123 additions & 2 deletions webknossos/tests/dataset/test_buffered_slice_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path

import numpy as np
Expand All @@ -8,6 +9,10 @@
from webknossos.geometry import BoundingBox, Mag, Vec3Int
from webknossos.utils import rmtree

# This module effectively tests BufferedSliceWriter and
# BufferedSliceReader (by calling get_buffered_slice_writer
# and get_buffered_slice_reader).


def test_buffered_slice_writer() -> None:
test_img = np.arange(24 * 24).reshape(24, 24).astype(np.uint16) + 1
Expand Down Expand Up @@ -77,11 +82,13 @@ def test_buffered_slice_writer() -> None:
def test_buffered_slice_writer_along_different_axis(tmp_path: Path) -> None:
test_cube = (np.random.random((3, 13, 13, 13)) * 100).astype(np.uint8)
cube_size_without_channel = test_cube.shape[1:]
offset = Vec3Int(5, 10, 20)
offset = Vec3Int(64, 96, 32)

for dim in [0, 1, 2]:
ds = Dataset(tmp_path / f"buffered_slice_writer_{dim}", voxel_size=(1, 1, 1))
mag_view = ds.add_layer("color", COLOR_CATEGORY, num_channels=3).add_mag(1)
mag_view = ds.add_layer(
"color", COLOR_CATEGORY, num_channels=test_cube.shape[0]
).add_mag(1)

with mag_view.get_buffered_slice_writer(
absolute_offset=offset, buffer_size=5, dimension=dim
Expand Down Expand Up @@ -129,3 +136,117 @@ def test_buffered_slice_reader_along_different_axis(tmp_path: Path) -> None:

assert np.array_equal(slice_data_a, original_slice)
assert np.array_equal(slice_data_b, original_slice)


def test_basic_buffered_slice_writer(tmp_path: Path) -> None:
# Create DS
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
layer = dataset.add_layer(
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
)
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(8, 8, 8))

# Allocate some data (~ 8 MB)
shape = (512, 512, 32)
data = np.random.randint(0, 255, shape, dtype=np.uint8)

with warnings.catch_warnings():
warnings.filterwarnings("error") # This escalates the warning to an error

# Write some slices
with mag1.get_buffered_slice_writer() as writer:
for z in range(0, shape[2]):
section = data[:, :, z]
writer.send(section)

written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape)

assert np.all(data == written_data)


def test_buffered_slice_writer_should_warn_about_unaligned_usage(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, in test_buffered_slice_writer_along_different_axis you aligned the offset and here are testing that unaligned screams?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly.

tmp_path: Path,
) -> None:
# Create DS
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
layer = dataset.add_layer(
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
)
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(8, 8, 8))

offset = (1, 1, 1)

# Allocate some data (~ 8 MB)
shape = (512, 512, 32)
data = np.random.randint(0, 255, shape, dtype=np.uint8)

with warnings.catch_warnings(record=True) as recorded_warnings:
warnings.filterwarnings("default", module="webknossos", message=r"\[WARNING\]")
# Write some slices
with mag1.get_buffered_slice_writer(
absolute_offset=offset, buffer_size=35
) as writer:
for z in range(0, shape[2]):
section = data[:, :, z]
writer.send(section)

warning1, warning2 = recorded_warnings
assert issubclass(warning1.category, UserWarning) and "Using an offset" in str(
warning1.message
)
assert issubclass(
warning2.category, UserWarning
) and "Using a buffer size" in str(warning2.message)

written_data = mag1.read(absolute_offset=offset, size=shape)

assert np.all(data == written_data)


def test_basic_buffered_slice_writer_multi_shard(tmp_path: Path) -> None:
# Create DS
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
layer = dataset.add_layer(
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=1
)
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(4, 4, 4))

# Allocate some data (~ 3 MB) that covers multiple shards (also in z)
shape = (160, 150, 140)
data = np.random.randint(0, 255, shape, dtype=np.uint8)

with warnings.catch_warnings():
warnings.filterwarnings("error") # This escalates the warning to an error

# Write some slices
with mag1.get_buffered_slice_writer() as writer:
for z in range(0, shape[2]):
section = data[:, :, z]
writer.send(section)

written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape)

assert np.all(data == written_data)


def test_basic_buffered_slice_writer_multi_shard_multi_channel(tmp_path: Path) -> None:
# Create DS
dataset = Dataset(tmp_path, voxel_size=(1, 1, 1))
layer = dataset.add_layer(
layer_name="color", category="color", dtype_per_channel="uint8", num_channels=3
)
mag1 = layer.add_mag("1", chunk_shape=(32, 32, 32), chunks_per_shard=(4, 4, 4))

# Allocate some data (~ 3 MB) that covers multiple shards (also in z)
shape = (3, 160, 150, 140)
data = np.random.randint(0, 255, shape, dtype=np.uint8)

# Write some slices
with mag1.get_buffered_slice_writer() as writer:
for z in range(0, shape[-1]):
section = data[:, :, :, z]
writer.send(section)

written_data = mag1.read(absolute_offset=(0, 0, 0), size=shape[1:])

assert np.all(data == written_data)
127 changes: 85 additions & 42 deletions webknossos/webknossos/dataset/_utils/buffered_slice_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import psutil

from webknossos.geometry import Vec3Int, Vec3IntLike
from webknossos.geometry import BoundingBox, Vec3Int, Vec3IntLike

if TYPE_CHECKING:
from webknossos.dataset import View
Expand Down Expand Up @@ -69,21 +69,43 @@ def __init__(
)
self.dimension = dimension

effective_offset = Vec3Int.full(0)
if self.relative_offset is not None:
effective_offset = self.view.bounding_box.topleft + self.relative_offset

if self.absolute_offset is not None:
effective_offset = self.absolute_offset

view_chunk_depth = self.view.info.chunk_shape[self.dimension]
if (
effective_offset is not None
and effective_offset[self.dimension] % view_chunk_depth != 0
):
warnings.warn(
"[WARNING] Using an offset that doesn't align with the datataset's chunk size, "
+ "will slow down the buffered slice writer, because twice as many chunks will be written.",
)
if buffer_size >= view_chunk_depth and buffer_size % view_chunk_depth > 0:
warnings.warn(
"[WARNING] Using a buffer size that doesn't align with the datataset's chunk size, "
+ "will slow down the buffered slice writer.",
)

assert 0 <= dimension <= 2

self.buffer: List[np.ndarray] = []
self.slices_to_write: List[np.ndarray] = []
self.current_slice: Optional[int] = None
self.buffer_start_slice: Optional[int] = None

def _write_buffer(self) -> None:
if len(self.buffer) == 0:
def _flush_buffer(self) -> None:
if len(self.slices_to_write) == 0:
return

assert (
len(self.buffer) <= self.buffer_size
len(self.slices_to_write) <= self.buffer_size
), "The WKW buffer is larger than the defined batch_size. The buffer should have been flushed earlier. This is probably a bug in the BufferedSliceWriter."

uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.buffer))
uniq_dtypes = set(map(lambda _slice: _slice.dtype, self.slices_to_write))
assert (
len(uniq_dtypes) == 1
), "The buffer of BufferedSliceWriter contains slices with differing dtype."
Expand All @@ -95,7 +117,7 @@ def _write_buffer(self) -> None:
if self.use_logging:
info(
"({}) Writing {} slices at position {}.".format(
getpid(), len(self.buffer), self.buffer_start_slice
getpid(), len(self.slices_to_write), self.buffer_start_slice
)
)
log_memory_consumption()
Expand All @@ -104,44 +126,65 @@ def _write_buffer(self) -> None:
assert (
self.buffer_start_slice is not None
), "Failed to write buffer: The buffer_start_slice is not set."
max_width = max(slice.shape[-2] for slice in self.buffer)
max_height = max(slice.shape[-1] for slice in self.buffer)

self.buffer = [
np.pad(
slice,
mode="constant",
pad_width=[
(0, 0),
(0, max_width - slice.shape[-2]),
(0, max_height - slice.shape[-1]),
],
)
for slice in self.buffer
]
max_width = max(section.shape[-2] for section in self.slices_to_write)
max_height = max(section.shape[-1] for section in self.slices_to_write)
channel_count = self.slices_to_write[0].shape[0]

buffer_bbox = BoundingBox(
(0, 0, 0), (max_width, max_height, self.buffer_size)
)

data = np.concatenate(
[np.expand_dims(slice, self.dimension + 1) for slice in self.buffer],
axis=self.dimension + 1,
shard_dimensions = self.view._get_file_dimensions().moveaxis(
-1, self.dimension
)
buffer_start_list = [0, 0, 0]
buffer_start_list[self.dimension] = self.buffer_start_slice
buffer_start = Vec3Int(buffer_start_list)
buffer_start_mag1 = buffer_start * self.view.mag.to_vec3_int()
self.view.write(
data,
offset=buffer_start.add_or_none(self.offset),
relative_offset=buffer_start_mag1.add_or_none(self.relative_offset),
absolute_offset=buffer_start_mag1.add_or_none(self.absolute_offset),
json_update_allowed=self.json_update_allowed,
chunk_size = Vec3Int(
min(shard_dimensions[0], max_width),
min(shard_dimensions[1], max_height),
self.buffer_size,
)
for chunk_bbox in buffer_bbox.chunk(chunk_size):
info(f"Writing chunk {chunk_bbox}")
width, height, _ = chunk_bbox.size
data = np.zeros(
(channel_count, width, height, self.buffer_size),
dtype=self.slices_to_write[0].dtype,
)

z = 0
for section in self.slices_to_write:
section_chunk = section[
:,
chunk_bbox.topleft.x : chunk_bbox.bottomright.x,
chunk_bbox.topleft.y : chunk_bbox.bottomright.y,
]
data[
:, 0 : section_chunk.shape[-2], 0 : section_chunk.shape[-1], z
] = section_chunk

z += 1

buffer_start = Vec3Int(
chunk_bbox.topleft.x, chunk_bbox.topleft.y, self.buffer_start_slice
).moveaxis(-1, self.dimension)
buffer_start_mag1 = buffer_start * self.view.mag.to_vec3_int()

data = np.moveaxis(data, -1, self.dimension + 1)

self.view.write(
data,
offset=buffer_start.add_or_none(self.offset),
relative_offset=buffer_start_mag1.add_or_none(self.relative_offset),
absolute_offset=buffer_start_mag1.add_or_none(self.absolute_offset),
json_update_allowed=self.json_update_allowed,
)
del data

except Exception as exc:
error(
"({}) An exception occurred in BufferedSliceWriter._write_buffer with {} "
"({}) An exception occurred in BufferedSliceWriter._flush_buffer with {} "
"slices at position {}. Original error is:\n{}:{}\n\nTraceback:".format(
getpid(),
len(self.buffer),
len(self.slices_to_write),
self.buffer_start_slice,
type(exc).__name__,
exc,
Expand All @@ -152,23 +195,23 @@ def _write_buffer(self) -> None:

raise exc
finally:
self.buffer = []
self.slices_to_write = []

def _get_slice_generator(self) -> Generator[None, np.ndarray, None]:
current_slice = 0
while True:
data = yield # Data gets send from the user
if len(self.buffer) == 0:
if len(self.slices_to_write) == 0:
self.buffer_start_slice = current_slice
if len(data.shape) == 2:
# The input data might contain channel data or not.
# Bringing it into the same shape simplifies the code
data = np.expand_dims(data, axis=0)
self.buffer.append(data)
self.slices_to_write.append(data)
current_slice += 1

if current_slice % self.buffer_size == 0:
self._write_buffer()
self._flush_buffer()

def __enter__(self) -> Generator[None, np.ndarray, None]:
gen = self._get_slice_generator()
Expand All @@ -182,4 +225,4 @@ def __exit__(
_value: Optional[BaseException],
_tb: Optional[TracebackType],
) -> None:
self._write_buffer()
self._flush_buffer()
15 changes: 15 additions & 0 deletions webknossos/webknossos/geometry/vec3_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,21 @@ def __repr__(self) -> str:
def add_or_none(self, other: Optional["Vec3Int"]) -> Optional["Vec3Int"]:
return None if other is None else self + other

def moveaxis(
self, source: Union[int, List[int]], target: Union[int, List[int]]
) -> "Vec3Int":
"""
Allows to move one element at index `source` to another index `target`. Similar to
np.moveaxis, this is *not* a swap operation but instead it moves the specified
source so that the other elements move when necessary.
"""

# Piggy-back on np.moveaxis by creating an auxiliary array where the indices 0, 1 and
# 2 appear in the shape.
indices = np.moveaxis(np.zeros((0, 1, 2)), source, target).shape
arr = self.to_np()[np.array(indices)]
return Vec3Int(arr)

@classmethod
def zeros(cls) -> "Vec3Int":
return cls(0, 0, 0)
Expand Down