Skip to content

Commit

Permalink
Lazy init refactor (#76)
Browse files Browse the repository at this point in the history
* add lazy_init stage to non-training scripts

* fix typo

---------

Co-authored-by: Arian Jamasb <[email protected]>
  • Loading branch information
a-r-j and Arian Jamasb authored Feb 9, 2024
1 parent 99696a6 commit 2efa6c7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
2 changes: 1 addition & 1 deletion proteinworkshop/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def embed(cfg: omegaconf.DictConfig):
# https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin
log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down
2 changes: 1 addition & 1 deletion proteinworkshop/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def finetune(cfg: DictConfig):
# https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin
log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down
59 changes: 44 additions & 15 deletions proteinworkshop/visualise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import hydra
import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import torch
import numpy as np
import umap
import umap.plot
import matplotlib.pyplot as plt
from beartype.typing import Any, Dict, List, Optional
from loguru import logger as log
from matplotlib.lines import Line2D
Expand Down Expand Up @@ -51,7 +51,7 @@ def draw_simple_ellipse(
alpha=alpha,
lw=0,
color=color,
**kwargs
**kwargs,
)
)

Expand Down Expand Up @@ -80,7 +80,7 @@ def visualise(cfg: omegaconf.DictConfig):
# https://pytorch.org/docs/stable/generated/torch.nn.modules.lazy.LazyModuleMixin.html#torch.nn.modules.lazy.LazyModuleMixin
log.info("Initializing lazy layers...")
with torch.no_grad():
datamodule.setup() # type: ignore
datamodule.setup(stage="lazy_init") # type: ignore
batch = next(iter(datamodule.val_dataloader()))
log.info(f"Unfeaturized batch: {batch}")
batch = model.featurise(batch)
Expand Down Expand Up @@ -166,20 +166,34 @@ def visualise(cfg: omegaconf.DictConfig):
out = model.forward(batch)
graph_embeddings = out["graph_embedding"]
node_embeddings = graph_embeddings.tolist()
collection.append({"embedding": node_embeddings, "labels": labels})
collection.append(
{"embedding": node_embeddings, "labels": labels}
)

# Derive clustering of embeddings using UMAP
assert len(collection) > 0 and len(collection[0]["embedding"]) > 0, "At least one batch of embeddings must be present to plot with UMAP."
clustering_data = np.array([batch for x in collection for batch in x["embedding"]])
clustering_labels = np.array([label for x in collection for label in x["labels"]])
umap_embeddings = umap.UMAP(random_state=cfg.seed).fit_transform(clustering_data)
assert (
len(collection) > 0 and len(collection[0]["embedding"]) > 0
), "At least one batch of embeddings must be present to plot with UMAP."
clustering_data = np.array(
[batch for x in collection for batch in x["embedding"]]
)
clustering_labels = np.array(
[label for x in collection for label in x["labels"]]
)
umap_embeddings = umap.UMAP(random_state=cfg.seed).fit_transform(
clustering_data
)

graph_label_available = cfg.visualise.label in batch
if graph_label_available:
clustering_label_indices = np.array(list(range(len(clustering_labels))))
clustering_label_indices = np.array(
list(range(len(clustering_labels)))
)
elif class_map_available:
orig_class_map = datamodule.parse_class_map()
clustering_label_indices = np.array([orig_class_map[label] for label in clustering_labels])
clustering_label_indices = np.array(
[orig_class_map[label] for label in clustering_labels]
)
else:
clustering_label_indices = clustering_labels

Expand All @@ -192,15 +206,19 @@ def visualise(cfg: omegaconf.DictConfig):
# Plot UMAP clustering
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
num_unique_labels = len(class_map) if class_map_available else len(np.unique(clustering_label_indices))
num_unique_labels = (
len(class_map)
if class_map_available
else len(np.unique(clustering_label_indices))
)
colors = plt.get_cmap("Spectral")(np.linspace(0, 1, num_unique_labels))

ax.scatter(
umap_embeddings[:, 0],
umap_embeddings[:, 1],
c=clustering_label_indices,
cmap="Spectral",
s=3
s=3,
)

# Create the legend with only the 20 most common labels
Expand All @@ -211,8 +229,19 @@ def visualise(cfg: omegaconf.DictConfig):
label_counts[label] += 1
else:
label_counts[label] = 1
top_20_labels = sorted(label_counts, key=label_counts.get, reverse=True)[:20]
legend_handles = [Line2D([0], [0], color=colors[orig_class_map[label]], lw=3, label=label) for label in top_20_labels]
top_20_labels = sorted(
label_counts, key=label_counts.get, reverse=True
)[:20]
legend_handles = [
Line2D(
[0],
[0],
color=colors[orig_class_map[label]],
lw=3,
label=label,
)
for label in top_20_labels
]
plt.legend(handles=legend_handles)

plt.xlabel("") # Remove x-axis label
Expand Down

0 comments on commit 2efa6c7

Please sign in to comment.