-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
642 train 3d model with lucchi data #650
base: dev
Are you sure you want to change the base?
Changes from 15 commits
9d75668
42f9f36
9be15d5
b0fc01a
8ca1326
ca864ed
a66c09f
a5e937a
1592988
b61ee04
c64944d
70cf9b7
09af0a7
3d8d879
9bf0d45
b4f7865
63b4654
eaacf7a
e3b2dbb
a19f73d
a90ca2e
ad76f2e
a550893
908e1c1
b6a7ce9
3422041
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import argparse | ||
from tqdm import tqdm | ||
import numpy as np | ||
import imageio.v3 as imageio | ||
from elf.io import open_file | ||
from skimage.measure import label as connected_components | ||
|
||
import torch | ||
from glob import glob | ||
|
||
from torch_em.util.segmentation import size_filter | ||
from torch_em.util import load_model | ||
from torch_em.transform.raw import normalize | ||
from torch_em.util.prediction import predict_with_halo | ||
|
||
from micro_sam import util | ||
from micro_sam.evaluation.inference import _run_inference_with_iterative_prompting_for_image | ||
|
||
from segment_anything import SamPredictor | ||
|
||
from micro_sam.models.sam_3d_wrapper import get_sam_3d_model | ||
from typing import List, Union, Dict, Optional, Tuple | ||
|
||
|
||
class RawTrafoFor3dInputs: | ||
def _normalize_inputs(self, raw): | ||
raw = normalize(raw) | ||
raw = raw * 255 | ||
return raw | ||
|
||
def _set_channels_for_inputs(self, raw): | ||
raw = np.stack([raw] * 3, axis=0) | ||
return raw | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
raw = self._set_channels_for_inputs(raw) | ||
return raw | ||
|
||
|
||
def _run_semantic_segmentation_for_image_3d( | ||
model: torch.nn.Module, | ||
image: np.ndarray, | ||
prediction_path: Union[os.PathLike, str], | ||
patch_shape: Tuple[int, int, int], | ||
halo: Tuple[int, int, int], | ||
): | ||
device = next(model.parameters()).device | ||
block_shape = tuple(bs - 2 * ha for bs, ha in zip(patch_shape, halo)) | ||
|
||
def preprocess(x): | ||
x = 255 * normalize(x) | ||
x = np.stack([x] * 3) | ||
return x | ||
|
||
def prediction_function(net, inp): | ||
# Note: we have two singleton axis in front here, I am not quite sure why. | ||
# Both need to be removed to be compatible with the SAM network. | ||
batched_input = [{ | ||
"image": inp[0, 0], "original_size": inp.shape[-2:] | ||
}] | ||
masks = net(batched_input, multimask_output=True)[0]["masks"] | ||
masks = torch.argmax(masks, dim=1) | ||
return masks | ||
|
||
# num_classes = model.sam_model.mask_decoder.num_multimask_outputs | ||
image_size = patch_shape[-1] | ||
output = np.zeros(image.shape, dtype="float32") | ||
predict_with_halo( | ||
image, model, gpu_ids=[device], | ||
block_shape=block_shape, halo=halo, | ||
preprocess=preprocess, output=output, | ||
prediction_function=prediction_function | ||
) | ||
|
||
# save the segmentations | ||
imageio.imwrite(prediction_path, output, compression="zlib") | ||
|
||
|
||
def run_semantic_segmentation_3d( | ||
model: torch.nn.Module, | ||
image_paths: List[Union[str, os.PathLike]], | ||
prediction_dir: Union[str, os.PathLike], | ||
semantic_class_map: Dict[str, int], | ||
patch_shape: Tuple[int, int, int] = (32, 512, 512), | ||
halo: Tuple[int, int, int] = (6, 64, 64), | ||
image_key: Optional[str] = None, | ||
is_multiclass: bool = False, | ||
): | ||
""" | ||
""" | ||
for image_path in tqdm(image_paths, desc="Run inference for semantic segmentation with all images"): | ||
image_name = os.path.basename(image_path) | ||
|
||
assert os.path.exists(image_path), image_path | ||
|
||
# Perform segmentation only on the semantic class | ||
for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): | ||
if is_multiclass: | ||
semantic_class_name = "all" | ||
if i > 0: # We only perform segmentation for multiclass once. | ||
continue | ||
|
||
# We skip the images that already have been segmented | ||
image_name = os.path.splitext(image_name)[0] + ".tif" | ||
prediction_path = os.path.join(prediction_dir, semantic_class_name, image_name) | ||
if os.path.exists(prediction_path): | ||
continue | ||
|
||
if image_key is None: | ||
image = imageio.imread(image_path) | ||
else: | ||
with open_file(image_path, "r") as f: | ||
image = f[image_key][:] | ||
|
||
# create the prediction folder | ||
os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) | ||
|
||
_run_semantic_segmentation_for_image_3d( | ||
model=model, image=image, prediction_path=prediction_path, | ||
patch_shape=patch_shape, halo=halo, | ||
) | ||
|
||
|
||
def transform_labels(y): | ||
return (y > 0).astype("float32") | ||
|
||
|
||
def predict(args): | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if args.checkpoint_path is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure why you would ever run prediction without a checkpoint. I would not make this optional. |
||
if os.path.exists(args.checkpoint_path): | ||
# model = load_model(checkpoint=args.checkpoint_path, device=device) # does not work | ||
|
||
cp_path = os.path.join(args.checkpoint_path, "", "best.pt") | ||
print(cp_path) | ||
model = get_sam_3d_model(device, n_classes=args.n_classes, image_size=args.patch_shape[1], | ||
lora_rank=4, | ||
model_type=args.model_type, | ||
checkpoint_path=cp_path | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will not work to actually load the checkpoint. Please read the code I send you carefully and see how I use |
||
|
||
# checkpoint = torch.load(cp_path, map_location=device) | ||
# #print(checkpoint.keys()) | ||
# # # Load the state dictionary from the checkpoint | ||
# model.load_state_dict(checkpoint['model_state']) | ||
model.eval() | ||
|
||
data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) | ||
pred_path = args.save_root | ||
semantic_class_map = {"all": 0} | ||
|
||
run_semantic_segmentation_3d( | ||
model=model, image_paths=data_paths, prediction_dir=pred_path, semantic_class_map=semantic_class_map, | ||
patch_shape=args.patch_shape, image_key="raw", is_multiclass=True | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." | ||
) | ||
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") | ||
parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") | ||
parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") | ||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
parser.add_argument("--num_workers", type=int, default=4, help="num_workers") | ||
parser.add_argument( | ||
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
parser.add_argument( | ||
"--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-lucchi-train/", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
predict(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
from math import ceil, floor | ||
import torch | ||
|
||
from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset | ||
from torch_em.segmentation import SegmentationDataset | ||
import torch_em | ||
from torch_em.util.debug import check_loader | ||
from torch_em.transform.raw import normalize | ||
|
||
from micro_sam.models.sam_3d_wrapper import get_sam_3d_model | ||
from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer | ||
import micro_sam.training as sam_training | ||
|
||
|
||
class RawTrafoFor3dInputs: | ||
def _normalize_inputs(self, raw): | ||
raw = normalize(raw) | ||
raw = raw * 255 | ||
return raw | ||
|
||
def _set_channels_for_inputs(self, raw): | ||
raw = np.stack([raw] * 3, axis=0) | ||
return raw | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
raw = self._set_channels_for_inputs(raw) | ||
return raw | ||
|
||
|
||
# for sega | ||
class RawResizeTrafoFor3dInputs(RawTrafoFor3dInputs): | ||
def __init__(self, desired_shape, padding="constant"): | ||
super().__init__() | ||
self.desired_shape = desired_shape | ||
self.padding = padding | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
|
||
# let's pad the inputs | ||
tmp_ddim = ( | ||
self.desired_shape[0] - raw.shape[0], | ||
self.desired_shape[1] - raw.shape[1], | ||
self.desired_shape[2] - raw.shape[2] | ||
) | ||
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2) | ||
raw = np.pad( | ||
raw, | ||
pad_width=( | ||
(ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2])) | ||
), | ||
mode=self.padding | ||
) | ||
|
||
raw = self._set_channels_for_inputs(raw) | ||
|
||
return raw | ||
|
||
|
||
class LucchiSegmentationDataset(SegmentationDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can now be removed. |
||
def __init__(self, patch_shape, label_transform=None, **kwargs): | ||
super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor | ||
|
||
def __getitem__(self, index): | ||
raw, label = super().__getitem__(index) | ||
# raw shape: (z, color channels, y, x) channels is fixed to 3 | ||
image_shape = (self.patch_shape[0], 1) + self.patch_shape[1:] | ||
raw = raw.unsqueeze(2) | ||
raw = raw.view(image_shape) | ||
raw = raw.squeeze(0) | ||
raw = raw.repeat(1, 3, 1, 1) | ||
# print("raw shape", raw.shape) | ||
# wanted label shape: (1, z, y, x) | ||
label = (label != 0).to(torch.float) | ||
# print("label shape", label.shape) | ||
return raw, label | ||
|
||
|
||
def transform_labels(y): | ||
return (y > 0).astype("float32") | ||
|
||
|
||
def get_loaders(input_path, patch_shape): | ||
train_loader = get_lucchi_loader( | ||
input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True, | ||
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels, | ||
n_samples=100 | ||
) | ||
val_loader = get_lucchi_loader( | ||
input_path, split="test", patch_shape=patch_shape, batch_size=1, | ||
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels | ||
) | ||
return train_loader, val_loader | ||
# def get_loader(path, split, patch_shape, n_classes, batch_size, label_transform, num_workers=1): | ||
# assert split in ("train", "test") | ||
# data_path = os.path.join(path, f"lucchi_{split}.h5") | ||
# raw_key, label_key = "raw", "labels" | ||
# ds = LucchiSegmentationDataset( | ||
# raw_path=data_path, label_path=data_path, raw_key=raw_key, | ||
# label_key=label_key, patch_shape=patch_shape, label_transform=label_transform) | ||
# loader = torch.utils.data.DataLoader( | ||
# ds, batch_size=batch_size, shuffle=True, | ||
# num_workers=num_workers) | ||
# loader.shuffle = True | ||
# return loader | ||
|
||
|
||
def train_on_lucchi(args): | ||
from micro_sam.training.util import ConvertToSemanticSamInputs | ||
input_path = args.input_path | ||
patch_shape = args.patch_shape | ||
batch_size = args.batch_size | ||
num_workers = args.num_workers | ||
n_classes = args.n_classes | ||
model_type = args.model_type | ||
n_iterations = args.n_iterations | ||
save_root = args.save_root | ||
|
||
# label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) | ||
# label_transform = None | ||
raw_data = np.random.rand(64, 256, 256) # Shape (z, y, x) | ||
raw_data2, label = next(iter(get_lucchi_loader(input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True))) | ||
|
||
# Create an instance of RawTrafoFor3dInputs | ||
transformer = RawTrafoFor3dInputs() | ||
|
||
# Apply transformations | ||
processed_data = transformer(raw_data) | ||
processed_data2 = transformer(raw_data2) | ||
print("input (64,256,256)", processed_data.shape) | ||
print("input", raw_data2.shape, processed_data2.shape) | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
sam_3d = get_sam_3d_model( | ||
device, n_classes=n_classes, image_size=patch_shape[1], | ||
model_type=model_type, lora_rank=4) | ||
train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) | ||
optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) | ||
|
||
trainer = SemanticSamTrainer( | ||
name="3d-sam-lucchi-train", | ||
model=sam_3d, | ||
convert_inputs=ConvertToSemanticSamInputs(), | ||
num_classes=n_classes, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
optimizer=optimizer, | ||
device=device, | ||
compile_model=False, | ||
save_root=save_root, | ||
#logger=None | ||
) | ||
# check_loader(train_loader, n_samples=10) | ||
trainer.fit(n_iterations) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." | ||
) | ||
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") | ||
parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") | ||
parser.add_argument("--n_classes", type=int, default=2, help="Number of classes to predict") | ||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
parser.add_argument("--num_workers", type=int, default=4, help="num_workers") | ||
parser.add_argument( | ||
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
|
||
args = parser.parse_args() | ||
train_on_lucchi(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest to remove this part already, it doesn't make sense in the context here.