From 2b0c4c42dccecd41e38e1348072ef42b4674ce83 Mon Sep 17 00:00:00 2001 From: ctr26 Date: Tue, 1 Oct 2024 14:49:19 +0000 Subject: [PATCH] Auto-commit updated notebooks --- notebooks/_shape_embed.ipynb | 570 +++++++++++++++++++++++++++++++ notebooks/shape_embed.ipynb | 633 +++++++++++++++++++++++++++++++++++ scripts/_shape_embed.py | 4 +- scripts/shape_embed.py | 6 +- 4 files changed, 1208 insertions(+), 5 deletions(-) create mode 100644 notebooks/_shape_embed.ipynb create mode 100644 notebooks/shape_embed.ipynb diff --git a/notebooks/_shape_embed.ipynb b/notebooks/_shape_embed.ipynb new file mode 100644 index 00000000..76e17a52 --- /dev/null +++ b/notebooks/_shape_embed.ipynb @@ -0,0 +1,570 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "57c2b078", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import pyefd\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import cross_validate, KFold, train_test_split\n", + "from sklearn.metrics import make_scorer\n", + "import pandas as pd\n", + "from sklearn import metrics\n", + "from pathlib import Path\n", + "import umap\n", + "from torch.autograd import Variable\n", + "from types import SimpleNamespace\n", + "import numpy as np\n", + "import logging\n", + "from skimage import measure\n", + "import umap.plot\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "\n", + "# Deal with the filesystem\n", + "import torch.multiprocessing\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "\n", + "from shape_embed import shapes\n", + "import bioimage_embed\n", + "\n", + "# Note - you must have torchvision installed for this example\n", + "\n", + "from pytorch_lightning import loggers as pl_loggers\n", + "from torchvision import transforms\n", + "from bioimage_embed.lightning import DataModule\n", + "\n", + "from torchvision import datasets\n", + "from bioimage_embed.shapes.transforms import (\n", + " ImageToCoords,\n", + " CropCentroidPipeline,\n", + " DistogramToCoords,\n", + " MaskToDistogramPipeline,\n", + ")\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "def scoring_df(X, y):\n", + " # Split the data into training and test sets\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y\n", + " )\n", + " # Define a dictionary of metrics\n", + " scoring = {\n", + " \"accuracy\": make_scorer(metrics.accuracy_score),\n", + " \"precision\": make_scorer(metrics.precision_score, average=\"macro\"),\n", + " \"recall\": make_scorer(metrics.recall_score, average=\"macro\"),\n", + " \"f1\": make_scorer(metrics.f1_score, average=\"macro\"),\n", + " }\n", + "\n", + " # Create a random forest classifier\n", + " clf = RandomForestClassifier()\n", + "\n", + " # Specify the number of folds\n", + " k_folds = 10\n", + "\n", + " # Perform k-fold cross-validation\n", + " cv_results = cross_validate(\n", + " estimator=clf,\n", + " X=X,\n", + " y=y,\n", + " cv=KFold(n_splits=k_folds),\n", + " scoring=scoring,\n", + " n_jobs=-1,\n", + " return_train_score=False,\n", + " )\n", + "\n", + " # Put the results into a DataFrame\n", + " return pd.DataFrame(cv_results)\n", + "\n", + "\n", + "def shape_embed_process():\n", + " # Setting the font size\n", + "\n", + " # rc(\"text\", usetex=True)\n", + " width = 3.45\n", + " height = width / 1.618\n", + " plt.rcParams[\"figure.figsize\"] = [width, height]\n", + "\n", + " sns.set(\n", + " style=\"white\",\n", + " context=\"notebook\",\n", + " rc={\"figure.figsize\": (width, height)},\n", + " )\n", + "\n", + " # matplotlib.use(\"TkAgg\")\n", + " interp_size = 128 * 2\n", + " max_epochs = 100\n", + " window_size = 128 * 2\n", + "\n", + " params = {\n", + " \"model\": \"resnet18_vqvae_legacy\",\n", + " \"epochs\": 75,\n", + " \"batch_size\": 3,\n", + " \"num_workers\": 2**4,\n", + " \"input_dim\": (3, interp_size, interp_size),\n", + " \"latent_dim\": (interp_size) * 8,\n", + " \"num_embeddings\": 16,\n", + " \"num_hiddens\": 16,\n", + " \"pretrained\": True,\n", + " \"commitment_cost\": 0.25,\n", + " \"decay\": 0.99,\n", + " \"loss_weights\": [1, 1, 1, 1],\n", + " }\n", + "\n", + " optimizer_params = {\n", + " \"opt\": \"LAMB\",\n", + " \"lr\": 0.001,\n", + " \"weight_decay\": 0.0001,\n", + " \"momentum\": 0.9,\n", + " }\n", + "\n", + " lr_scheduler_params = {\n", + " \"sched\": \"cosine\",\n", + " \"min_lr\": 1e-4,\n", + " \"warmup_epochs\": 5,\n", + " \"warmup_lr\": 1e-6,\n", + " \"cooldown_epochs\": 10,\n", + " \"t_max\": 50,\n", + " \"cycle_momentum\": False,\n", + " }\n", + "\n", + " args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)\n", + "\n", + " dataset_path = \"bbbc010/BBBC010_v1_foreground_eachworm\"\n", + " # dataset_path = \"vampire/mefs/data/processed/Control\"\n", + " # dataset_path = \"shape_embed_data/data/vampire/torchvision/Control/\"\n", + " # dataset_path = \"vampire/torchvision/Control\"\n", + " # dataset = \"bbbc010\"\n", + "\n", + " # train_data_path = f\"scripts/shapes/data/{dataset_path}\"\n", + " train_data_path = f\"data/{dataset_path}\"\n", + " metadata = lambda x: f\"results/{dataset_path}_{args.model}/{x}\"\n", + "\n", + " path = Path(metadata(\"\"))\n", + " path.mkdir(parents=True, exist_ok=True)\n", + " model_dir = f\"models/{dataset_path}_{args.model}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f50128cb", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + " transform_crop = CropCentroidPipeline(window_size)\n", + " transform_dist = MaskToDistogramPipeline(\n", + " window_size, interp_size, matrix_normalised=False\n", + " )\n", + " transform_mdscoords = DistogramToCoords(window_size)\n", + " transform_coords = ImageToCoords(window_size)\n", + "\n", + " transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)])\n", + "\n", + " transform_mask_to_crop = transforms.Compose(\n", + " [\n", + " # transforms.ToTensor(),\n", + " transform_mask_to_gray,\n", + " transform_crop,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_dist = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_dist,\n", + " ]\n", + " )\n", + " transform_mask_to_coords = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_coords,\n", + " ]\n", + " )\n", + "\n", + " transforms_dict = {\n", + " \"none\": transform_mask_to_gray,\n", + " \"transform_crop\": transform_mask_to_crop,\n", + " \"transform_dist\": transform_mask_to_dist,\n", + " \"transform_coords\": transform_mask_to_coords,\n", + " }\n", + "\n", + " train_data = {\n", + " key: datasets.ImageFolder(train_data_path, transform=value)\n", + " for key, value in transforms_dict.items()\n", + " }\n", + "\n", + " for key, value in train_data.items():\n", + " print(key, len(value))\n", + " plt.imshow(train_data[key][0][0], cmap=\"gray\")\n", + " plt.imsave(metadata(f\"{key}.png\"), train_data[key][0][0], cmap=\"gray\")\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + " # plt.scatter(*train_data[\"transform_coords\"][0][0])\n", + " # plt.savefig(metadata(f\"transform_coords.png\"))\n", + " # plt.show()\n", + "\n", + " # plt.imshow(train_data[\"transform_crop\"][0][0], cmap=\"gray\")\n", + " # plt.scatter(*train_data[\"transform_coords\"][0][0],c=np.arange(interp_size), cmap='rainbow', s=1)\n", + " # plt.show()\n", + " # plt.savefig(metadata(f\"transform_coords.png\"))\n", + "\n", + " # Retrieve the coordinates and cropped image\n", + " coords = train_data[\"transform_coords\"][0][0]\n", + " crop_image = train_data[\"transform_crop\"][0][0]\n", + "\n", + " fig = plt.figure(frameon=True)\n", + " ax = plt.Axes(fig, [0, 0, 1, 1])\n", + " ax.set_axis_off()\n", + " fig.add_axes(ax)\n", + "\n", + " # Display the cropped image using grayscale colormap\n", + " plt.imshow(crop_image, cmap=\"gray_r\")\n", + "\n", + " # Scatter plot with smaller point size\n", + " plt.scatter(*coords, c=np.arange(interp_size), cmap=\"rainbow\", s=2)\n", + "\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"transform_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + "\n", + " # Close the plot\n", + " plt.close()\n", + " # import albumentations as A" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b26e4d66", + "metadata": {}, + "outputs": [], + "source": [ + " gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))\n", + " transform = transforms.Compose(\n", + " [transform_mask_to_dist, transforms.ToTensor(), gray2rgb]\n", + " )\n", + "\n", + " dataset = datasets.ImageFolder(train_data_path, transform=transform)\n", + "\n", + " valid_indices = []\n", + " # Iterate through the dataset and apply the transform to each image\n", + " for idx in range(len(dataset)):\n", + " try:\n", + " image, label = dataset[idx]\n", + " # If the transform works without errors, add the index to the list of valid indices\n", + " valid_indices.append(idx)\n", + " except Exception as e:\n", + " # A better way to do with would be with batch collation\n", + " print(f\"Error occurred for image {idx}: {e}\")\n", + "\n", + " # Create a Subset using the valid indices\n", + " dataset = torch.utils.data.Subset(dataset, valid_indices)\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=True,\n", + " num_workers=args.num_workers,\n", + " )\n", + "\n", + " model = bioimage_embed.models.create_model(**vars(args))\n", + " logger.info(model)\n", + "\n", + " # lit_model = shapes.MaskEmbedLatentAugment(model, args)\n", + " lit_model = shapes.MaskEmbed(model, args)\n", + " test_data = dataset[0][0].unsqueeze(0)\n", + " # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),)\n", + " test_output = lit_model.forward((test_data,))\n", + "\n", + " dataloader.setup()\n", + " model.eval()\n", + " # Model\n", + " lit_model.eval()\n", + "\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + "\n", + " model_dir = f\"my_models/{dataset_path}_{model._get_name()}_{lit_model._get_name()}\"\n", + " Path(f\"{model_dir}/\").mkdir(parents=True, exist_ok=True)\n", + "\n", + " tb_logger = pl_loggers.TensorBoardLogger(\n", + " \"logs/\",\n", + " name=f\"{dataset_path}_{args.model}_{model._get_name()}_{lit_model._get_name()}\",\n", + " )\n", + "\n", + " checkpoint_callback = ModelCheckpoint(dirpath=f\"{model_dir}/\", save_last=True)\n", + "\n", + " trainer = pl.Trainer(\n", + " logger=tb_logger,\n", + " gradient_clip_val=0.5,\n", + " enable_checkpointing=True,\n", + " devices=1,\n", + " accelerator=\"gpu\",\n", + " precision=16, # Use mixed precision\n", + " accumulate_grad_batches=4,\n", + " callbacks=[checkpoint_callback],\n", + " min_epochs=50,\n", + " max_epochs=args.epochs,\n", + " )\n", + " # # %%\n", + "\n", + " testing = trainer.test(lit_model, datamodule=dataloader)\n", + "\n", + " try:\n", + " trainer.fit(\n", + " lit_model, datamodule=dataloader, ckpt_path=f\"{model_dir}/last.ckpt\"\n", + " )\n", + " except:\n", + " trainer.fit(lit_model, datamodule=dataloader)\n", + "\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + " try:\n", + " example_input = Variable(torch.rand(2, 1, *args.input_dim))\n", + " torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")\n", + " torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + "\n", + " except:\n", + " logger.info(\"Model \")\n", + "\n", + " validation = trainer.validate(lit_model, datamodule=dataloader)\n", + " # testing = trainer.test(lit_model, datamodule=dataloader)\n", + "\n", + " example_input = Variable(torch.rand(1, *args.input_dim))\n", + " logger.info(f\"Saving model to {model_dir}\")\n", + " torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + " torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")\n", + "\n", + " # Inference\n", + "\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=args.num_workers,\n", + " # Transform is commented here to avoid augmentations in real data\n", + " # HOWEVER, applying a the transform multiple times and averaging the results might produce better latent embeddings\n", + " # transform=transform,\n", + " )\n", + " dataloader.setup()\n", + "\n", + " predictions = trainer.predict(lit_model, datamodule=dataloader)\n", + " latent_space = torch.stack([d[\"z\"].flatten() for d in predictions])\n", + " scalings = torch.stack([d[\"scalings\"].flatten() for d in predictions])\n", + "\n", + " idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}\n", + "\n", + " y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])\n", + "\n", + " y_partial = y.copy()\n", + " indices = np.random.choice(y.size, int(0.3 * y.size), replace=False)\n", + " y_partial[indices] = -1\n", + " y_blind = -1 * np.ones_like(y)\n", + " umap_labels = y_blind\n", + " classes = np.array([idx_to_class[i] for i in y])\n", + "\n", + " n_components = 64 # Number of UMAP components\n", + " component_names = [f\"umap{i}\" for i in range(n_components)] # List of column names\n", + "\n", + " logger.info(\"UMAP fitting\")\n", + " mapper = umap.UMAP(n_components=64, random_state=42).fit(\n", + " latent_space.numpy(), y=umap_labels\n", + " )\n", + "\n", + " logger.info(\"UMAP transforming\")\n", + " semi_supervised_latent = mapper.transform(latent_space.numpy())\n", + "\n", + " df = pd.DataFrame(semi_supervised_latent, columns=component_names)\n", + " df[\"Class\"] = y\n", + " # Map numeric classes to their labels\n", + " idx_to_class = {0: \"alive\", 1: \"dead\"}\n", + " df[\"Class\"] = df[\"Class\"].map(idx_to_class)\n", + " df[\"Scale\"] = scalings\n", + " df = df.set_index(\"Class\")\n", + " df_shape_embed = df.copy()\n", + "\n", + " ax = sns.relplot(\n", + " data=df,\n", + " x=\"umap0\",\n", + " y=\"umap1\",\n", + " hue=\"Class\",\n", + " palette=\"deep\",\n", + " alpha=0.5,\n", + " edgecolor=None,\n", + " s=5,\n", + " height=height,\n", + " aspect=0.5 * width / height,\n", + " )\n", + "\n", + " sns.move_legend(\n", + " ax,\n", + " \"upper center\",\n", + " )\n", + " ax.set(xlabel=None, ylabel=None)\n", + " sns.despine(left=True, bottom=True)\n", + " plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)\n", + " plt.tight_layout()\n", + " plt.savefig(metadata(\"umap_no_axes.pdf\"))\n", + " # plt.show()\n", + " plt.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c2a51ff", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " X = df_shape_embed.to_numpy()\n", + " y = df_shape_embed.index.values\n", + "\n", + " properties = [\n", + " \"area\",\n", + " \"perimeter\",\n", + " \"centroid\",\n", + " \"major_axis_length\",\n", + " \"minor_axis_length\",\n", + " \"orientation\",\n", + " ]\n", + " dfs = []\n", + " for i, data in enumerate(train_data[\"transform_crop\"]):\n", + " X, y = data\n", + " # Do regionprops here\n", + " # Calculate shape summary statistics using regionprops\n", + " # We're considering that the mask has only one object, thus we take the first element [0]\n", + " # props = regionprops(np.array(X).astype(int))[0]\n", + " props_table = measure.regionprops_table(\n", + " np.array(X).astype(int), properties=properties\n", + " )\n", + "\n", + " # Store shape properties in a dataframe\n", + " df = pd.DataFrame(props_table)\n", + "\n", + " # Assuming the class or label is contained in 'y' variable\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True)\n", + " dfs.append(df)\n", + "\n", + " df_regionprops = pd.concat(dfs)\n", + "\n", + " # Assuming 'dataset_contour' is your DataLoader for the dataset\n", + " dfs = []\n", + " for i, data in enumerate(train_data[\"transform_coords\"]):\n", + " # Convert the tensor to a numpy array\n", + " X, y = data\n", + "\n", + " # Feed it to PyEFD's calculate_efd function\n", + " coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False)\n", + " # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]})\n", + "\n", + " norm_coeffs = pyefd.normalize_efd(coeffs)\n", + " df = pd.DataFrame(\n", + " {\n", + " \"norm_coeffs\": norm_coeffs.flatten().tolist(),\n", + " \"coeffs\": coeffs.flatten().tolist(),\n", + " }\n", + " ).T.rename_axis(\"coeffs\")\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True, append=True)\n", + " dfs.append(df)\n", + "\n", + " df_pyefd = pd.concat(dfs)\n", + "\n", + " trials = [\n", + " {\n", + " \"name\": \"mask_embed\",\n", + " \"features\": df_shape_embed.to_numpy(),\n", + " \"labels\": df_shape_embed.index,\n", + " },\n", + " {\n", + " \"name\": \"fourier_coeffs\",\n", + " \"features\": df_pyefd.xs(\"coeffs\", level=\"coeffs\"),\n", + " \"labels\": df_pyefd.xs(\"coeffs\", level=\"coeffs\").index,\n", + " },\n", + " # {\"name\": \"fourier_norm_coeffs\",\n", + " # \"features\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\"),\n", + " # \"labels\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\").index\n", + " # }\n", + " {\n", + " \"name\": \"regionprops\",\n", + " \"features\": df_regionprops,\n", + " \"labels\": df_regionprops.index,\n", + " },\n", + " ]\n", + "\n", + " trial_df = pd.DataFrame()\n", + " for trial in trials:\n", + " X = trial[\"features\"]\n", + " y = trial[\"labels\"]\n", + " trial[\"score_df\"] = scoring_df(X, y)\n", + " trial[\"score_df\"][\"trial\"] = trial[\"name\"]\n", + " print(trial[\"score_df\"])\n", + " trial[\"score_df\"].to_csv(metadata(f\"{trial['name']}_score_df.csv\"))\n", + " trial_df = pd.concat([trial_df, trial[\"score_df\"]])\n", + " trial_df = trial_df.drop([\"fit_time\", \"score_time\"], axis=1)\n", + "\n", + " trial_df.to_csv(metadata(\"trial_df.csv\"))\n", + " trial_df.groupby(\"trial\").mean().to_csv(metadata(\"trial_df_mean.csv\"))\n", + " trial_df.plot(kind=\"bar\")\n", + "\n", + " melted_df = trial_df.melt(id_vars=\"trial\", var_name=\"Metric\", value_name=\"Score\")\n", + " # fig, ax = plt.subplots(figsize=(width, height))\n", + " ax = sns.catplot(\n", + " data=melted_df,\n", + " kind=\"bar\",\n", + " x=\"trial\",\n", + " hue=\"Metric\",\n", + " y=\"Score\",\n", + " errorbar=\"se\",\n", + " height=height,\n", + " aspect=width * 2**0.5 / height,\n", + " )\n", + " # ax.xtick_params(labelrotation=45)\n", + " # plt.legend(loc='lower center', bbox_to_anchor=(1, 1))\n", + " # sns.move_legend(ax, \"lower center\", bbox_to_anchor=(1, 1))\n", + " # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " # plt.tight_layout()\n", + " plt.savefig(metadata(\"trials_barplot.pdf\"))\n", + " plt.close()\n", + "\n", + " avs = (\n", + " melted_df.set_index([\"trial\", \"Metric\"])\n", + " .xs(\"test_f1\", level=\"Metric\", drop_level=False)\n", + " .groupby(\"trial\")\n", + " .mean()\n", + " )\n", + " print(avs)\n", + " # tikzplotlib.save(metadata(f\"trials_barplot.tikz\"))\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " shape_embed_process()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/shape_embed.ipynb b/notebooks/shape_embed.ipynb new file mode 100644 index 00000000..9c0f3331 --- /dev/null +++ b/notebooks/shape_embed.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2d7ff7fc", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "import seaborn as sns\n", + "import pyefd\n", + "from sklearn.discriminant_analysis import StandardScaler\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.model_selection import (\n", + " cross_validate,\n", + " KFold,\n", + " train_test_split,\n", + ")\n", + "from sklearn.metrics import make_scorer\n", + "import pandas as pd\n", + "from sklearn import metrics\n", + "import matplotlib as mpl\n", + "from pathlib import Path\n", + "from sklearn.pipeline import Pipeline\n", + "from torch.autograd import Variable\n", + "from types import SimpleNamespace\n", + "import numpy as np\n", + "from skimage import measure\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "import pytorch_lightning as pl\n", + "import torch\n", + "from types import SimpleNamespace\n", + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", + "from umap import UMAP\n", + "# Deal with the filesystem\n", + "import torch.multiprocessing\n", + "import logging\n", + "from tqdm import tqdm\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "\n", + "from shape_embed import shapes\n", + "import bioimage_embed\n", + "from pytorch_lightning import loggers as pl_loggers\n", + "from torchvision import transforms\n", + "from bioimage_embed.lightning import DataModule\n", + "\n", + "from torchvision import datasets\n", + "from shape_embed.shapes.transforms import (\n", + " ImageToCoords,\n", + " CropCentroidPipeline,\n", + " DistogramToCoords,\n", + " RotateIndexingClockwise,\n", + " CoordsToDistogram,\n", + " AsymmetricDistogramToCoordsPipeline,\n", + ")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from matplotlib import rc\n", + "\n", + "import pickle\n", + "import base64\n", + "import hashlib\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "# Seed everything\n", + "np.random.seed(42)\n", + "pl.seed_everything(42)\n", + "\n", + "\n", + "def hashing_fn(args):\n", + " serialized_args = pickle.dumps(vars(args))\n", + " hash_object = hashlib.sha256(serialized_args)\n", + " hashed_string = base64.urlsafe_b64encode(hash_object.digest()).decode()\n", + " return hashed_string\n", + "\n", + "\n", + "def umap_plot(df, metadata, width=3.45, height=3.45 / 1.618):\n", + " umap_reducer = UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42)\n", + " mask = np.random.rand(len(df)) < 0.7\n", + "\n", + " semi_labels = df[\"Class\"].copy()\n", + " semi_labels[~mask] = -1 # Assuming -1 indicates unknown label for semi-supervision\n", + "\n", + " umap_embedding = umap_reducer.fit_transform(df, y=semi_labels)\n", + "\n", + " ax = sns.relplot(\n", + " data=pd.DataFrame(umap_embedding, columns=[\"umap0\", \"umap1\"]),\n", + " x=\"umap0\",\n", + " y=\"umap1\",\n", + " hue=\"Class\",\n", + " palette=\"deep\",\n", + " alpha=0.5,\n", + " edgecolor=None,\n", + " s=5,\n", + " height=height,\n", + " aspect=0.5 * width / height,\n", + " )\n", + "\n", + " sns.move_legend(\n", + " ax,\n", + " \"upper center\",\n", + " )\n", + " ax.set(xlabel=None, ylabel=None)\n", + " sns.despine(left=True, bottom=True)\n", + " plt.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)\n", + " plt.tight_layout()\n", + " plt.savefig(metadata(\"umap_no_axes.pdf\"))\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + "\n", + "def scoring_df(X, y):\n", + " # Split the data into training and test sets\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.2, random_state=42, shuffle=True, stratify=y\n", + " )\n", + " # Define a dictionary of metrics\n", + " scoring = {\n", + " \"accuracy\": make_scorer(metrics.balanced_accuracy_score),\n", + " \"precision\": make_scorer(metrics.precision_score, average=\"macro\"),\n", + " \"recall\": make_scorer(metrics.recall_score, average=\"macro\"),\n", + " \"f1\": make_scorer(metrics.f1_score, average=\"macro\"),\n", + " \"roc_auc\": make_scorer(metrics.roc_auc_score, average=\"macro\"),\n", + " }\n", + "\n", + " # Create a random forest classifier\n", + " pipeline = Pipeline(\n", + " [\n", + " (\"scaler\", StandardScaler()),\n", + " # (\"pca\", PCA(n_components=0.95, whiten=True, random_state=42)),\n", + " (\"clf\", RandomForestClassifier()),\n", + " # (\"clf\", DummyClassifier()),\n", + " ]\n", + " )\n", + "\n", + " # Specify the number of folds\n", + " k_folds = 5\n", + "\n", + " # Perform k-fold cross-validation\n", + " cv_results = cross_validate(\n", + " estimator=pipeline,\n", + " X=X,\n", + " y=y,\n", + " cv=StratifiedKFold(n_splits=k_folds),\n", + " scoring=scoring,\n", + " n_jobs=-1,\n", + " return_train_score=False,\n", + " )\n", + "\n", + " # Put the results into a DataFrame\n", + " return pd.DataFrame(cv_results)\n", + "\n", + "\n", + "def shape_embed_process():\n", + " # Setting the font size\n", + " mpl.rcParams[\"font.size\"] = 10\n", + "\n", + " # rc(\"text\", usetex=True)\n", + " rc(\"font\", **{\"family\": \"sans-serif\", \"sans-serif\": [\"Arial\"]})\n", + " width = 3.45\n", + " height = width / 1.618\n", + " plt.rcParams[\"figure.figsize\"] = [width, height]\n", + "\n", + " sns.set(\n", + " style=\"white\",\n", + " context=\"notebook\",\n", + " rc={\"figure.figsize\": (width, height)},\n", + " )\n", + "\n", + " # matplotlib.use(\"TkAgg\")\n", + " interp_size = 128 * 2\n", + " max_epochs = 100\n", + " window_size = 128 * 2\n", + "\n", + " params = {\n", + " \"model\": \"resnet50_vqvae\",\n", + " \"epochs\": 250,\n", + " \"batch_size\": 4,\n", + " \"num_workers\": 2**4,\n", + " \"input_dim\": (3, interp_size, interp_size),\n", + " \"latent_dim\": interp_size,\n", + " \"num_embeddings\": interp_size,\n", + " \"num_hiddens\": interp_size,\n", + " \"pretrained\": True,\n", + " \"commitment_cost\": 0.25,\n", + " \"decay\": 0.99,\n", + " \"frobenius_norm\": False,\n", + " # dataset = \"bbbc010/BBBC010_v1_foreground_eachworm\"\n", + " # dataset = \"vampire/mefs/data/processed/Control\"\n", + " \"dataset\": \"synthcellshapes_dataset\",\n", + " }\n", + "\n", + " optimizer_params = {\n", + " \"opt\": \"AdamW\",\n", + " \"lr\": 0.001,\n", + " \"weight_decay\": 0.0001,\n", + " \"momentum\": 0.9,\n", + " }\n", + "\n", + " lr_scheduler_params = {\n", + " \"sched\": \"cosine\",\n", + " \"min_lr\": 1e-4,\n", + " \"warmup_epochs\": 5,\n", + " \"warmup_lr\": 1e-6,\n", + " \"cooldown_epochs\": 10,\n", + " \"t_max\": 50,\n", + " \"cycle_momentum\": False,\n", + " }\n", + "\n", + " args = SimpleNamespace(**params, **optimizer_params, **lr_scheduler_params)\n", + "\n", + " dataset_path = args.dataset\n", + "\n", + " # train_data_path = f\"scripts/shapes/data/{dataset_path}\"\n", + " train_data_path = f\"data/{dataset_path}\"\n", + " metadata = lambda x: f\"results/{dataset_path}_{args.model}/{x}\"\n", + "\n", + " path = Path(metadata(\"\"))\n", + " path.mkdir(parents=True, exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "059effef", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "\n", + " transform_crop = CropCentroidPipeline(window_size)\n", + " # transform_dist = MaskToDistogramPipeline(\n", + " # window_size, interp_size, matrix_normalised=False\n", + " # )\n", + " transform_coord_to_dist = CoordsToDistogram(interp_size, matrix_normalised=False)\n", + " transform_mdscoords = DistogramToCoords(window_size)\n", + " transform_coords = ImageToCoords(window_size)\n", + "\n", + " transform_mask_to_gray = transforms.Compose([transforms.Grayscale(1)])\n", + "\n", + " transform_mask_to_crop = transforms.Compose(\n", + " [\n", + " # transforms.ToTensor(),\n", + " transform_mask_to_gray,\n", + " transform_crop,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_coords = transforms.Compose(\n", + " [\n", + " transform_mask_to_crop,\n", + " transform_coords,\n", + " ]\n", + " )\n", + "\n", + " transform_mask_to_dist = transforms.Compose(\n", + " [\n", + " transform_mask_to_coords,\n", + " transform_coord_to_dist,\n", + " ]\n", + " )\n", + "\n", + " gray2rgb = transforms.Lambda(lambda x: x.repeat(3, 1, 1))\n", + " transform = transforms.Compose(\n", + " [\n", + " transform_mask_to_dist,\n", + " transforms.ToTensor(),\n", + " RotateIndexingClockwise(p=1),\n", + " gray2rgb,\n", + " ]\n", + " )\n", + "\n", + " transforms_dict = {\n", + " \"none\": transform_mask_to_gray,\n", + " \"transform_crop\": transform_mask_to_crop,\n", + " \"transform_dist\": transform_mask_to_dist,\n", + " \"transform_coords\": transform_mask_to_coords,\n", + " }\n", + "\n", + " # Apply transform to find which images don't work\n", + " dataset = datasets.ImageFolder(train_data_path, transform=transform)\n", + "\n", + " valid_indices = []\n", + " # Iterate through the dataset and apply the transform to each image\n", + " for idx in range(len(dataset)):\n", + " try:\n", + " image, label = dataset[idx]\n", + " # If the transform works without errors, add the index to the list of valid indices\n", + " valid_indices.append(idx)\n", + " except Exception as e:\n", + " # A better way to do with would be with batch collation\n", + " logger.warning(f\"Error occurred for image {idx}: {e}\")\n", + "\n", + " train_data = {\n", + " key: torch.utils.data.Subset(\n", + " datasets.ImageFolder(train_data_path, transform=value),\n", + " valid_indices,\n", + " )\n", + " for key, value in transforms_dict.items()\n", + " }\n", + "\n", + " dataset = torch.utils.data.Subset(\n", + " datasets.ImageFolder(train_data_path, transform=transform),\n", + " valid_indices,\n", + " )\n", + "\n", + " for key, value in train_data.items():\n", + " logger.info(key, len(value))\n", + " plt.imshow(np.array(train_data[key][0][0]), cmap=\"gray\")\n", + " plt.imsave(metadata(f\"{key}.png\"), train_data[key][0][0], cmap=\"gray\")\n", + " # plt.show()\n", + " plt.close()\n", + "\n", + " # Retrieve the coordinates and cropped image\n", + " coords = train_data[\"transform_coords\"][0][0]\n", + " crop_image = train_data[\"transform_crop\"][0][0]\n", + "\n", + " fig = plt.figure(frameon=True)\n", + " ax = plt.Axes(fig, [0, 0, 1, 1])\n", + " ax.set_axis_off()\n", + " fig.add_axes(ax)\n", + "\n", + " # Display the cropped image using grayscale colormap\n", + " plt.imshow(crop_image, cmap=\"gray_r\")\n", + "\n", + " # Scatter plot with smaller point size\n", + " plt.scatter(*coords, c=np.arange(interp_size), cmap=\"rainbow\", s=2)\n", + "\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"transform_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + "\n", + " # Close the plot\n", + " plt.close()\n", + "\n", + " # Create a Subset using the valid indices\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=args.batch_size,\n", + " shuffle=True,\n", + " num_workers=args.num_workers,\n", + " )\n", + "\n", + " model = bioimage_embed.models.create_model(\n", + " model=args.model,\n", + " input_dim=args.input_dim,\n", + " latent_dim=args.latent_dim,\n", + " pretrained=args.pretrained,\n", + " )\n", + "\n", + " # lit_model = shapes.MaskEmbedLatentAugment(model, args)\n", + " lit_model = shapes.MaskEmbed(model, args)\n", + " test_data = dataset[0][0].unsqueeze(0)\n", + " # test_lit_data = 2*(dataset[0][0].unsqueeze(0).repeat_interleave(3, dim=1),)\n", + " test_output = lit_model.forward((test_data,))\n", + "\n", + " dataloader.setup()\n", + " model.eval()\n", + "\n", + " model_dir = f\"checkpoints/{hashing_fn(args)}\"\n", + "\n", + " tb_logger = pl_loggers.TensorBoardLogger(\"logs/\")\n", + " wandb = pl_loggers.WandbLogger(project=\"bioimage-embed\", name=\"shapes\")\n", + "\n", + " Path(f\"{model_dir}/\").mkdir(parents=True, exist_ok=True)\n", + "\n", + " checkpoint_callback = ModelCheckpoint(\n", + " dirpath=f\"{model_dir}/\",\n", + " save_last=True,\n", + " save_top_k=1,\n", + " monitor=\"loss/val\",\n", + " mode=\"min\",\n", + " )\n", + " wandb.watch(lit_model, log=\"all\")\n", + "\n", + " trainer = pl.Trainer(\n", + " logger=[wandb, tb_logger],\n", + " gradient_clip_val=0.5,\n", + " enable_checkpointing=True,\n", + " devices=1,\n", + " accelerator=\"gpu\",\n", + " accumulate_grad_batches=4,\n", + " callbacks=[checkpoint_callback],\n", + " min_epochs=50,\n", + " max_epochs=args.epochs,\n", + " # callbacks=[EarlyStopping(monitor=\"loss/val\", mode=\"min\")],\n", + " log_every_n_steps=1,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2aff6834", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " # Determine the checkpoint path for resuming\n", + " last_checkpoint_path = f\"{model_dir}/last.ckpt\"\n", + " best_checkpoint_path = checkpoint_callback.best_model_path\n", + "\n", + " # Check if a last checkpoint exists to resume from\n", + " if os.path.isfile(last_checkpoint_path):\n", + " resume_checkpoint = last_checkpoint_path\n", + " elif best_checkpoint_path and os.path.isfile(best_checkpoint_path):\n", + " resume_checkpoint = best_checkpoint_path\n", + " else:\n", + " resume_checkpoint = None\n", + "\n", + " trainer.fit(lit_model, datamodule=dataloader, ckpt_path=resume_checkpoint)\n", + "\n", + " lit_model.eval()\n", + "\n", + " validation = trainer.validate(lit_model, datamodule=dataloader)\n", + " # testing = trainer.test(lit_model, datamodule=dataloader)\n", + " example_input = Variable(torch.rand(1, *args.input_dim))\n", + "\n", + " # torch.jit.save(lit_model.to_torchscript(), f\"{model_dir}/model.pt\")\n", + " # torch.onnx.export(lit_model, example_input, f\"{model_dir}/model.onnx\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3659d80e", + "metadata": {}, + "outputs": [], + "source": [ + " # Inference on full dataset\n", + " dataloader = DataModule(\n", + " dataset,\n", + " batch_size=1,\n", + " shuffle=False,\n", + " num_workers=args.num_workers,\n", + " # Transform is commented here to avoid augmentations in real data\n", + " # HOWEVER, applying the transform multiple times and averaging the results might produce better latent embeddings\n", + " # transform=transform,\n", + " )\n", + " dataloader.setup()\n", + "\n", + " predictions = trainer.predict(lit_model, datamodule=dataloader)\n", + "\n", + " test_dist_pred = predictions[0].out.recon_x\n", + " plt.imsave(metadata(\"test_dist_pred.png\"), test_dist_pred.mean(axis=(0, 1)))\n", + " plt.close()\n", + "\n", + " test_dist_in = predictions[0].x.data\n", + " plt.imsave(metadata(\"test_dist_in.png\"), test_dist_in.mean(axis=(0, 1)))\n", + " plt.close()\n", + "\n", + " test_pred_coords = AsymmetricDistogramToCoordsPipeline(window_size=window_size)(\n", + " np.array(test_dist_pred[:, 0, :, :].unsqueeze(dim=0))\n", + " )\n", + "\n", + " plt.scatter(*test_pred_coords[0, 0].T)\n", + " # Save the plot as an image without border and coordinate axes\n", + " plt.savefig(metadata(\"test_pred_coords.png\"), bbox_inches=\"tight\", pad_inches=0)\n", + " plt.close()\n", + "\n", + " # Use the namespace variables\n", + " latent_space = torch.stack([d.out.z.flatten() for d in predictions])\n", + " scalings = torch.stack([d.x.scalings.flatten() for d in predictions])\n", + " idx_to_class = {v: k for k, v in dataset.dataset.class_to_idx.items()}\n", + " y = np.array([int(data[-1]) for data in dataloader.predict_dataloader()])\n", + "\n", + " y_partial = y.copy()\n", + " indices = np.random.choice(y.size, int(0.3 * y.size), replace=False)\n", + " y_partial[indices] = -1\n", + " y_blind = -1 * np.ones_like(y)\n", + "\n", + " df = pd.DataFrame(latent_space.numpy())\n", + " df[\"Class\"] = pd.Series(y).map(idx_to_class).astype(\"category\")\n", + " df[\"Scale\"] = scalings[:, 0].squeeze()\n", + " df = df.set_index(\"Class\")\n", + " df_shape_embed = df.copy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "940d43b7", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " X = df_shape_embed.to_numpy()\n", + " y = df_shape_embed.index\n", + "\n", + " properties = [\n", + " \"area\",\n", + " \"perimeter\",\n", + " \"centroid\",\n", + " \"major_axis_length\",\n", + " \"minor_axis_length\",\n", + " \"orientation\",\n", + " ]\n", + " dfs = []\n", + " # Distance matrix data\n", + " for i, data in enumerate(tqdm(train_data[\"transform_crop\"])):\n", + " X, y = data\n", + " # Do regionprops here\n", + " # Calculate shape summary statistics using regionprops\n", + " # We're considering that the mask has only one object, so we take the first element [0]\n", + " # props = regionprops(np.array(X).astype(int))[0]\n", + " props_table = measure.regionprops_table(\n", + " np.array(X).astype(int), properties=properties\n", + " )\n", + "\n", + " # Store shape properties in a dataframe\n", + " df = pd.DataFrame(props_table)\n", + "\n", + " # Assuming the class or label is contained in 'y' variable\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True)\n", + " dfs.append(df)\n", + "\n", + " df_regionprops = pd.concat(dfs)\n", + "\n", + " dfs = []\n", + " for i, data in enumerate(tqdm(train_data[\"transform_coords\"])):\n", + " # Convert the tensor to a numpy array\n", + " X, y = data\n", + "\n", + " # Feed it to PyEFD's calculate_efd function\n", + " coeffs = pyefd.elliptic_fourier_descriptors(X, order=10, normalize=False)\n", + " # coeffs_df = pd.DataFrame({'class': [y], 'norm_coeffs': [norm_coeffs.flatten().tolist()]})\n", + "\n", + " norm_coeffs = pyefd.normalize_efd(coeffs)\n", + " df = pd.DataFrame(\n", + " {\n", + " \"norm_coeffs\": norm_coeffs.flatten().tolist(),\n", + " \"coeffs\": coeffs.flatten().tolist(),\n", + " }\n", + " ).T.rename_axis(\"coeffs\")\n", + " df[\"class\"] = y\n", + " df.set_index(\"class\", inplace=True, append=True)\n", + " dfs.append(df)\n", + "\n", + " df_pyefd = pd.concat(dfs)\n", + "\n", + " trials = [\n", + " {\n", + " \"name\": \"mask_embed\",\n", + " \"features\": df_shape_embed.to_numpy(),\n", + " \"labels\": df_shape_embed.index,\n", + " },\n", + " {\n", + " \"name\": \"fourier_coeffs\",\n", + " \"features\": df_pyefd.xs(\"coeffs\", level=\"coeffs\"),\n", + " \"labels\": df_pyefd.xs(\"coeffs\", level=\"coeffs\").index,\n", + " },\n", + " # {\"name\": \"fourier_norm_coeffs\",\n", + " # \"features\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\"),\n", + " # \"labels\": df_pyefd.xs(\"norm_coeffs\", level=\"coeffs\").index\n", + " # }\n", + " {\n", + " \"name\": \"regionprops\",\n", + " \"features\": df_regionprops,\n", + " \"labels\": df_regionprops.index,\n", + " },\n", + " ]\n", + "\n", + " trial_df = pd.DataFrame()\n", + " for trial in trials:\n", + " X = trial[\"features\"]\n", + " y = trial[\"labels\"]\n", + " trial[\"score_df\"] = scoring_df(X, y)\n", + " trial[\"score_df\"][\"trial\"] = trial[\"name\"]\n", + " logger.info(trial[\"score_df\"])\n", + " trial[\"score_df\"].to_csv(metadata(f\"{trial['name']}_score_df.csv\"))\n", + " trial_df = pd.concat([trial_df, trial[\"score_df\"]])\n", + " trial_df = trial_df.drop([\"fit_time\", \"score_time\"], axis=1)\n", + "\n", + " trial_df.to_csv(metadata(\"trial_df.csv\"))\n", + " trial_df.groupby(\"trial\").mean().to_csv(metadata(\"trial_df_mean.csv\"))\n", + " trial_df.plot(kind=\"bar\")\n", + "\n", + " avg = trial_df.groupby(\"trial\").mean()\n", + " logger.info(avg)\n", + " avg.to_latex(metadata(\"trial_df.tex\"))\n", + "\n", + " melted_df = trial_df.melt(id_vars=\"trial\", var_name=\"Metric\", value_name=\"Score\")\n", + " # fig, ax = plt.subplots(figsize=(width, height))\n", + " ax = sns.catplot(\n", + " data=melted_df,\n", + " kind=\"bar\",\n", + " x=\"trial\",\n", + " hue=\"Metric\",\n", + " y=\"Score\",\n", + " errorbar=\"se\",\n", + " height=height,\n", + " aspect=width * 2**0.5 / height,\n", + " )\n", + " # ax.xtick_params(labelrotation=45)\n", + " # plt.legend(loc='lower center', bbox_to_anchor=(1, 1))\n", + " # sns.move_legend(ax, \"lower center\", bbox_to_anchor=(1, 1))\n", + " # ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')\n", + " # plt.tight_layout()\n", + " plt.savefig(metadata(\"trials_barplot.pdf\"))\n", + " plt.close()\n", + "\n", + " avs = (\n", + " melted_df.set_index([\"trial\", \"Metric\"])\n", + " .xs(\"test_f1\", level=\"Metric\", drop_level=False)\n", + " .groupby(\"trial\")\n", + " .mean()\n", + " )\n", + " logger.info(avs)\n", + " # tikzplotlib.save(metadata(f\"trials_barplot.tikz\"))\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " shape_embed_process()" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-all" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/scripts/_shape_embed.py b/scripts/_shape_embed.py index 5e8b11b9..755492e3 100644 --- a/scripts/_shape_embed.py +++ b/scripts/_shape_embed.py @@ -146,7 +146,7 @@ def shape_embed_process(): path = Path(metadata("")) path.mkdir(parents=True, exist_ok=True) model_dir = f"models/{dataset_path}_{args.model}" - # %% +# %% transform_crop = CropCentroidPipeline(window_size) transform_dist = MaskToDistogramPipeline( @@ -394,7 +394,7 @@ def shape_embed_process(): # plt.show() plt.close() - # %% +# %% X = df_shape_embed.to_numpy() y = df_shape_embed.index.values diff --git a/scripts/shape_embed.py b/scripts/shape_embed.py index f1064f7c..dad19802 100644 --- a/scripts/shape_embed.py +++ b/scripts/shape_embed.py @@ -213,7 +213,7 @@ def shape_embed_process(): path = Path(metadata("")) path.mkdir(parents=True, exist_ok=True) - # %% +# %% transform_crop = CropCentroidPipeline(window_size) # transform_dist = MaskToDistogramPipeline( @@ -372,7 +372,7 @@ def shape_embed_process(): # callbacks=[EarlyStopping(monitor="loss/val", mode="min")], log_every_n_steps=1, ) - # %% +# %% # Determine the checkpoint path for resuming last_checkpoint_path = f"{model_dir}/last.ckpt" @@ -446,7 +446,7 @@ def shape_embed_process(): df = df.set_index("Class") df_shape_embed = df.copy() - # %% +# %% X = df_shape_embed.to_numpy() y = df_shape_embed.index