Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Jun 26, 2023
1 parent 41b6ac9 commit 52f8c53
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 36 deletions.
161 changes: 126 additions & 35 deletions pymaid/stack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from enum import IntEnum
from io import BytesIO
from typing import Literal, Optional, Sequence, Type, TypeVar, Generic, TypedDict, Union
from typing import Any, Literal, Optional, Sequence, Type, TypedDict, Union
import numpy as np
from abc import ABC
from numpy.typing import DTypeLike, ArrayLike
Expand All @@ -15,8 +16,11 @@
import requests
import imageio.v3 as iio

from . import config

logger = config.get_logger(__name__)

Dimension = Literal["x", "y", "z"]
Orientation = Literal["xy", "xz", "zy"]
HALF_PX = 0.5
ENDIAN = "<" if sys.byteorder == "little" else ">"

Expand All @@ -32,38 +36,64 @@ class MirrorInfo(BaseModel):
position: int


N = TypeVar("N", int, float)
class CoordI(TypedDict):
x: int
y: int
z: int


class CoordF(TypedDict):
x: float
y: float
z: float


class Orientation(IntEnum):
# this is a guess
XY = 0
XZ = 1
ZY = 2

class Coord(TypedDict, Generic[N]):
x: N
y: N
z: N
def as_str(self) -> str:
return ("xy", "xz", "zy")[self]

def get_order(self):
return ("xyz", "xzy", "zyx")[self]


class StackInfo(BaseModel):
sid: int
pid: int
ptitle: str
stitle: str
downsample_factors: list[Coord[float]]
num_zoom_levels: int
translation: Coord[float]
resolution: Coord[float]
dimension: Coord[int]
downsample_factors: Optional[list[CoordF]] # None = catmaid default (power2 in XY)?
num_zoom_levels: int # can be -1
translation: CoordF
resolution: CoordF
dimension: CoordI
comment: str
description: str
metadata: str
metadata: Optional[str]
broken_slices: dict[int, int]
mirrors: list[MirrorInfo]
orientation: Orientation
orientation: Orientation # as int ID
attribution: str
canary_location: Coord[int]
placeholder_colour: dict[str, float] # actually {r g b a}
canary_location: CoordI
placeholder_color: dict[str, float] # actually {r g b a}

def get_downsample_factor(self, scale_level: int) -> CoordF:
if not self.has_scale(scale_level):
raise ValueError(f"No scale level {scale_level}")
if self.downsample_factors is not None:
return self.downsample_factors[scale_level]
return CoordF(x=2**scale_level, y=2**scale_level, z=1)

def has_scale(self, scale_level: int) -> bool:
return self.num_zoom_levels == -1 or scale_level <= self.num_zoom_levels


