Skip to content

Commit

Permalink
add loader for npz files
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed May 5, 2020
1 parent fffe726 commit 98d8ffd
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 18 deletions.
2 changes: 2 additions & 0 deletions bigfish/stack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .io import read_image
from .io import read_dv
from .io import read_array
from .io import read_compressed
from .io import save_image
from .io import save_array

Expand Down Expand Up @@ -80,6 +81,7 @@
"read_image",
"read_dv",
"read_array",
"read_compressed",
"save_image",
"save_array"]

Expand Down
54 changes: 36 additions & 18 deletions bigfish/stack/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def read_dv(path, sanity_check=False):
# read video file
video = mrc.imread(path)

# metadata can be read running 'tensor.Mrc.info()'

# check the output video
# metadata can be read running 'tensor.Mrc.info()'
if sanity_check:
check_array(video,
dtype=[np.uint16, np.int16, np.int32, np.float32],
Expand All @@ -89,42 +88,61 @@ def read_dv(path, sanity_check=False):
return video


def read_array(path, sanity_check=False):
def read_array(path):
"""Read a numpy array with 'npy' extension.
Parameters
----------
path
sanity_check
path : str
Path of the array to read.
Returns
-------
array : ndarray, np.uint or np.int
array : ndarray
Array read.
"""
# check path
check_parameter(path=str,
sanity_check=bool)
check_parameter(path=str)
if ".npy" not in path:
path += ".npy"

# read array file
array = np.load(path)

# check the output array
if sanity_check:
check_array(array,
dtype=[np.uint8, np.uint16, np.uint32,
np.int8, np.int16, np.int32, np.int64,
np.float16, np.float32, np.float64,
bool],
ndim=[2, 3, 4, 5],
allow_nan=False)

return array


def read_compressed(path, verbose=False):
"""Read a NpzFile object with 'npz' extension.
Parameters
----------
path : str
Path of the file to read.
verbose : bool
Return names of the different compressed objects.
Returns
-------
data : NpzFile object
NpzFile read.
"""
# check path
check_parameter(path=str,
verbose=bool)
if ".npz" not in path:
path += ".npz"

# read array file
data = np.load(path)
if verbose:
print("Compressed objects: {0} \n".format(", ".join(data.files)))

return data


# ### Write ###

def save_image(image, path, extension="tif"):
Expand Down
21 changes: 21 additions & 0 deletions bigfish/stack/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,24 @@ def test_npy(shape, dtype):
tensor = stack.read_array(path)
assert_array_equal(test, tensor)
assert test.dtype == tensor.dtype


@pytest.mark.parametrize("shape", [
(8, 8), (8, 8, 8), (8, 8, 8, 8), (8, 8, 8, 8, 8)])
@pytest.mark.parametrize("dtype", [
np.uint8, np.uint16, np.uint32,
np.int8, np.int16, np.int32, np.int64,
np.float16, np.float32, np.float64, bool])
def test_npz(shape, dtype):
# build a temporary directory and save tensors inside
with tempfile.TemporaryDirectory() as tmp_dir:
test_1 = np.zeros(shape, dtype=dtype)
test_2 = np.ones(shape, dtype=dtype)
path = os.path.join(tmp_dir, "test.npz")
np.savez(path, test_1=test_1, test_2=test_2)
data = stack.read_compressed(path)
assert data.files == ["test_1", "test_2"]
assert_array_equal(test_1, data["test_1"])
assert_array_equal(test_2, data["test_2"])
assert test_1.dtype == data["test_1"].dtype
assert test_2.dtype == data["test_2"].dtype

0 comments on commit 98d8ffd

Please sign in to comment.