Skip to content

Commit

Permalink
Merge pull request #42 from ctr26/shape_embed
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 authored Mar 4, 2024
2 parents 622a8a6 + dcdfa63 commit 6c6b1cf
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 82 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ jobs:
# shell: bash -l {0}
steps:
- uses: actions/checkout@v2
- name: Free Disk Space (Ubuntu)
uses: jlumbroso/free-disk-space@main
with:
tool-cache: false
android: true
dotnet: true
haskell: true
large-packages: true
docker-images: true
swap-storage: true
- uses: conda-incubator/setup-miniconda@v2
with:
environment-file: environment.yml
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ dependencies:
- pytorch
- pillow=9.5.0
- pip
- conda-forge::opencv
- pip:
- -e .
151 changes: 69 additions & 82 deletions scripts/shapes/shape_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@
import pandas as pd
from sklearn import metrics
import matplotlib as mpl
import seaborn as sns
from pathlib import Path
from sklearn.pipeline import Pipeline
import umap
from torch.autograd import Variable
from types import SimpleNamespace
import numpy as np
import logging
from skimage import measure
import umap.plot
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand All @@ -25,17 +22,16 @@
from types import SimpleNamespace
from umap import UMAP
import os

# Deal with the filesystem
import torch.multiprocessing
import logging
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)

torch.multiprocessing.set_sharing_strategy("file_system")

from bioimage_embed import shapes
import bioimage_embed

# Note - you must have torchvision installed for this example

from pytorch_lightning import loggers as pl_loggers
from torchvision import transforms
from bioimage_embed.lightning import DataModule
Expand All @@ -47,16 +43,15 @@
DistogramToCoords,
MaskToDistogramPipeline,
RotateIndexingClockwise,
CoordsToDistogram,
)

import matplotlib.pyplot as plt

from bioimage_embed.lightning import DataModule
import matplotlib as mpl
from matplotlib import rc

import logging
import pickle
import pickle
import base64
import hashlib

Expand All @@ -66,6 +61,7 @@
np.random.seed(42)
pl.seed_everything(42)


def hashing_fn(args):
serialized_args = pickle.dumps(vars(args))
hash_object = hashlib.sha256(serialized_args)
Expand Down Expand Up @@ -176,6 +172,9 @@ def shape_embed_process():
"latent_dim": int(128),
"pretrained": True,
"frobenius_norm": False,
# dataset = "bbbc010/BBBC010_v1_foreground_eachworm"
# dataset = "vampire/mefs/data/processed/Control"
"dataset": "synthcellshapes_dataset",
}

optimizer_params = {
Expand All @@ -197,14 +196,8 @@ def shape_embed_process():

args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)

#dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm"
dataset_path = "shape_embed_data/data/bbbc010/BBBC010_v1_foreground_eachworm/"
# dataset_path = "vampire/mefs/data/processed/Control"
# dataset_path = "shape_embed_data/data/vampire/torchvision/Control/"
# dataset_path = "vampire/torchvision/Control"
# dataset = "bbbc010"
dataset_path = args.dataset

# train_data_path = f"scripts/shapes/data/{dataset_path}"
train_data_path = f"scripts/shapes/data/{dataset_path}"
metadata = lambda x: f"results/{dataset_path}_{args.model}/{x}"

Expand All @@ -213,9 +206,10 @@ def shape_embed_process():
# %%

transform_crop = CropCentroidPipeline(window_size)
transform_dist = MaskToDistogramPipeline(
window_size, interp_size, matrix_normalised=False
)
# transform_dist = MaskToDistogramPipeline(
# window_size, interp_size, matrix_normalised=False
# )
transform_coord_to_dist = CoordsToDistogram(interp_size, matrix_normalised=False)
transform_mdscoords = DistogramToCoords(window_size)
transform_coords = ImageToCoords(window_size)

Expand All @@ -229,16 +223,27 @@ def shape_embed_process():
]
)

transform_mask_to_dist = transforms.Compose(
transform_mask_to_coords = transforms.Compose(
[
transform_mask_to_crop,
transform_dist,
transform_coords,
]
)
transform_mask_to_coords = transforms.Compose(

transform_mask_to_dist = transforms.Compose(
[
transform_mask_to_crop,
transform_coords,
transform_mask_to_coords,
transform_coord_to_dist,
]
)

gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))
transform = transforms.Compose(
[
transform_mask_to_dist,
transforms.ToTensor(),
RotateIndexingClockwise(p=1),
gray2rgb,
]
)

Expand All @@ -249,27 +254,38 @@ def shape_embed_process():
"transform_coords": transform_mask_to_coords,
}

# Apply transform to find which images don't work
dataset = datasets.ImageFolder(train_data_path, transform=transform)

valid_indices = []
# Iterate through the dataset and apply the transform to each image
for idx in range(len(dataset)):
try:
image, label = dataset[idx]
# If the transform works without errors, add the index to the list of valid indices
valid_indices.append(idx)
except Exception as e:
# A better way to do with would be with batch collation
logger.warning(f"Error occurred for image {idx}: {e}")

train_data = {
key: datasets.ImageFolder(train_data_path, transform=value)
key: torch.utils.data.Subset(
datasets.ImageFolder(train_data_path, transform=value), valid_indices
)
for key, value in transforms_dict.items()
}

dataset = torch.utils.data.Subset(
datasets.ImageFolder(train_data_path, transform=transform), valid_indices
)

for key, value in train_data.items():
print(key, len(value))
plt.imshow(train_data[key][0][0], cmap="gray")
logger.info(key, len(value))
plt.imshow(np.array(train_data[key][0][0]), cmap="gray")
plt.imsave(metadata(f"{key}.png"), train_data[key][0][0], cmap="gray")
# plt.show()
plt.close()

