Skip to content

Commit

Permalink
Adding scaling to latentspace manually
Browse files Browse the repository at this point in the history
  • Loading branch information
ctr26 committed Nov 9, 2023
1 parent 1c312e0 commit 086c814
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
6 changes: 3 additions & 3 deletions bioimage_embed/lightning/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, model, args=None):

def forward(self, batch):
x = self.batch_to_tensor(batch)
return self.model(x)
return {**x,**self.model(x)}

def get_results(self, batch):
# if self.PYTHAE_FLAG:
Expand All @@ -57,7 +57,7 @@ def batch_to_tensor(self, batch):
return {"data": batch}

def embedding_from_output(self, model_output):
return model_output.z.view(model_output.z.shape[0], -1)
return model_output["z"].view(model_output["z"].shape[0], -1)

def get_model_output(self, x, batch_idx):
model_output = self.model(x, epoch=batch_idx)
Expand Down Expand Up @@ -125,7 +125,7 @@ def validation_step(self, batch, batch_idx):
self.logger.experiment.add_scalar("Loss/val", loss, batch_idx)
self.logger.experiment.add_image(
"val",
torchvision.utils.make_grid(model_output.recon_x),
torchvision.utils.make_grid(model_output["recon_x"]),
batch_idx,
)

Expand Down
8 changes: 6 additions & 2 deletions bioimage_embed/shapes/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def __init__(self, model, args=None):
super().__init__(model, args)

def batch_to_tensor(self, batch):
return super().batch_to_tensor(batch[0].float())
x = batch[0].float()
output = super().batch_to_tensor(x)
froebenius_norm = torch.norm(
output["data"], p="fro", dim=(-2, -1), keepdim=True
)
return {"data": x / froebenius_norm, "scalings": froebenius_norm}

def loss_function(self, model_output, *args, **kwargs):
loss_ops = lf.DistanceMatrixLoss(model_output.recon_x, norm=True)
Expand Down Expand Up @@ -110,4 +115,3 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):
def configure_optimizers(self):
opt_ed, lr_s_ed = self.timm_optimizers(self.model)
return self.timm_to_lightning(optimizer=opt_ed, lr_scheduler=lr_s_ed)

50 changes: 30 additions & 20 deletions scripts/shapes/shape_embed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# %%
import seaborn as sns
import pyefd
import tikzplotlib
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_validate, KFold, train_test_split
from sklearn.metrics import make_scorer
Expand All @@ -14,8 +13,7 @@
from torch.autograd import Variable
from types import SimpleNamespace
import numpy as np
import tikzplotlib

import logging
from skimage import measure
import umap.plot
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
Expand Down Expand Up @@ -50,6 +48,8 @@
import matplotlib as mpl
from matplotlib import rc

logger = logging.getLogger(__name__)


def scoring_df(X, y):
# Split the data into training and test sets
Expand Down Expand Up @@ -107,7 +107,6 @@ def shape_embed_process():
"batch_size": 4,
"num_workers": 2**4,
# "window_size": 64*2,
"num_workers": 1,
"input_dim": (1, window_size, window_size),
# "channels": 3,
"latent_dim": 16,
Expand Down Expand Up @@ -144,7 +143,7 @@ def shape_embed_process():
# input_dim = (params["channels"], params["window_size"], params["window_size"])
args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)

dataset_path = "bbbc010"
dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm"
# dataset_path = "vampire/mefs/data/processed/Control"
# dataset_path = "bbbc010/BBBC010_v1_foreground_eachworm"
# dataset_path = "vampire/torchvision/Control"
Expand Down Expand Up @@ -292,9 +291,7 @@ def shape_embed_process():
min_epochs=50,
max_epochs=args.epochs,
)

# %%

try:
trainer.fit(
lit_model, datamodule=dataloader, ckpt_path=f"{model_dir}/last.ckpt"
Expand All @@ -309,7 +306,7 @@ def shape_embed_process():
example_input = Variable(torch.rand(1, *args.input_dim))

# torch.jit.save(lit_model.to_torchscript(), f"{model_dir}/model.pt")
torch.onnx.export(lit_model, example_input, f"{model_dir}/model.onnx")
# torch.onnx.export(lit_model, example_input, f"{model_dir}/model.onnx")

# %%
# Inference
Expand All @@ -327,29 +324,39 @@ def shape_embed_process():
dataloader.setup()

predictions = trainer.predict(lit_model, datamodule=dataloader)
latent_space = torch.stack(
[prediction.z.flatten() for prediction in predictions[:-1]], dim=0
)
latent_space = torch.stack([d["z"].flatten() for d in predictions])
scalings = torch.stack([d["scalings"].flatten() for d in predictions])

idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}

y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])[:-1]
y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])

y_partial = y.copy()
indices = np.random.choice(y.size, int(0.3 * y.size), replace=False)

y_partial[indices] = -1

y_blind = -1 * np.ones_like(y)
umap_labels = y_blind
classes = np.array([idx_to_class[i] for i in y])

mapper = umap.UMAP().fit(latent_space.numpy(), y=y)
n_components = 64 # Number of UMAP components
component_names = [f"umap{i}" for i in range(n_components)] # List of column names

logger.info("UMAP fitting")
mapper = umap.UMAP(n_components=64, random_state=42).fit(
latent_space.numpy(), y=umap_labels
)

logger.info("UMAP transforming")
semi_supervised_latent = mapper.transform(latent_space.numpy())

df = pd.DataFrame(semi_supervised_latent, columns=["umap0", "umap1"])
df = pd.DataFrame(semi_supervised_latent, columns=component_names)
df["Class"] = y
# Map numeric classes to their labels
idx_to_class = {0: "alive", 1: "dead"}
df["Class"] = df["Class"].map(idx_to_class)
df["Scale"] = scalings
df = df.set_index("Class")
df_shape_embed = df.copy()

ax = sns.relplot(
data=df,
Expand Down Expand Up @@ -378,10 +385,9 @@ def shape_embed_process():

# %%

X = latent_space.numpy()
y = classes
X = df_shape_embed.to_numpy()
y = df_shape_embed.index.values

dfs = []
properties = [
"area",
"perimeter",
Expand Down Expand Up @@ -435,7 +441,11 @@ def shape_embed_process():
df_pyefd = pd.concat(dfs)

trials = [
{"name": "mask_embed", "features": latent_space.numpy(), "labels": classes},
{
"name": "mask_embed",
"features": df_shape_embed.to_numpy(),
"labels": df_shape_embed.index,
},
{
"name": "fourier_coeffs",
"features": df_pyefd.xs("coeffs", level="coeffs"),
Expand Down

0 comments on commit 086c814

Please sign in to comment.