Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test and refactor TiledDataset #23

Merged
merged 22 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ jobs:
- name: Test formatting with black
run: |
black . --check
- name: pytest
run: |
pytest
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/

# vscode
.vscode/
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
[tool.isort]
profile = "black"

[tool.pytest.ini_options]
pythonpath = [
"src"
]
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ black==24.3.0
flake8==7.0.0
isort==5.13.2
pre-commit==3.6.2
tiled[all]==0.1.0a114
pytest
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
qlty==0.1.6
qlty==0.1.7
dvclive==3.44.0
#dlsia==0.3.0
dlsia==0.3.1
pydantic==2.6.3
tiled[client]==0.1.0a114
Empty file added src/__init__.py
Empty file.
Empty file added src/_tests/__init__.py
Empty file.
38 changes: 38 additions & 0 deletions src/_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest
from tiled.catalog import from_uri
from tiled.client import Context, from_context
from tiled.server.app import build_app


@pytest.fixture
def catalog(tmpdir):
adapter = from_uri(
f"sqlite+aiosqlite:///{tmpdir}/catalog.db",
writable_storage=str(tmpdir),
init_if_not_exists=True,
)
yield adapter


@pytest.fixture
def app(catalog):
app = build_app(catalog)
yield app


@pytest.fixture
def context(app):
with Context.from_app(app) as context:
yield context


@pytest.fixture
def client(context):
"Fixture for tests which only read data"
client = from_context(context)
recons_container = client.create_container("reconstructions")
recons_container.write_array(np.zeros((2, 3, 3), dtype=np.int8), key="recon1")
masks_container = client.create_container("uid0001", metadata={"mask_idx": ["0"]})
masks_container.write_array(np.zeros((1, 3, 3), dtype=np.int8), key="mask")
yield client
16 changes: 16 additions & 0 deletions src/_tests/test_tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from ..tiled_dataset import TiledDataset


def test_tiled_dataset(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"],
)
assert tiled_dataset
assert tiled_dataset[0].shape == (3, 3)


def test_tiled_dataset_with_masks(client):
tiled_dataset = TiledDataset(
client["reconstructions"]["recon1"], mask_tiled_client=client["uid0001"]
)
assert tiled_dataset[0].shape == (3, 3)
15 changes: 11 additions & 4 deletions src/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import yaml
from qlty.qlty2D import NCYXQuilt
from tiled.client import from_uri
from torchvision import transforms

from network import baggin_smsnet_ensemble, load_network
Expand Down Expand Up @@ -52,11 +53,17 @@

print("Parameters loaded successfully.")

data_tiled_client = from_uri(
io_parameters.data_tiled_uri, api_key=io_parameters.data_tiled_api_key
)
mask_tiled_client = None
if io_parameters.mask_tiled_uri:
mask_tiled_client = from_uri(
io_parameters.mask_tiled_uri, api_key=io_parameters.mask_tiled_api_key
)
dataset = TiledDataset(
data_tiled_uri=io_parameters.data_tiled_uri,
data_tiled_api_key=io_parameters.data_tiled_api_key,
mask_tiled_uri=io_parameters.mask_tiled_uri,
mask_tiled_api_key=io_parameters.mask_tiled_api_key,
data_tiled_client,
mask_tiled_client=mask_tiled_client,
is_training=False,
using_qlty=False,
qlty_window=model_parameters.qlty_window,
Expand Down
26 changes: 10 additions & 16 deletions src/tiled_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import torch
from qlty import cleanup
from qlty.qlty2D import NCYXQuilt
from tiled.client import from_uri


class TiledDataset(torch.utils.data.Dataset):

def __init__(
self,
data_tiled_uri,
data_tiled_api_key=None,
mask_tiled_uri=None,
mask_tiled_api_key=None,
data_tiled_client,
mask_tiled_client=None,
is_training=None,
using_qlty=False,
qlty_window=50,
Expand All @@ -33,20 +31,16 @@ def __init__(
Return:
ml_data: tuple, (data_tensor, mask_tensor)
"""
self.data_tiled_uri = data_tiled_uri
self.data_client = from_uri(data_tiled_uri, api_key=data_tiled_api_key)
self.mask_tiled_uri = mask_tiled_uri
if mask_tiled_uri:
self.mask_client_one_up = from_uri(
mask_tiled_uri, api_key=mask_tiled_api_key
)
self.mask_client = self.mask_client_one_up["mask"]
self.mask_idx = [
int(idx) for idx in self.mask_client_one_up.metadata["mask_idx"]
]

self.data_client = data_tiled_client
self.mask_client = None
if mask_tiled_client:
self.mask_client = mask_tiled_client["mask"]
self.mask_idx = [int(idx) for idx in mask_tiled_client.metadata["mask_idx"]]
else:
self.mask_client = None
self.mask_idx = None

self.transform = transform
if using_qlty:
# this object handles unstitching and stitching
Expand Down
29 changes: 20 additions & 9 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import yaml
from dlsia.core.train_scripts import Trainer
from dvclive import Live
from tiled.client import from_uri
from torchvision import transforms

from network import build_network
Expand All @@ -17,15 +18,12 @@
TUNet3PlusParameters,
TUNetParameters,
)
from seg_utils import crop_split_load, train_segmentation
from seg_utils import crop_split_load
from tiled_dataset import TiledDataset
from utils import create_directory

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("yaml_path", type=str, help="path of yaml file for parameters")
args = parser.parse_args()

def train(args):
# Open the YAML file for all parameters
with open(args.yaml_path, "r") as file:
# Load parameters
Expand Down Expand Up @@ -59,11 +57,17 @@
# Create Result Directory if not existed
create_directory(model_dir)

data_tiled_client = from_uri(
io_parameters.data_tiled_uri, api_key=io_parameters.data_tiled_api_key
)
mask_tiled_client = None
if io_parameters.mask_tiled_uri:
mask_tiled_client = from_uri(
io_parameters.mask_tiled_uri, api_key=io_parameters.mask_tiled_api_key
)
dataset = TiledDataset(
data_tiled_uri=io_parameters.data_tiled_uri,
data_tiled_api_key=io_parameters.data_tiled_api_key,
mask_tiled_uri=io_parameters.mask_tiled_uri,
mask_tiled_api_key=io_parameters.mask_tiled_api_key,
data_tiled_client=data_tiled_client,
mask_tiled_client=mask_tiled_client,
is_training=True,
using_qlty=False,
qlty_window=model_parameters.qlty_window,
Expand Down Expand Up @@ -146,3 +150,10 @@
torch.cuda.empty_cache()

print(f"{network} trained successfully.")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("yaml_path", type=str, help="path of yaml file for parameters")
args = parser.parse_args()
train(args)
8 changes: 6 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,13 @@ def allocate_array_space(
# For now, only save image 1 by 1 regardless of the batch_size_inference.
structure.chunks = ((1,) * array_shape[0], (array_shape[1],), (array_shape[2],))

mask_uri = None
if tiled_dataset.mask_client is not None:
mask_uri = tiled_dataset.mask_client.uri

metadata = {
"data_uri": tiled_dataset.data_tiled_uri,
"mask_uri": tiled_dataset.mask_tiled_uri,
"data_uri": tiled_dataset.data_client.uri,
"mask_uri": mask_uri,
"mask_idx": tiled_dataset.mask_idx,
"uid": uid,
"model": model,
Expand Down
Loading