# plt.scatter(*train_data["transform_coords"][0][0])
# plt.savefig(metadata(f"transform_coords.png"))
# plt.show()

# plt.imshow(train_data["transform_crop"][0][0], cmap="gray")
# plt.scatter(*train_data["transform_coords"][0][0],c=np.arange(interp_size), cmap='rainbow', s=1)
# plt.show()
# plt.savefig(metadata(f"transform_coords.png"))

# Retrieve the coordinates and cropped image
coords = train_data["transform_coords"][0][0]
crop_image = train_data["transform_crop"][0][0]
Expand All @@ -290,51 +306,22 @@ def shape_embed_process():

# Close the plot
plt.close()
# import albumentations as A
# %%
gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))
transform = transforms.Compose(
[
transform_mask_to_dist,
transforms.ToTensor(),
RotateIndexingClockwise(p=1),
gray2rgb,
]
)

dataset = datasets.ImageFolder(train_data_path, transform=transform)

valid_indices = []
# Iterate through the dataset and apply the transform to each image
for idx in range(len(dataset)):
try:
image, label = dataset[idx]
# If the transform works without errors, add the index to the list of valid indices
valid_indices.append(idx)
except Exception as e:
# A better way to do with would be with batch collation
print(f"Error occurred for image {idx}: {e}")

# Create a Subset using the valid indices
dataset = torch.utils.data.Subset(dataset, valid_indices)
dataloader = DataModule(
dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
)

# model = bioimage_embed.models.create_model("resnet18_vqvae_legacy", **vars(args))
#
model = bioimage_embed.models.create_model(
model=args.model,
input_dim=args.input_dim,
latent_dim=args.latent_dim,
pretrained=args.pretrained,
)

# model = bioimage_embed.models.factory.ModelFactory(**vars(args)).resnet50_vqvae_legacy()

# lit_model = shapes.MaskEmbedLatentAugment(model, args)
lit_model = shapes.MaskEmbed(model, args)
test_data = dataset[0][0].unsqueeze(0)
Expand Down Expand Up @@ -398,16 +385,14 @@ def shape_embed_process():
# torch.onnx.export(lit_model, example_input, f"{model_dir}/model.onnx")

# %%
# Inference

# Inference on full dataset
dataloader = DataModule(
dataset,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
# Transform is commented here to avoid augmentations in real data
# HOWEVER, applying a the transform multiple times and averaging the results might produce better latent embeddings
# transform=transform,
# HOWEVER, applying the transform multiple times and averaging the results might produce better latent embeddings
# transform=transform,
)
dataloader.setup()
Expand All @@ -421,16 +406,14 @@ def shape_embed_process():
y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])

df = pd.DataFrame(latent_space.numpy())
df["Class"] = y
# Map numeric classes to their labels
idx_to_class = {0: "alive", 1: "dead"}
df["Class"] = df["Class"].map(idx_to_class).astype("category")
df["Class"] = pd.Series(y).map(idx_to_class).astype("category")
df["Scale"] = scalings[:, 0].squeeze()
df = df.set_index("Class")
df_shape_embed = df.copy()

# %% UMAP plot
umap_plot(df, metadata, width, height,split=0.9)

umap_plot(df, metadata, width, height, split=0.9)

X = df_shape_embed.to_numpy()
y = df_shape_embed.index
Expand All @@ -444,11 +427,12 @@ def shape_embed_process():
"orientation",
]
dfs = []
for i, data in enumerate(train_data["transform_crop"]):
# Distance matrix data
for i, data in enumerate(tqdm(train_data["transform_crop"])):
X, y = data
# Do regionprops here
# Calculate shape summary statistics using regionprops
# We're considering that the mask has only one object, thus we take the first element [0]
# We're considering that the mask has only one object, so we take the first element [0]
# props = regionprops(np.array(X).astype(int))[0]
props_table = measure.regionprops_table(
np.array(X).astype(int), properties=properties
Expand All @@ -464,9 +448,8 @@ def shape_embed_process():

df_regionprops = pd.concat(dfs)

# Assuming 'dataset_contour' is your DataLoader for the dataset
dfs = []
for i, data in enumerate(train_data["transform_coords"]):
for i, data in enumerate(tqdm(train_data["transform_coords"])):
# Convert the tensor to a numpy array
X, y = data

Expand Down Expand Up @@ -515,7 +498,7 @@ def shape_embed_process():
y = trial["labels"]
trial["score_df"] = scoring_df(X, y)
trial["score_df"]["trial"] = trial["name"]
print(trial["score_df"])
logger.info(trial["score_df"])
trial["score_df"].to_csv(metadata(f"{trial['name']}_score_df.csv"))
trial_df = pd.concat([trial_df, trial["score_df"]])
trial_df = trial_df.drop(["fit_time", "score_time"], axis=1)
Expand All @@ -524,6 +507,10 @@ def shape_embed_process():
trial_df.groupby("trial").mean().to_csv(metadata(f"trial_df_mean.csv"))
trial_df.plot(kind="bar")

avg = trial_df.groupby("trial").mean()
logger.info(avg)
avg.to_latex(metadata(f"trial_df.tex"))

melted_df = trial_df.melt(id_vars="trial", var_name="Metric", value_name="Score")
# fig, ax = plt.subplots(figsize=(width, height))
ax = sns.catplot(
Expand All @@ -550,7 +537,7 @@ def shape_embed_process():
.groupby("trial")
.mean()
)
print(avs)
logger.info(avs)
# tikzplotlib.save(metadata(f"trials_barplot.tikz"))


Expand Down

0 comments on commit 6c6b1cf

Please sign in to comment.