def to_array(
coord: Union[Coord[N], ArrayLike],
coord: Union[CoordF, CoordI, ArrayLike],
dtype: DTypeLike = np.float64,
order: Sequence[Dimension] = ("z", "y", "x"),
) -> np.ndarray:
Expand Down Expand Up @@ -97,17 +127,24 @@ def __init__(
self.mirror_info = mirror_info
self.zoom_level = zoom_level

brok_sl = {int(k): int(k) + v for k, v in self.stack_info.broken_slices.items()}
self.broken_slices = dict()
for k, v in brok_sl.items():
while v in brok_sl:
v = brok_sl[v]
self.broken_slices[k] = v

if session is None:
cm = utils._eval_remote_instance(None)
self.session = cm._session
else:
self.session = session

order = full_orientation[self.stack_info.orientation]
order = self.stack_info.orientation.get_order()
self.metadata_payload = json.dumps(
{
"zarr_format": 2,
"shape": to_array(stack_info.dimension, order, int).tolist(),
"shape": to_array(stack_info.dimension, int, order).tolist(),
"chunks": [mirror_info.tile_width, mirror_info.tile_height, 1],
"dtype": ENDIAN + "u1",
"compressor": None,
Expand Down Expand Up @@ -141,7 +178,8 @@ def __getitem__(self, key):
if last == ".zarray":
return self.metadata_payload
# todo: check order
slice_idx, col, row = (int(i) for i in last.split("."))
row, col, slice_idx = (int(i) for i in last.split("."))
slice_idx = self.broken_slices.get(slice_idx, slice_idx)
url = self._format_url(row, col, slice_idx)
response = self.session.get(url)
if response.status_code == 404:
Expand All @@ -160,6 +198,18 @@ def to_array(self) -> zarr.Array:
def to_dask(self) -> da.Array:
return da.from_zarr(self.to_array())

def __delitem__(self, __key: Any) -> None:
raise NotImplementedError("Store not erasable")

def __iter__(self):
raise NotImplementedError("Store not listable")

def __len__(self) -> int:
raise NotImplementedError("Store not listable")

def __setitem__(self, __key: Any, __value: Any) -> None:
raise NotImplementedError("Store not writeable")


class TileStore1(TileStore):
tile_source_type = 1
Expand Down Expand Up @@ -194,7 +244,8 @@ def from_catmaid(
cls, stack_id: int, mirror_id: Optional[int] = None, remote_instance=None
):
cm = utils._eval_remote_instance(remote_instance)
info = cm.make_url("stack", stack_id, "info")
url = cm.make_url(cm.project_id, "stack", stack_id, "info")
info = cm.fetch(url)
sinfo = parse_obj_as(StackInfo, info)
return cls(sinfo, mirror_id)

Expand All @@ -211,11 +262,58 @@ def _get_mirror_info(self, mirror_id: Optional[int] = None) -> MirrorInfo:
)

def set_mirror(self, mirror_id: int):
self.mirror_id = self._get_mirror_info(mirror_id)
self.mirror_info = self._get_mirror_info(mirror_id)

def supported_mirrors(self):
out = dict()
for mirror in self.stack_info.mirrors:
if mirror.tile_source_type not in tile_stores:
continue
out[mirror.id] = mirror
return out

def select_mirror(self) -> bool:
"""Interactively select a stack mirror.
Returns
-------
bool
Whether a mirror was successfully selected.
"""
supported = self.supported_mirrors()

if not supported:
logger.warning("No supported mirrors found")
return False

supported_sorted = sorted(self.supported_mirrors().items(), key=lambda x: x[0])
prompt_rows = ["Available mirrors:"]
for idx, mirror in supported_sorted:
prompt_rows.append(f"{idx}.\t{mirror.title} ( {mirror.image_base} )")
last = "Enter mirror ID and return (empty to cancel): "
prompt_rows.append(last)
prompt = "\n".join(prompt_rows)
while True:
response = input(prompt).strip()
if not response:
return False
try:
mid = int(response)
except ValueError:
logger.warning("Not a valid mirror ID, try again")
prompt = last
continue

if mid not in supported:
logger.warning("ID is not for a supported mirror")
prompt = last
else:
self.set_mirror(mid)
return True

def _res_for_scale(self, scale_level: int) -> np.ndarray:
return to_array(self.stack_info.resolution) * to_array(
self.stack_info.downsample_factors[scale_level]
self.stack_info.get_downsample_factor(scale_level)
)

def _from_array(self, arr, scale_level: int) -> ImageVolume:
Expand All @@ -230,7 +328,7 @@ def get_scale(
self, scale_level: int, mirror_id: Optional[int] = None
) -> ImageVolume:
mirror_info = self._get_mirror_info(mirror_id)
if scale_level > self.stack_info.num_zoom_levels:
if not self.stack_info.has_scale(scale_level):
raise ValueError(
f"Scale level {scale_level} does not exist "
f"for stack {self.stack_info.sid} "
Expand Down Expand Up @@ -258,13 +356,6 @@ def get_scale(
)


full_orientation: dict[Orientation, Sequence[Dimension]] = {
"xy": "xyz",
"xz": "xzy",
"zy": "zyx",
}


class ImageVolume:
def __init__(self, array, offset, resolution, orientation: Orientation):
self.array = array
Expand All @@ -276,7 +367,7 @@ def __init__(self, array, offset, resolution, orientation: Orientation):

@property
def full_orientation(self):
return full_orientation[self.orientation]
return self.orientation.get_order()

@property
def offset_oriented(self):
Expand All @@ -290,8 +381,8 @@ def __getitem__(self, selection):
return self.array.__getitem__(selection)

def get_roi(
self, offset: Coord[float], shape: Coord[float]
) -> tuple[Coord[float], np.ndarray]:
self, offset: CoordF, shape: CoordF
) -> tuple[CoordF, np.ndarray]:
order = self.full_orientation
offset_o = to_array(offset, order=order)
shape_o = to_array(shape, order=order)
Expand All @@ -301,7 +392,7 @@ def get_roi(
).astype("uint64")
slicing = tuple(slice(mi, ma) for mi, ma in zip(mins, maxes))
# todo: finalise orientation
actual_offset = Coord(
actual_offset = CoordF(
**{
d: m
for d, m in zip(
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ tqdm>=4.50.0
psutil>=5.4.3

#extra: extras
fuzzywuzzy[speedup]~=0.17.0
fuzzywuzzy[speedup]~=0.18.0
ujson~=1.35
zarr
pydantic
imageio
dask

# diskcache>=4.0.0

Expand Down

0 comments on commit 52f8c53

Please sign in to comment.