diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml
index 68077759..68577f72 100644
--- a/.github/workflows/checks.yml
+++ b/.github/workflows/checks.yml
@@ -36,15 +36,31 @@ jobs:
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ versions:
+ - python: "3.10"
+ torch: "1.13.1"
+ - python: "3.10"
+ torch: "2.0.1"
+ - python: "3.11"
+ torch: "2.0.1"
steps:
- name: Checkout code
uses: actions/checkout@v3
+ - name: Set up python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.versions.python }}
+
- name: Install dependencies
run: |
curl -sSL https://install.python-poetry.org | python3 -
poetry lock --check
- poetry install
+ export CUDA_VISIBLE_DEVICES=0
+ poetry add torch@${{ matrix.versions.torch }}+cpu --source torch_cpu
+ poetry install --all-extras
- name: Unit tests
run: make unit
diff --git a/.gitignore b/.gitignore
index 31f650c1..0bfc5dcc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,7 @@ models/
wandb
tests/_temp/**
tests/**/_temp/**
+notebooks/data/**
.coverage
htmlcov/
diff --git a/examples/datasets/custom-hallway-g8-n100-a_dfs-h20593.zanj b/examples/datasets/custom-hallway-g8-n100-a_dfs-h20593.zanj
new file mode 100644
index 00000000..70848975
Binary files /dev/null and b/examples/datasets/custom-hallway-g8-n100-a_dfs-h20593.zanj differ
diff --git a/examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj b/examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj
new file mode 100644
index 00000000..4c88d9d4
Binary files /dev/null and b/examples/hallway-medium_2023-06-16-03-40-47.iter_26554.zanj differ
diff --git a/maze_transformer/mechinterp/direct_logit_attribution.py b/maze_transformer/mechinterp/direct_logit_attribution.py
new file mode 100644
index 00000000..756eab91
--- /dev/null
+++ b/maze_transformer/mechinterp/direct_logit_attribution.py
@@ -0,0 +1,469 @@
+import datetime
+import json
+from pathlib import Path
+from typing import Literal
+
+import einops
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+from jaxtyping import Float, Int
+
+# maze-datset stuff
+from maze_dataset import MazeDataset, MazeDatasetConfig
+from maze_dataset.tokenization import MazeTokenizer
+
+# TransformerLens imports
+from transformer_lens import ActivationCache
+
+# mechinterp stuff
+from maze_transformer.mechinterp.logit_attrib_task import (
+ LOGIT_ATTRIB_TASKS,
+ DLAProtocolFixed,
+)
+from maze_transformer.mechinterp.logit_diff import (
+ logit_diff_residual_stream,
+ logits_diff_multi,
+ residual_stack_to_logit_diff,
+)
+from maze_transformer.mechinterp.logit_lens import plot_logit_lens
+from maze_transformer.mechinterp.plot_attention import plot_attention_final_token
+from maze_transformer.mechinterp.plot_logits import plot_logits
+
+# model stuff
+from maze_transformer.training.config import ZanjHookedTransformer
+
+
+def compute_direct_logit_attribution(
+ model: ZanjHookedTransformer,
+ cache: ActivationCache,
+ answer_tokens: Int[torch.Tensor, "n_mazes"],
+) -> dict[Literal["heads", "neurons"], Float[np.ndarray, "layer index"]]:
+ n_layers: int = model.zanj_model_config.model_cfg.n_layers
+ n_heads: int = model.zanj_model_config.model_cfg.n_heads
+ d_model: int = model.zanj_model_config.model_cfg.d_model
+ mlp_dim: int = 4 * d_model
+
+ print(f"{answer_tokens.shape = }")
+ print(f"{n_layers = }, {n_heads = }, {d_model = }")
+ print(f"{n_layers * n_heads = }")
+ print(f"{n_layers * mlp_dim = }")
+
+ # logit diff
+ avg_diff, diff_direction = logit_diff_residual_stream(
+ model=model,
+ cache=cache,
+ tokens_correct=answer_tokens,
+ tokens_compare_to=None,
+ directions=True,
+ )
+
+ # per head
+ per_head_residual, head_labels = cache.stack_head_results(
+ layer=-1, pos_slice=-1, return_labels=True
+ )
+
+ per_head_logit_diffs = residual_stack_to_logit_diff(
+ residual_stack=per_head_residual,
+ cache=cache,
+ logit_diff_directions=diff_direction,
+ )
+
+ print(f"{per_head_residual.shape = }")
+ print(f"{per_head_logit_diffs.shape = }")
+
+ per_head_logit_diffs = einops.rearrange(
+ per_head_logit_diffs,
+ "(layer head_index) -> layer head_index",
+ layer=n_layers,
+ head_index=n_heads,
+ )
+
+ print(f"{per_head_logit_diffs.shape = }")
+
+ # per neuron
+ per_neuron_residual, neuron_labels = cache.stack_neuron_results(
+ layer=-1,
+ pos_slice=-1,
+ return_labels=True,
+ )
+
+ per_neuron_logit_diffs = residual_stack_to_logit_diff(
+ residual_stack=per_neuron_residual,
+ cache=cache,
+ logit_diff_directions=diff_direction,
+ )
+
+ print(f"{per_neuron_residual.shape = }")
+ print(f"{per_neuron_logit_diffs.shape = }")
+
+ per_neuron_logit_diffs = einops.rearrange(
+ per_neuron_logit_diffs,
+ "(layer neuron_index) -> layer neuron_index",
+ layer=n_layers,
+ neuron_index=mlp_dim,
+ )
+
+ print(f"{per_neuron_logit_diffs.shape = }")
+
+ # return
+ return dict(
+ heads=per_head_logit_diffs.to("cpu").numpy(),
+ neurons=per_neuron_logit_diffs.to("cpu").numpy(),
+ )
+
+
+def plot_direct_logit_attribution(
+ model: ZanjHookedTransformer,
+ cache: ActivationCache,
+ answer_tokens: Int[torch.Tensor, "n_mazes"],
+ show: bool = True,
+) -> tuple[plt.Figure, plt.Axes, dict[str, Float[np.ndarray, "layer head"]],]:
+ dla_data: dict[str, torch.Tensor] = compute_direct_logit_attribution(
+ model=model,
+ cache=cache,
+ answer_tokens=answer_tokens,
+ )
+ dla_heads: Float[np.ndarray, "layer head"] = dla_data["heads"]
+ dla_neurons: Float[np.ndarray, "layer neuron"] = dla_data["neurons"]
+
+ extval_heads: float = np.max(np.abs(dla_heads))
+ extval_neurons: float = np.max(np.abs(dla_neurons))
+ fig_heads, ax_heads = plt.subplots(figsize=(5, 5))
+ fig_neurons, ax_neurons = plt.subplots(figsize=(20, 5))
+
+ # heads
+ ax_heads.imshow(dla_heads, cmap="RdBu", vmin=-extval_heads, vmax=extval_heads)
+ ax_heads.set_xlabel("Head")
+ ax_heads.set_ylabel("Layer")
+ plt.colorbar(ax_heads.get_images()[0], ax=ax_heads)
+ ax_heads.set_title(
+ f"Logit Difference from each head\n{model.zanj_model_config.name}"
+ )
+
+ # neurons
+ # don't enforce aspect ratio, no blending
+ ax_neurons.imshow(
+ dla_neurons,
+ cmap="RdBu",
+ vmin=-extval_neurons,
+ vmax=extval_neurons,
+ aspect="auto",
+ interpolation="none",
+ )
+ ax_neurons.set_xlabel("Neuron")
+ ax_neurons.set_ylabel("Layer")
+ plt.colorbar(ax_neurons.get_images()[0], ax=ax_neurons)
+ ax_neurons.set_title(
+ f"Logit Difference from each neuron\n{model.zanj_model_config.name}"
+ )
+
+ if show:
+ plt.show()
+
+ return (fig_heads, fig_neurons), (ax_heads, ax_neurons), dla_data
+
+
+def _output_codeblock(
+ data: str | dict,
+ lang: str = "",
+) -> str:
+ newdata: str = data
+
+ if isinstance(data, dict):
+ newdata = json.dumps(data, indent=2)
+
+ return f"```{lang}\n{newdata}\n```"
+
+
+def create_report(
+ model: ZanjHookedTransformer | str | Path,
+ dataset_cfg_source: MazeDatasetConfig | None,
+ logit_attribution_task_name: str,
+ n_examples: int = 100,
+ out_path: str | Path | None = None,
+ device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
+) -> Path:
+ # setup
+ # ======================================================================
+ torch.set_grad_enabled(False)
+
+ # model and tokenizer
+ if not isinstance(model, ZanjHookedTransformer):
+ model = ZanjHookedTransformer.read(model)
+ tokenizer: MazeTokenizer = model.zanj_model_config.maze_tokenizer
+
+ # dataset cfg
+ if dataset_cfg_source is None:
+ dataset_cfg_source = model.zanj_model_config.dataset_cfg
+
+ # output
+ if out_path is None:
+ out_path = Path(
+ f"data/dla_reports/{model.zanj_model_config.name}-{dataset_cfg_source.name}-{logit_attribution_task_name}-n{n_examples}/"
+ )
+
+ out_path.mkdir(parents=True, exist_ok=True)
+
+ fig_path: Path = out_path / "figures"
+ fig_path.mkdir(parents=True, exist_ok=True)
+ fig_path_md: Path = Path(f"figures")
+
+ output_md_path: Path = out_path / "report.md"
+ output_md = output_md_path.open("w")
+
+ # write header
+ output_md.write(
+ f"""---
+title: Direct Logit Attribution Report
+model_name: {model.zanj_model_config.name}
+dataset_cfg_name: {dataset_cfg_source.name}
+logit_attribution_task_name: {logit_attribution_task_name}
+n_examples: {n_examples}
+time: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
+---
+
+# Direct Logit Attribution Report
+
+## Model
+`{model.zanj_model_config.name}`
+{_output_codeblock(model.zanj_model_config.summary(), 'json')}
+
+## Dataset
+`{dataset_cfg_source.name}`
+{_output_codeblock(dataset_cfg_source.summary(), 'json')}
+
+"""
+ )
+
+ # task
+ logit_attribution_task: DLAProtocolFixed = LOGIT_ATTRIB_TASKS[
+ logit_attribution_task_name
+ ]
+
+ # dataset
+ dataset: MazeDataset = MazeDataset.from_config(dataset_cfg_source)
+ dataset_tokens: list[list[str]] = dataset.as_tokens(
+ tokenizer, join_tokens_individual_maze=False
+ )
+
+ dataset_prompts: list[list[str]]
+ dataset_targets: list[str]
+ dataset_prompts, dataset_targets = logit_attribution_task(dataset_tokens)
+ dataset_prompts_joined: list[str] = [" ".join(prompt) for prompt in dataset_prompts]
+ dataset_target_ids: Float[torch.Tensor, "n_mazes"] = torch.tensor(
+ tokenizer.encode(dataset_targets), dtype=torch.long
+ )
+
+ # print some info about dataset
+
+ n_mazes: int = len(dataset)
+ d_vocab: int = tokenizer.vocab_size
+
+ output_md.write(
+ f"""
+
+number of mazes: {n_mazes}
+vocabulary size: {d_vocab}
+
+### First Maze
+
+full: {_output_codeblock(' '.join(dataset_prompts[0]))}
+prompt: {_output_codeblock('[...] ' + dataset_prompts_joined[0][-150:])}
+target: {_output_codeblock(dataset_targets[0])}
+target id: {_output_codeblock(str(dataset_target_ids[0]))}
+
+![First maze as raster image]({fig_path_md / 'first_maze.png'})
+
+"""
+ )
+
+ plt.imsave(fig_path / "first_maze.png", dataset[0].as_pixels())
+
+ # run model
+ # ======================================================================
+
+ logits: Float[torch.Tensor, "n_mazes seq_len d_vocab"]
+ cache: ActivationCache
+ logits, cache = model.run_with_cache(dataset_prompts_joined, device=device)
+
+ last_tok_logits: Float[torch.Tensor, "n_mazes d_vocab"] = logits[:, -1, :]
+
+ output_md.write(
+ f"""# Model Output
+
+```
+logits.shape: {logits.shape}
+cache_shapes: TODO
+last_tok_logits.shape: {last_tok_logits.shape}
+```
+"""
+ )
+
+ plot_logits(
+ last_tok_logits=last_tok_logits,
+ target_idxs=dataset_target_ids,
+ tokenizer=tokenizer,
+ n_bins=50,
+ show=False,
+ )
+ plt.savefig(fig_path / "last_tok_logits.png")
+ output_md.write(f"![last token logits]({fig_path_md / 'last_tok_logits.png'})\n")
+
+ predicted_tokens: list[str] = tokenizer.decode(
+ last_tok_logits.argmax(dim=-1).tolist()
+ )
+ prediction_correct: Float[torch.Tensor, "n_mazes"] = torch.tensor(
+ [pred == target for pred, target in zip(predicted_tokens, dataset_targets)]
+ )
+
+ output_md.write(
+ f"""
+```
+predicted_tokens.shape: {len(predicted_tokens)}
+prediction_correct.shape: {prediction_correct.shape}
+prediction_correct.mean(): {prediction_correct.float().mean().item()}
+```
+"""
+ )
+
+ # logit diff
+ logit_diff_df: pd.DataFrame = logits_diff_multi(
+ model=model,
+ cache=cache,
+ dataset_target_ids=dataset_target_ids,
+ last_tok_logits=last_tok_logits,
+ noise_sigmas=np.logspace(0, 3, 100),
+ )
+
+ output_md.write(
+ f"""
+# Logit Difference
+```
+logit_diff_df.shape: {logit_diff_df.shape}
+```
+
+```
+{logit_diff_df}
+```
+"""
+ )
+
+ # scatter separately for "all" vs "random"
+ fig, ax = plt.subplots()
+ for compare_to in ["all", "random"]:
+ df = logit_diff_df[logit_diff_df["compare_to"] == compare_to]
+ ax.scatter(
+ df["result_orig"],
+ df["result_res"],
+ label=f"comparing to {compare_to}",
+ marker="o",
+ )
+ ax.legend()
+ plt.xlabel("result_orig")
+ plt.ylabel("result_res")
+ plt.title("Scatter Plot between result_orig and result_res")
+ plt.savefig(fig_path / "logit_diff_scatter.png")
+ output_md.write(
+ f"![logit difference scatterplot comparison]({fig_path_md / 'logit_diff_scatter.png'})\n"
+ )
+
+ # logit lens
+ logitlens_figax, logitlens_results = plot_logit_lens(
+ model=model,
+ cache=cache,
+ answer_tokens=dataset_target_ids,
+ show=False,
+ )
+ logitlens_figax[0].savefig(fig_path / "logitlens.png")
+ output_md.write(f"![logit lens results]({fig_path_md / 'logitlens.png'})\n")
+
+ # direct logit attribution
+ output_md.write(f"# Direct Logit Attribution")
+ dla_fig, dla_ax, dla_data = plot_direct_logit_attribution(
+ model=model,
+ cache=cache,
+ answer_tokens=dataset_target_ids,
+ show=False,
+ )
+ dla_ax.set_title(
+ f"Logit difference from each head\n{model.zanj_model_config.name}\n'{logit_attribution_task_name}' task"
+ )
+
+ dla_fig.savefig(fig_path / "logit_attribution.png")
+ output_md.write(f"![logit attribution]({fig_path_md / 'logit_attribution.png'})\n")
+
+ # head analysis
+ # let's try to plot the values of the attention heads for the top and bottom n contributing heads
+ # (layer, head, value)
+ top_heads: int = 5
+ important_heads: list[tuple[int, int, float]] = sorted(
+ [
+ (i, j, dla_data[i, j])
+ for i in range(dla_data.shape[0])
+ for j in range(dla_data.shape[1])
+ ],
+ key=lambda x: abs(x[2]),
+ reverse=True,
+ )[:top_heads]
+ # print(f"{important_heads = }")
+ output_md.write(
+ f"""
+# Head Analysis
+top {top_heads} heads: `{important_heads}`
+"""
+ )
+
+ # plot the attention heads
+ important_heads_scores = {
+ f"layer_{i}.head_{j}": (
+ c,
+ cache[f"blocks.{i}.attn.hook_attn_scores"][:, j, :, :].numpy(),
+ )
+ for i, j, c in important_heads
+ }
+
+ attn_final_tok_output: list[dict] = plot_attention_final_token(
+ important_heads_scores=important_heads_scores,
+ prompts=dataset_prompts,
+ targets=dataset_targets,
+ mazes=dataset,
+ tokenizer=tokenizer,
+ n_mazes=3,
+ last_n_tokens=20,
+ exponentiate_scores=False,
+ maze_colormap_center=0.0,
+ # important
+ show_all=False,
+ print_fmt="latex",
+ )
+
+ head_fig_path: Path = fig_path / "head_analysis"
+ head_fig_path.mkdir(parents=True, exist_ok=True)
+ head_fig_path_md: Path = fig_path_md / "head_analysis"
+
+ for i, attn_final_tok in enumerate(attn_final_tok_output):
+ head_info: dict = attn_final_tok["head_info"]
+ head_lbl: str = head_info["head"]
+
+ attn_final_tok["scores"][0].savefig(head_fig_path / f"scores-{head_lbl}.png")
+ attn_final_tok["attn_maze"][0].savefig(
+ head_fig_path / f"attn_maze-{head_lbl}.png"
+ )
+
+ output_md.write(
+ f"""
+## Head {head_lbl}
+head info: `{head_info}`
+
+{attn_final_tok['colored_tokens']}
+
+![scores of attention head over tokens]({head_fig_path_md / f'scores-{head_lbl}.png'})
+![scores of attention head over maze]({head_fig_path_md / f'attn_maze-{head_lbl}.png'})
+"""
+ )
+
+ # cleaning up
+ output_md.flush()
+ output_md.close()
diff --git a/maze_transformer/mechinterp/logit_attrib_task.py b/maze_transformer/mechinterp/logit_attrib_task.py
new file mode 100644
index 00000000..6ef67363
--- /dev/null
+++ b/maze_transformer/mechinterp/logit_attrib_task.py
@@ -0,0 +1,137 @@
+import functools
+import typing
+
+import numpy as np
+from jaxtyping import Float
+from maze_dataset import SPECIAL_TOKENS
+
+
+def get_token_first_index(search_token: str, token_list: list[str]) -> int:
+ return token_list.index(search_token)
+
+
+TaskSetup = typing.NamedTuple(
+ "TaskSetup",
+ [
+ ("prompts", list[list[str]]),
+ ("targets", str),
+ ],
+)
+
+
+class DLAProtocol(typing.Protocol):
+ """should take a dataset's tokens, and return a tuple of (prompts, targets)"""
+
+ def __call__(self, dataset_tokens: list[list[str]], **kwargs) -> TaskSetup:
+ ...
+
+
+class DLAProtocolFixed(typing.Protocol):
+ """should take a dataset's tokens, and return a tuple of (prompts, targets)
+
+ this variant signifies it's ready to be used -- no keyword arguments are needed
+ """
+
+ def __call__(self, dataset_tokens: list[list[str]]) -> TaskSetup:
+ ...
+
+
+def token_after_fixed_start_token(
+ dataset_tokens: list[list[str]],
+ start_token: str = SPECIAL_TOKENS.PATH_START,
+ offset: int = 1,
+) -> TaskSetup:
+ """in this task, we simply predict the token after `start_token`
+
+ # Parameters:
+ - `dataset_tokens : list[list[str]]`
+ list of string-lists
+ - `start_token : str`
+ token to look for
+ (defaults to `SPECIAL_TOKENS.PATH_START`)
+ - `offset : int`
+ which token to predict:
+ 1: the token after `start_token`, given everything up to and including `start_token`
+ 0: the token at `start_token`, given everything up to and **not** including `start_token`
+ (defaults to `1`)
+
+ # Returns:
+ - `TaskSetup`
+ tuple of (prompts, targets)
+ """
+
+ prompts: list[list[str]] = list()
+ targets: list[str] = list()
+
+ for maze_tokens in dataset_tokens:
+ path_start_idx: int = get_token_first_index(start_token, maze_tokens)
+ prompt_tokens: list[str] = maze_tokens[: path_start_idx + offset]
+ prompts.append(prompt_tokens)
+ targets.append(maze_tokens[path_start_idx + offset])
+
+ return TaskSetup(prompts=prompts, targets=targets)
+
+
+def rand_token_in_range(
+ dataset_tokens: list[list[str]],
+ start_token: str = SPECIAL_TOKENS.PATH_START,
+ end_token: str = SPECIAL_TOKENS.PATH_END,
+ start_offset: int = 1,
+ end_offset: int = -1,
+) -> TaskSetup:
+ """predict some random token between (non-inclusive) `start_token` and `end_token`"""
+ n_samples: int = len(dataset_tokens)
+
+ prompts: list[list[str]] = list()
+ targets: list[str] = list()
+ positions_p: Float[np.ndarray, "n_samples"] = np.random.uniform(size=(n_samples,))
+
+ for i, sample_tokens in enumerate(dataset_tokens):
+ start_idx: int = (
+ get_token_first_index(start_token, sample_tokens) + start_offset
+ )
+ end_idx: int = get_token_first_index(end_token, sample_tokens) + end_offset
+
+ selected_token_idx: int
+ if start_idx < end_idx:
+ selected_token_idx = int(positions_p[i] * (end_idx - start_idx) + start_idx)
+ else:
+ selected_token_idx = start_idx
+
+ prompts.append(sample_tokens[:selected_token_idx])
+ targets.append(sample_tokens[selected_token_idx])
+
+ return TaskSetup(prompts=prompts, targets=targets)
+
+
+LOGIT_ATTRIB_TASKS: dict[str, DLAProtocolFixed] = {
+ "path_start": functools.partial(
+ token_after_fixed_start_token, start_token=SPECIAL_TOKENS.PATH_START, offset=0
+ ),
+ "origin_after_path_start": functools.partial(
+ token_after_fixed_start_token, start_token=SPECIAL_TOKENS.PATH_START, offset=1
+ ),
+ "first_path_choice": functools.partial(
+ token_after_fixed_start_token, start_token=SPECIAL_TOKENS.PATH_START, offset=2
+ ),
+ "path_end": functools.partial(
+ token_after_fixed_start_token, start_token=SPECIAL_TOKENS.PATH_END, offset=0
+ ),
+ "final_before_path_end": functools.partial(
+ token_after_fixed_start_token, start_token=SPECIAL_TOKENS.PATH_END, offset=-1
+ ),
+ "rand_path_token": functools.partial(
+ rand_token_in_range,
+ start_token=SPECIAL_TOKENS.PATH_START,
+ end_token=SPECIAL_TOKENS.PATH_END,
+ start_offset=1,
+ end_offset=-1,
+ ),
+ "rand_path_token_non_endpoint": functools.partial(
+ rand_token_in_range,
+ start_token=SPECIAL_TOKENS.PATH_START,
+ end_token=SPECIAL_TOKENS.PATH_END,
+ start_offset=3,
+ end_offset=-2,
+ ),
+}
diff --git a/maze_transformer/mechinterp/logit_diff.py b/maze_transformer/mechinterp/logit_diff.py
new file mode 100644
index 00000000..87976659
--- /dev/null
+++ b/maze_transformer/mechinterp/logit_diff.py
@@ -0,0 +1,198 @@
+# Numerical Computing
+import pandas as pd
+import torch
+from fancy_einsum import einsum
+from jaxtyping import Float, Int
+
+# TransformerLens imports
+from transformer_lens import ActivationCache, HookedTransformer
+
+# model stuff
+from maze_transformer.training.config import ZanjHookedTransformer
+
+LArr = Float[torch.Tensor, "samples"]
+
+
+def residual_stack_to_logit_diff(
+ residual_stack: Float[torch.Tensor, "components batch d_model"],
+ cache: ActivationCache,
+ logit_diff_directions: Float[torch.Tensor, "samples d_model"],
+) -> float:
+ scaled_residual_stack = cache.apply_ln_to_stack(
+ residual_stack, layer=-1, pos_slice=-1
+ )
+
+ return (
+ einsum(
+ "... batch d_model, batch d_model -> ...",
+ scaled_residual_stack,
+ logit_diff_directions,
+ )
+ / logit_diff_directions.shape[0]
+ )
+
+
+def logit_diff_direct(
+ model_logits: Float[torch.Tensor, "samples d_vocab"],
+ tokens_correct: Int[torch.Tensor, "samples"],
+ tokens_compare_to: Int[torch.Tensor, "samples"] | None = None,
+ diff_per_prompt: bool = True,
+) -> Float[torch.Tensor, "samples"] | float:
+ """based on Neel's explanatory notebook
+
+ https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb
+
+ if `tokens_compare_to` is None, then compare to sum of logits on all other tokens
+ """
+
+ # logit on the correct answer token for each sample
+ model_logits_on_correct: LArr = torch.gather(
+ model_logits, 1, tokens_correct.unsqueeze(1)
+ ).squeeze(1)
+
+ output_diff: LArr
+ if tokens_compare_to is None:
+ # subtract total logits across all other tokens
+ all_logits: LArr = torch.sum(model_logits, dim=1)
+ output_diff = model_logits_on_correct - (all_logits - model_logits_on_correct)
+ else:
+ # subtract just the logit on the compare_to token
+ logits_compare_to: LArr = torch.gather(
+ model_logits, 1, tokens_compare_to.unsqueeze(1)
+ ).squeeze(1)
+ output_diff = model_logits_on_correct - logits_compare_to
+
+ assert output_diff.shape == tokens_correct.shape
+
+ if diff_per_prompt:
+ return output_diff
+ else:
+ return output_diff.mean().item()
+
+ # return answer_logits / (all_logits - answer_logits)
+
+
+def logit_diff_residual_stream(
+ model: ZanjHookedTransformer,
+ cache: ActivationCache,
+ tokens_correct: Int[torch.Tensor, "samples"],
+ tokens_compare_to: Int[torch.Tensor, "samples"] | None = None,
+ directions: bool = False,
+) -> float | tuple[float, torch.Tensor]:
+ d_vocab: int = model.config.maze_tokenizer.vocab_size
+ d_model: int = model.config.model_cfg.d_model
+
+ # embed the whole vocab first
+ vocab_tensor: Float[torch.Tensor, "d_vocab"] = torch.arange(
+ d_vocab, dtype=torch.long
+ )
+ vocab_residual_directions: Float[
+ torch.Tensor, "d_vocab d_model"
+ ] = model.tokens_to_residual_directions(vocab_tensor)
+ # get embedding of answer tokens
+ answer_residual_directions = vocab_residual_directions[tokens_correct]
+ # get the directional difference between logits and corrent and logits on {all other tokens, comparison tokens}
+ logit_diff_directions: Float[torch.Tensor, "samples d_model"]
+ if tokens_compare_to is None:
+ logit_diff_directions = (
+ answer_residual_directions - vocab_residual_directions[~tokens_correct]
+ )
+ else:
+ logit_diff_directions = (
+ answer_residual_directions - vocab_residual_directions[tokens_compare_to]
+ )
+
+ # get the values from the cache at the last layer and last token
+ final_token_residual_stream: Float[torch.Tensor, "samples d_model"] = cache[
+ "resid_post", -1
+ ][:, -1, :]
+
+ # scaling the values in residual stream with layer norm
+ scaled_final_token_residual_stream: Float[
+ torch.Tensor, "samples d_model"
+ ] = cache.apply_ln_to_stack(
+ final_token_residual_stream,
+ layer=-1,
+ pos_slice=-1,
+ )
+
+ # measure similarity between the logit diff directions and the residual stream at final layer directions
+ average_logit_diff: float = (
+ torch.dot(
+ scaled_final_token_residual_stream.flatten(),
+ logit_diff_directions.flatten(),
+ )
+ / logit_diff_directions.shape[0]
+ ).item()
+
+ if directions:
+ return average_logit_diff, logit_diff_directions
+ else:
+ return average_logit_diff
+
+
+def logits_diff_multi(
+ model: HookedTransformer,
+ cache: ActivationCache,
+ dataset_target_ids: Int[torch.Tensor, "samples"],
+ last_tok_logits: Float[torch.Tensor, "samples d_vocab"],
+ noise_sigmas: list[float] = [1, 2, 3, 5, 10],
+ n_randoms: int = 1,
+) -> pd.DataFrame:
+ d_vocab: int = last_tok_logits.shape[1]
+
+ test_logits: dict[str, Float[torch.Tensor, "samples"]] = {
+ "target": dataset_target_ids,
+ "predicted": last_tok_logits.argmax(dim=-1),
+ "sampled": torch.multinomial(
+ torch.softmax(last_tok_logits, dim=-1), num_samples=1
+ ).squeeze(-1),
+ **{
+ f"noise={s:.2f}": (
+ last_tok_logits + s * torch.randn_like(last_tok_logits)
+ ).argmax(dim=-1)
+ for s in noise_sigmas
+ },
+ # "random": torch.randint_like(dataset_target_ids, low=0, high=d_vocab),
+ **{
+ f"random_r{i}": torch.randint_like(dataset_target_ids, low=0, high=d_vocab)
+ for i in range(n_randoms)
+ },
+ }
+ compare_dict: dict[str, None | Float[torch.Tensor, "samples"]] = {
+ "all": None,
+ "random": torch.randint_like(dataset_target_ids, low=0, high=d_vocab),
+ "target": dataset_target_ids,
+ }
+
+ outputs: list[dict] = list()
+
+ for k_comp, compare_to in compare_dict.items():
+ for k, d in test_logits.items():
+ result_orig: float = logit_diff_direct(
+ model_logits=last_tok_logits,
+ tokens_correct=d,
+ diff_per_prompt=False,
+ tokens_compare_to=compare_to,
+ )
+ result_res: float = logit_diff_residual_stream(
+ model=model,
+ cache=cache,
+ tokens_correct=d,
+ tokens_compare_to=compare_to,
+ )
+ # print(f"logit diff of {k}\tcompare:\t{'all' if compare_to is None else 'random'}\t{result = }\t{result_res = }")
+ outputs.append(
+ dict(
+ test=k,
+ compare_to=k_comp,
+ result_orig=result_orig,
+ result_res=result_res,
+ )
+ )
+
+ df_out: pd.DataFrame = pd.DataFrame(outputs)
+ df_out["diff"] = df_out["result_orig"] - df_out["result_res"]
+ df_out["ratio"] = df_out["result_orig"] / df_out["result_res"]
+
+ return df_out
diff --git a/maze_transformer/mechinterp/logit_lens.py b/maze_transformer/mechinterp/logit_lens.py
new file mode 100644
index 00000000..e418502c
--- /dev/null
+++ b/maze_transformer/mechinterp/logit_lens.py
@@ -0,0 +1,109 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from jaxtyping import Int
+
+# TransformerLens imports
+from transformer_lens import ActivationCache
+
+# mechinterp stuff
+from maze_transformer.mechinterp.logit_diff import (
+ logit_diff_residual_stream,
+ residual_stack_to_logit_diff,
+)
+
+# model stuff
+from maze_transformer.training.config import ZanjHookedTransformer
+
+
+def compute_logit_lens(
+ model: ZanjHookedTransformer,
+ cache: ActivationCache,
+ answer_tokens: Int[torch.Tensor, "n_mazes"],
+) -> tuple[
+ torch.Tensor,
+ torch.Tensor, # x/y for diff
+ torch.Tensor,
+ torch.Tensor, # x/y for attr
+]:
+ # logit diff
+ avg_diff, diff_direction = logit_diff_residual_stream(
+ model=model,
+ cache=cache,
+ tokens_correct=answer_tokens,
+ tokens_compare_to=None,
+ directions=True,
+ )
+
+ accumulated_residual, labels = cache.accumulated_resid(
+ layer=-1,
+ incl_mid=True,
+ pos_slice=-1,
+ return_labels=True,
+ )
+
+ logit_lens_logit_diffs = residual_stack_to_logit_diff(
+ residual_stack=accumulated_residual,
+ cache=cache,
+ logit_diff_directions=diff_direction,
+ )
+
+ # logit attribution
+ per_layer_residual, labels = cache.decompose_resid(
+ layer=-1, pos_slice=-1, return_labels=True
+ )
+ per_layer_logit_diffs = residual_stack_to_logit_diff(
+ residual_stack=per_layer_residual,
+ cache=cache,
+ logit_diff_directions=diff_direction,
+ )
+
+ return (
+ # np.arange(model.zanj_model_config.model_cfg.n_layers*2+1)/2,
+ np.arange(logit_lens_logit_diffs.shape[0]),
+ logit_lens_logit_diffs.to("cpu").numpy(),
+ np.arange(per_layer_logit_diffs.shape[0]),
+ per_layer_logit_diffs.to("cpu").numpy(),
+ )
+
+
+def plot_logit_lens(
+ model: ZanjHookedTransformer,
+ cache: ActivationCache,
+ answer_tokens: Int[torch.Tensor, "n_mazes"],
+ show: bool = True,
+) -> tuple[
+ tuple[plt.Figure, plt.Axes, plt.Axes], # figure and axes
+ tuple[
+ torch.Tensor,
+ torch.Tensor, # x/y for diff
+ torch.Tensor,
+ torch.Tensor, # x/y for attr
+ ],
+]:
+ diff_x, diff_y, attr_x, attr_y = compute_logit_lens(
+ model=model,
+ cache=cache,
+ answer_tokens=answer_tokens,
+ )
+
+ fig, ax1 = plt.subplots(figsize=(10, 5))
+
+ ax1.set_xlabel("Layer")
+ ax1.set_ylabel("Logit Difference", color="tab:blue")
+ ax1.set_title("Logit Lens")
+ ax1.plot(diff_x, diff_y, label="Logit Difference", color="tab:blue")
+ ax1.tick_params(axis="y", labelcolor="tab:blue")
+ ax1.legend(loc="upper left")
+
+ # create a second y-axis sharing the same x-axis
+ ax2 = ax1.twinx()
+
+ ax2.set_ylabel("Logit Attribution", color="tab:red")
+ ax2.plot(attr_x, attr_y, label="Logit Attribution", color="tab:red")
+ ax2.tick_params(axis="y", labelcolor="tab:red")
+ ax2.legend(loc="lower right")
+
+ plt.show()
+
+ return (fig, ax1, ax2), (diff_x, diff_y, attr_x, attr_y)
diff --git a/maze_transformer/evaluation/plot_attention.py b/maze_transformer/mechinterp/plot_attention.py
similarity index 55%
rename from maze_transformer/evaluation/plot_attention.py
rename to maze_transformer/mechinterp/plot_attention.py
index 1910643d..d5a28aa6 100644
--- a/maze_transformer/evaluation/plot_attention.py
+++ b/maze_transformer/mechinterp/plot_attention.py
@@ -1,9 +1,8 @@
# Generic
import typing
+from collections import defaultdict
# plotting
-import IPython
-import matplotlib
import matplotlib.pyplot as plt
# Numerical Computing
@@ -16,6 +15,9 @@
from jaxtyping import Float
from maze_dataset import CoordTup, MazeDataset, MazeDatasetConfig, SolvedMaze
from maze_dataset.plotting import MazePlot
+from maze_dataset.plotting.plot_tokens import plot_colored_text
+from maze_dataset.plotting.print_tokens import color_tokens_cmap
+from maze_dataset.tokenization import MazeTokenizer
from maze_dataset.tokenization.token_utils import coord_str_to_tuple_noneable
# Utilities
@@ -197,6 +199,7 @@ def plot_attentions_on_maze(
for idx_attn, (name, attn) in enumerate(
zip(self.attention_names, self.attention_tensored)
):
+ # --------------------------
# empty node values
node_values: Float[np.ndarray, "grid_n grid_n"] = np.zeros(
self.input_maze.grid_shape
@@ -216,6 +219,7 @@ def plot_attentions_on_maze(
node_values=node_values,
color_map=color_map,
)
+ # --------------------------
# create a shared figure
fig, axs = plt.subplots(
@@ -232,32 +236,202 @@ def plot_attentions_on_maze(
return fig, axs
-def colorize(
- tokens: list[str],
- weights: list[float],
- cmap: matplotlib.colors.Colormap | str = "Blues",
- template: str = ' {tok} ',
-) -> str:
- """given a sequence of tokens and their weights, colorize the tokens according to the weights (output is html)
+def mazeplot_attention(
+ maze: SolvedMaze,
+ tokens_context: str,
+ target: str,
+ attention: Float[np.ndarray, "n_tokens"],
+ mazeplot: MazePlot | None = None,
+ cmap: str = "RdBu",
+ min_for_positive: float = 0.0,
+ show_other_tokens: bool = True,
+ fig_ax: tuple[plt.Figure, plt.Axes] | None = None,
+ colormap_center: None | float | typing.Literal["median", "mean"] = None,
+) -> tuple[MazePlot, plt.Figure, plt.Axes]:
+ # storing attention
+ node_values: Float[np.ndarray, "grid_n grid_n"] = np.zeros(maze.grid_shape)
+ total_logits_nonpos = defaultdict(float)
+
+ # get node values for each token
+ for idx_token, token in enumerate(tokens_context):
+ coord: CoordTup | None = coord_str_to_tuple_noneable(token)
+ # TODO: mean/median instead of just sum?
+ if coord is not None:
+ node_values[coord[0], coord[1]] += np.sum(attention[idx_token])
+ else:
+ total_logits_nonpos[token] += attention[idx_token]
+
+ # MazePlot attentions
+ if mazeplot is None:
+ mazeplot = MazePlot(maze)
+
+ final_prompt_coord: CoordTup | None = coord_str_to_tuple_noneable(
+ tokens_context[-1]
+ )
+ target_coord: CoordTup | None = coord_str_to_tuple_noneable(target)
+
+ colormap_center_val: float | None
+ if colormap_center is None:
+ colormap_center_val = None
+ elif colormap_center == "median":
+ colormap_center_val = np.median(attention)
+ elif colormap_center == "mean":
+ colormap_center_val = np.mean(attention)
+ else:
+ colormap_center_val = colormap_center
+
+ mazeplot.add_node_values(
+ node_values=node_values,
+ color_map=cmap,
+ target_token_coord=target_coord,
+ preceeding_tokens_coords=[final_prompt_coord]
+ if final_prompt_coord is not None
+ else None,
+ colormap_center=colormap_center_val,
+ )
+
+ # set up combined figure
+ if fig_ax is None:
+ fig, (ax_maze, ax_other) = plt.subplots(
+ 2,
+ 1,
+ figsize=(7, 7),
+ height_ratios=[7, 1],
+ )
+ else:
+ fig, (ax_maze, ax_other) = fig_ax
+ # set height ratio
+ mazeplot.plot(
+ title=f"{attention.min() = }\n{attention.max() = }",
+ fig_ax=(fig, ax_maze),
+ )
+
+ # non-pos tokens attention
+ total_logits_nonpos_processed: tuple[list[str], list[float]] = tuple(
+ zip(*sorted(total_logits_nonpos.items(), key=lambda x: x[0]))
+ )
+
+ if len(total_logits_nonpos_processed) == 2:
+ plot_colored_text(
+ total_logits_nonpos_processed[0],
+ total_logits_nonpos_processed[1],
+ cmap=cmap,
+ ax=ax_other,
+ fontsize=5,
+ width_scale=0.01,
+ char_min=5,
+ )
+ else:
+ print(f"No non-pos tokens found!\n{total_logits_nonpos_processed = }")
+
+ ax_other.set_title("Non-Positional Tokens Attention")
+
+ return mazeplot, fig, (ax_maze, ax_other)
+
+
+def plot_attention_final_token(
+ important_heads_scores: dict[
+ str,
+ tuple[float, Float[np.ndarray, "n_mazes n_tokens n_tokens"]],
+ ],
+ prompts: list[list[str]],
+ targets: list[str],
+ mazes: list[SolvedMaze],
+ tokenizer: MazeTokenizer,
+ n_mazes: int = 5,
+ last_n_tokens: int = 20,
+ exponentiate_scores: bool = False,
+ plot_colored_tokens: bool = True,
+ plot_scores: bool = True,
+ plot_attn_maze: bool = True,
+ maze_colormap_center: None | float | typing.Literal["median", "mean"] = None,
+ show_all: bool = True,
+ print_fmt: str = "terminal",
+) -> list[dict]:
+ # str, # head info
+ # str|None, # colored tokens text
+ # tuple[plt.Figure, plt.Axes]|None, # scores plot
+ # tuple[plt.Figure, plt.Axes]|None, # attn maze plot
+ output: list[dict[str, str | None | tuple[plt.Figure, plt.Axes]]] = list()
+
+ for k, (c, v) in important_heads_scores.items():
+ head_info: str = f"head: {k}, score: {c = }, {v.shape = }"
+ if show_all:
+ print("-" * 80)
+ print(head_info)
+
+ head_output: dict[str, str | None | tuple[plt.Figure, plt.Axes]] = dict(
+ head_info_str=head_info,
+ head_info=dict(
+ head=k,
+ score=c,
+ shape=v.shape,
+ ),
+ )
- originally from https://stackoverflow.com/questions/59220488/to-visualize-attention-color-tokens-using-attention-weights
- """
+ # set up scores across tokens figure
+ if plot_scores:
+ scores_fig, scores_ax = plt.subplots(n_mazes, 1)
+ scores_fig.set_size_inches(30, 4 * n_mazes)
+
+ # set up attention across maze figure
+ if plot_attn_maze:
+ mazes_fig, mazes_ax = plt.subplots(
+ 2,
+ n_mazes,
+ figsize=(7 * n_mazes, 7),
+ height_ratios=[7, 1],
+ )
- if isinstance(cmap, str):
- cmap = matplotlib.cm.get_cmap(cmap)
+ # for each maze
+ for i in range(n_mazes):
+ # process tokens and attention scores
+ n_tokens_prompt = len(prompts[i])
+ n_tokens_view = min(n_tokens_prompt, last_n_tokens)
+ v_final = v[i][-1] # -1 for last token
+ if exponentiate_scores:
+ v_final = np.exp(v_final)
+
+ # print token scores
+ if plot_colored_tokens:
+ color_tokens_text: str = color_tokens_cmap(
+ prompts[i][-n_tokens_view:],
+ v_final[-n_tokens_view:],
+ fmt=print_fmt,
+ labels=(print_fmt == "terminal"),
+ )
+ if show_all:
+ print(color_tokens_text)
- colored_string: str = ""
+ head_output["colored_tokens"] = color_tokens_text
- for word, color in zip(tokens, weights):
- color_hex: str = matplotlib.colors.rgb2hex(cmap(color)[:3])
- colored_string += template.format(clr=color_hex, tok=word)
+ # plot across tokens
+ if plot_scores:
+ scores_ax[i].plot(
+ v_final[-n_tokens_prompt:],
+ "o",
+ )
+ scores_ax[i].grid(axis="x", which="major", color="black", alpha=0.1)
+ scores_ax[i].set_xticks(range(n_tokens_prompt), prompts[i], rotation=90)
+
+ head_output["scores"] = (scores_fig, scores_ax)
+
+ # plot attention across maze
+ if plot_attn_maze:
+ mazeplot, fig, ax = mazeplot_attention(
+ maze=mazes[i],
+ tokens_context=prompts[i][-n_tokens_prompt:],
+ target=targets[i],
+ attention=v_final[-n_tokens_prompt:],
+ fig_ax=(mazes_fig, mazes_ax[:, i]),
+ colormap_center=maze_colormap_center,
+ )
- return colored_string
+ head_output["attn_maze"] = (mazes_fig, mazes_ax)
+ if show_all:
+ plt.show()
+ else:
+ output.append(head_output)
-def _test():
- mystr: str = "lorem ipsum dolor sit amet consectetur adipiscing elit sed do eiusmod tempor incididunt ut labore et dolore magna aliqua ut enim ad minim veniam quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur excepteur sint occaecat cupidatat non proident sunt in culpa qui officia deserunt mollit anim id est laborum"
- tokens: list[str] = mystr.split()
- weights: list[float] = np.random.rand(len(tokens)).tolist()
- colored: str = colorize(tokens, weights)
- IPython.display.display(IPython.display.HTML(colored))
+ return output
diff --git a/maze_transformer/mechinterp/plot_logits.py b/maze_transformer/mechinterp/plot_logits.py
new file mode 100644
index 00000000..3ce7c717
--- /dev/null
+++ b/maze_transformer/mechinterp/plot_logits.py
@@ -0,0 +1,119 @@
+# Numerical Computing
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from jaxtyping import Float, Int
+
+# Our Code
+from maze_dataset.tokenization import MazeTokenizer
+
+_DEFAULT_SUBPLOTS_KWARGS: dict = dict(
+ figsize=(20, 20),
+ height_ratios=[3, 1],
+)
+
+
+def plot_logits(
+ last_tok_logits: Float[torch.Tensor, "n_mazes d_vocab"],
+ target_idxs: Int[torch.Tensor, "n_mazes"],
+ tokenizer: MazeTokenizer,
+ n_bins: int = 50,
+ mark_incorrect: bool = True,
+ mark_correct: bool = False,
+ subplots_kwargs: dict | None = None,
+ show: bool = True,
+ density: bool = True,
+ logy: bool = False,
+) -> None:
+ # set up figure
+ # --------------------------------------------------
+ n_mazes: int
+ d_vocab: int
+ n_mazes, d_vocab = last_tok_logits.shape
+ if subplots_kwargs is None:
+ subplots_kwargs = _DEFAULT_SUBPLOTS_KWARGS
+
+ fig, (ax_all, ax_sum) = plt.subplots(
+ 2, 1, **{**_DEFAULT_SUBPLOTS_KWARGS, **subplots_kwargs}
+ )
+
+ # fig.subplots_adjust(hspace=0.5, bottom=0.1, top=0.9, left=0.1, right=0.9)
+
+ # plot heatmap of logits
+ # --------------------------------------------------
+ # all vocab elements
+ ax_all.set_xlabel("vocab element logit")
+ ax_all.set_ylabel("maze index")
+ # add vocab as xticks
+ ax_all.set_xticks(ticks=np.arange(d_vocab), labels=tokenizer.token_arr, rotation=90)
+ ax_all.imshow(last_tok_logits.numpy(), aspect="auto")
+ # set colorbar
+ plt.colorbar(ax_all.imshow(last_tok_logits.numpy(), aspect="auto"), ax=ax_all)
+
+ if mark_correct:
+ # place yellow x at max logit token
+ ax_all.scatter(
+ last_tok_logits.argmax(dim=1),
+ np.arange(n_mazes),
+ marker="x",
+ color="yellow",
+ )
+ # place red dot at correct token
+ ax_all.scatter(target_idxs, np.arange(n_mazes), marker=".", color="red")
+ if mark_incorrect:
+ raise ValueError("mark_correct and mark_incorrect cannot both be True")
+
+ if mark_incorrect:
+ # place a red dot wherever the max logit token is not the correct token
+ ax_all.scatter(
+ last_tok_logits.argmax(dim=1)[last_tok_logits.argmax(dim=1) != target_idxs],
+ np.arange(n_mazes)[last_tok_logits.argmax(dim=1) != target_idxs],
+ marker=".",
+ color="red",
+ )
+
+ # histogram of logits for correct and incorrect tokens
+ # --------------------------------------------------
+ ax_sum.set_ylabel("probability density" if density else "frequency")
+ ax_sum.set_xlabel("logit value")
+
+ # get correct token logits
+ correct_token_logits: Float[torch.Tensor, "n_mazes"] = torch.gather(
+ last_tok_logits, 1, target_idxs.unsqueeze(1)
+ ).squeeze(1)
+ mask = torch.ones(n_mazes, d_vocab, dtype=torch.bool)
+ mask.scatter_(1, target_idxs.unsqueeze(1), False)
+ other_token_logits: Float[torch.Tensor, "n_mazes d_vocab-1"] = last_tok_logits[
+ mask
+ ].reshape(n_mazes, d_vocab - 1)
+
+ # plot histogram
+ bins: Float[np.ndarray, "n_bins"] = np.linspace(
+ last_tok_logits.min(), last_tok_logits.max(), n_bins
+ )
+ ax_sum.hist(
+ correct_token_logits.numpy(),
+ density=density,
+ bins=bins,
+ label="correct token",
+ alpha=0.5,
+ )
+ ax_sum.hist(
+ other_token_logits.numpy().flatten(),
+ density=density,
+ bins=bins,
+ label="other token",
+ alpha=0.5,
+ )
+ ax_sum.legend()
+ if logy:
+ ax_sum.set_yscale("log")
+
+ if show:
+ plt.show()
+
+ return fig, (ax_all, ax_sum)
+
+
+def plot_logits_maze(*args, **kwargs):
+ raise NotImplementedError()
diff --git a/maze_transformer/mechinterp/plot_weights.py b/maze_transformer/mechinterp/plot_weights.py
new file mode 100644
index 00000000..2deffb7a
--- /dev/null
+++ b/maze_transformer/mechinterp/plot_weights.py
@@ -0,0 +1,196 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from jaxtyping import Float, Int
+from transformer_lens import HookedTransformer
+
+
+def _weights_plot_helper(
+ fig: plt.Figure,
+ ax: plt.Axes,
+ data: Float[np.ndarray, "inputs_outputs_or_1 n_interesting_neurons"],
+ title: str,
+ ylabel: str = None,
+ cmap: str = "RdBu",
+):
+ # plot heatmap
+ n_rows: int = data.shape[0]
+ singlerow: bool = n_rows == 1
+ vbound: float = np.max(np.abs(data))
+ im: plt.AxesImage = ax.imshow(
+ data,
+ aspect="auto" if not singlerow else "equal",
+ interpolation="none",
+ cmap=cmap,
+ vmin=-vbound,
+ vmax=vbound,
+ )
+ # colorbar
+ fig.colorbar(im, ax=ax)
+ # other figure adjustments
+ ax.set_title(title)
+ if ylabel:
+ ax.set_ylabel(ylabel)
+
+ if singlerow:
+ ax.set_yticks([])
+
+
+def plot_important_neurons(
+ model: HookedTransformer,
+ layer: int,
+ neuron_idxs: Int[np.ndarray, "neuron_idxs"] | None = None,
+ neuron_dla_data: Float[np.ndarray, "n_layers n_neurons"] | None = None,
+ n_important_neurons: int = 10,
+ show: bool = True,
+) -> tuple[plt.Figure, plt.Axes]:
+ """Plot the weights and biases for the selected or most important neurons in a given layer
+
+ - if both of `neuron_idxs` and `neuron_dla_data` are `None`, then all neurons will be plotted
+ - if a value is provided for `neuron_idxs`, then only those neurons will be plotted
+ - if a value is provided for `neuron_dla_data`, then the most important neurons will be selected based on the DLA data
+ """
+
+ # get dimension info from model state dict (expecting TransformerLens style)
+
+ # state dict
+ state_dict: dict[str, torch.Tensor] = model.state_dict()
+ state_dict_keys: list[str] = list(state_dict.keys())
+
+ # layers
+ layer_ids: list[int] = sorted(
+ list(
+ set(
+ [
+ int(key.split(".")[1])
+ for key in state_dict_keys
+ if key.startswith("blocks.")
+ ]
+ )
+ )
+ )
+ n_layers: int = len(layer_ids)
+ assert n_layers == max(layer_ids) + 1, f"Layers are not contiguous? {layer_ids}"
+ assert layer_ids == list(range(n_layers)), f"Layers are not contiguous? {layer_ids}"
+ # handle layer negative indexing
+ if layer < 0:
+ layer = layer_ids[layer]
+ assert layer in layer_ids, f"Layer {layer} not found in {layer_ids}"
+
+ # model dim and hidden dim
+ d_model: int
+ n_neurons: int
+ d_model, n_neurons = state_dict[f"blocks.{layer}.mlp.W_in"].shape
+
+ # dim checks for sanity
+ assert state_dict[f"blocks.{layer}.mlp.b_in"].shape[0] == n_neurons
+ assert state_dict[f"blocks.{layer}.mlp.W_out"].shape[0] == n_neurons
+ assert state_dict[f"blocks.{layer}.mlp.W_out"].shape[1] == d_model
+ assert state_dict[f"blocks.{layer}.mlp.b_out"].shape[0] == d_model
+
+ # get the neuron indices to plot
+
+ # all neurons if nothing specified
+ if neuron_idxs is None and neuron_dla_data is None:
+ neuron_idxs = np.arange(n_neurons)
+
+ # from dla data
+ if neuron_dla_data is not None:
+ assert (
+ neuron_idxs is None
+ ), "Cannot provide both neuron_idxs and neuron_dla_data"
+
+ neuron_idxs: np.ndarray = np.argsort(np.abs(neuron_dla_data[layer]))[
+ -n_important_neurons:
+ ][::-1]
+
+ mlp_key_base: str = f"blocks.{layer}.mlp"
+
+ # Cache model state for easier access
+ model_state = model.state_dict()
+
+ # Create named subplots, tight layout
+ fig, axes = plt.subplots(
+ 3 + int(neuron_dla_data is not None), # w_in, b_in, w_out, dla (if applicable)
+ 1,
+ figsize=(10, 10),
+ sharex=True,
+ gridspec_kw=dict(hspace=0.1, wspace=0.1),
+ )
+
+ if neuron_dla_data is not None:
+ ax_w_in, ax_b_in, ax_w_out, ax_dla = axes
+ else:
+ ax_w_in, ax_b_in, ax_w_out = axes
+
+ # Plot in weight
+ w_in_data = model_state[mlp_key_base + ".W_in"].cpu().numpy()[:, neuron_idxs]
+ _weights_plot_helper(fig, ax_w_in, w_in_data, "W_in", "input neuron")
+
+ # Plot in bias
+ b_in_data = model_state[mlp_key_base + ".b_in"].cpu().numpy()[neuron_idxs][None, :]
+ _weights_plot_helper(fig, ax_b_in, b_in_data, "b_in")
+
+ # Plot out weight
+ w_out_data = model_state[mlp_key_base + ".W_out"].cpu().numpy()[neuron_idxs, :].T
+ _weights_plot_helper(fig, ax_w_out, w_out_data, "W_out", "output neuron")
+
+ # Plot DLA
+ neuron_dla_data = neuron_dla_data[layer][neuron_idxs][None, :]
+ _weights_plot_helper(fig, ax_dla, neuron_dla_data, "DLA")
+
+ # Show the plot
+ if show:
+ plt.show()
+
+ return fig, axes
+
+
+def plot_embeddings(
+ model: HookedTransformer, token_arr: list[str], show: bool = True
+) -> tuple[plt.Figure, plt.Axes]:
+ # Get the weight matrices for vocab and positional embeddings
+ W_E: Float[torch.Tensor, "vocab_size d_model"] = model.W_E
+ W_pos: Float[torch.Tensor, "max_seq_len d_model"] = model.W_pos
+
+ # Make sure they have the same dimension
+ d_model: int = W_E.shape[1]
+ assert W_pos.shape[1] == d_model
+
+ # Create the figure and axes
+ fig, (ax_e, ax_pos) = plt.subplots(2, 1, figsize=(16, 16), sharex=True)
+
+ # Visualize vocab embeddings
+ vbound_e: float = W_E.abs().max().item()
+ ax_e.imshow(
+ W_E.cpu().detach().numpy(),
+ cmap="RdBu",
+ aspect="auto",
+ vmin=-vbound_e,
+ vmax=vbound_e,
+ )
+ ax_e.set_title("Vocab Embeddings")
+ ax_e.set_ylabel("vocab item")
+ ax_e.set_yticks(np.arange(len(token_arr)))
+ ax_e.set_yticklabels(token_arr, fontsize=5)
+ fig.colorbar(ax_e.get_images()[0], ax=ax_e)
+
+ # Visualize positional embeddings
+ vbound_pos: float = W_pos.abs().max().item()
+ ax_pos.imshow(
+ W_pos.cpu().detach().numpy(),
+ cmap="RdBu",
+ aspect="auto",
+ vmin=-vbound_pos,
+ vmax=vbound_pos,
+ )
+ ax_pos.set_title("Positional Embeddings")
+ ax_pos.set_ylabel("pos vs token embed")
+ ax_pos.set_xlabel("d_model")
+ fig.colorbar(ax_pos.get_images()[0], ax=ax_pos)
+
+ # Show the plot
+ if show:
+ plt.show()
+
+ return fig, (ax_e, ax_pos)
diff --git a/maze_transformer/mechinterp/residual_stream_structure.py b/maze_transformer/mechinterp/residual_stream_structure.py
new file mode 100644
index 00000000..86fa7e77
--- /dev/null
+++ b/maze_transformer/mechinterp/residual_stream_structure.py
@@ -0,0 +1,404 @@
+import itertools
+from typing import NamedTuple
+
+import matplotlib.pyplot as plt
+
+# numerical
+import numpy as np
+import seaborn as sns
+from jaxtyping import Float
+
+# maze_dataset
+from maze_dataset.constants import _SPECIAL_TOKENS_ABBREVIATIONS
+from maze_dataset.tokenization import MazeTokenizer
+from maze_dataset.tokenization.token_utils import strings_to_coords
+
+# scipy
+from scipy.spatial.distance import pdist, squareform
+from scipy.stats import pearsonr
+from sklearn.decomposition import PCA
+
+# transformerlens
+from transformer_lens import HookedTransformer
+
+# from scipy.spatial.distance import cosine
+
+
+def coordinate_to_color(
+ coord: tuple[float, float], max_val: float = 1.0
+) -> tuple[float, float, float]:
+ """Maps a coordinate (i, j) to a unique RGB color"""
+ coord = np.array(coord)
+ if max_val < coord.max():
+ raise ValueError(
+ f"max_val ({max_val}) must be at least as large as the largest coordinate ({coord.max()})"
+ )
+
+ coord = coord / max_val
+
+ return (
+ coord[0] * 0.6 + 0.3, # r
+ 0.5, # g
+ coord[1] * 0.6 + 0.3, # b
+ )
+
+
+TokenPlottingInfo = NamedTuple(
+ "TokenPlottingInfo",
+ token=str,
+ coord=tuple[float, float] | str,
+ color=tuple[float, float, float],
+)
+
+
+def process_tokens_for_pca(tokenizer: MazeTokenizer) -> list[TokenPlottingInfo]:
+ tokens_coords: list[str | tuple[int, int]] = strings_to_coords(
+ tokenizer.token_arr, when_noncoord="include"
+ )
+ tokens_coords_only: list[tuple[int, int]] = strings_to_coords(
+ tokenizer.token_arr, when_noncoord="skip"
+ )
+ max_coord: int = np.array(tokens_coords_only).max()
+ # token_idxs_coords: list[int] = tokenizer.encode(tokenizer.coords_to_strings(tokens_coords_only))
+
+ vocab_coordinates_colored: list[TokenPlottingInfo] = [
+ TokenPlottingInfo(*x)
+ for x in zip(
+ tokenizer.token_arr,
+ tokens_coords,
+ [
+ coordinate_to_color(coord, max_val=max_coord)
+ if isinstance(coord, tuple)
+ else (0.0, 1.0, 0.0)
+ for coord in tokens_coords
+ ],
+ )
+ ]
+
+ return vocab_coordinates_colored
+
+
+EmbeddingsPCAResult = NamedTuple(
+ "EmbeddingsPCAResult",
+ result=np.ndarray,
+ index_map=list[int] | None,
+ pca_obj=PCA,
+)
+
+
+def compute_pca(
+ model: HookedTransformer,
+ token_plotting_info: list[TokenPlottingInfo],
+) -> dict[str, EmbeddingsPCAResult]:
+ pca_all: PCA = PCA(svd_solver="full")
+ pca_coords: PCA = PCA(svd_solver="full")
+ pca_special: PCA = PCA(svd_solver="full")
+
+ # PCA_RESULTS = pca_all.fit_transform(MODEL.W_E.cpu().numpy().T)
+ # PCA_RESULTS_COORDS_ONLY = pca_coords.fit_transform(MODEL.W_E[token_idxs_coords].cpu().numpy().T)
+
+ idxs_coords: list[int] = list()
+ idxs_special: list[int] = list()
+
+ i: int
+ tokinfo: TokenPlottingInfo
+ for i, tokinfo in enumerate(token_plotting_info):
+ if isinstance(tokinfo.coord, tuple):
+ idxs_coords.append(i)
+ elif isinstance(tokinfo.coord, str):
+ idxs_special.append(i)
+ else:
+ raise ValueError(
+ f"unexpected coord type: {type(tokinfo.coord)}\n{tokinfo = }"
+ )
+
+ return dict(
+ all=EmbeddingsPCAResult(
+ result=pca_all.fit_transform(model.W_E.cpu().numpy().T),
+ index_map=None,
+ pca_obj=pca_all,
+ ),
+ coords_only=EmbeddingsPCAResult(
+ result=pca_coords.fit_transform(model.W_E[idxs_coords].cpu().numpy().T),
+ index_map=idxs_coords,
+ pca_obj=pca_coords,
+ ),
+ special_only=EmbeddingsPCAResult(
+ result=pca_special.fit_transform(model.W_E[idxs_special].cpu().numpy().T),
+ index_map=idxs_special,
+ pca_obj=pca_special,
+ ),
+ )
+
+
+def plot_pca_colored(
+ pca_results_options: dict[str, EmbeddingsPCAResult],
+ pca_results_key: str,
+ vocab_colors: list[tuple],
+ dim1: int,
+ dim2: int,
+ lattice_connections: bool = True,
+ symlog_scale: float | None = None,
+ axes_and_centered: bool = True,
+) -> tuple[plt.Figure, plt.Axes]:
+ # set up figure, get PCA results
+ fig, ax = plt.subplots(figsize=(5, 5))
+ pca_result: EmbeddingsPCAResult = pca_results_options[pca_results_key]
+
+ # Store lattice points for drawing connections
+ lattice_points: tuple[tuple[int, int], tuple[float, float]] = list()
+
+ for i in range(pca_result.result.shape[1]):
+ # map index if necessary
+ if pca_result.index_map is not None:
+ i_map: int = pca_result.index_map[i]
+ else:
+ i_map = i
+ token, coord, color = vocab_colors[i_map]
+ # plot the point
+ ax.scatter(
+ pca_result.result[dim1 - 1, i],
+ pca_result.result[dim2 - 1, i],
+ alpha=0.5,
+ color=color,
+ )
+ if isinstance(coord, str):
+ # label with the abbreviated token name
+ ax.text(
+ pca_result.result[dim1 - 1, i],
+ pca_result.result[dim2 - 1, i],
+ _SPECIAL_TOKENS_ABBREVIATIONS[coord],
+ fontsize=8,
+ )
+ else:
+ # add to the lattice points list for later
+ lattice_points.append(
+ (
+ coord,
+ (pca_result.result[dim1 - 1, i], pca_result.result[dim2 - 1, i]),
+ )
+ )
+
+ if axes_and_centered:
+ # find x and y limits
+ xbound: float = np.max(np.abs(pca_result.result[dim1 - 1])) * 1.1
+ ybound: float = np.max(np.abs(pca_result.result[dim2 - 1])) * 1.1
+ # set axes limits
+ ax.set_xlim(-xbound, xbound)
+ ax.set_ylim(-ybound, ybound)
+ # plot axes
+ ax.plot([-xbound, xbound], [0, 0], color="black", alpha=0.5, linewidth=0.5)
+ ax.plot([0, 0], [-ybound, ybound], color="black", alpha=0.5, linewidth=0.5)
+
+ # add lattice connections
+ if lattice_connections:
+ for (i, j), (x, y) in lattice_points:
+ for (i2, j2), (x2, y2) in lattice_points:
+ # manhattan distance of 1
+ if np.linalg.norm(np.array([i, j]) - np.array([i2, j2]), ord=1) == 1:
+ # plot a line between the two points
+ ax.plot(
+ [x, x2],
+ [y, y2],
+ color="red",
+ alpha=0.2,
+ linewidth=0.5,
+ )
+
+ ax.set_xlabel(f"PC{dim1}")
+ ax.set_ylabel(f"PC{dim2}")
+ ax.set_title(f"PCA of Survey Responses:\nPC{dim1} vs PC{dim2}")
+
+ # semi-log scale
+ if isinstance(symlog_scale, (float, int)):
+ if symlog_scale > 0:
+ ax.set_xscale("symlog", linthresh=symlog_scale)
+ ax.set_yscale("symlog", linthresh=symlog_scale)
+
+ return fig, ax
+
+
+def compute_distances_and_correlation(
+ embedding_matrix: Float[np.ndarray, "d_vocab d_model"],
+ tokenizer: MazeTokenizer,
+ embedding_metric: str = "cosine",
+ coordinate_metric: str = "euclidean",
+ show: bool = True,
+) -> dict:
+ """embedding distances passed to pdist from scipy"""
+
+ coord_tokens_ids: dict[str, int] = tokenizer.coordinate_tokens_ids
+ coord_embeddings: Float[np.ndarray, "n_coord_tokens d_model"] = np.array(
+ [embedding_matrix[v] for v in coord_tokens_ids.values()]
+ )
+
+ # Calculate the pairwise distances in embedding space
+ embedding_distances: Float[np.ndarray, "n_coord_tokens d_model"] = pdist(
+ coord_embeddings,
+ metric=embedding_metric,
+ )
+ # normalize the distance by the maximum distance
+ embedding_distances /= embedding_distances.max()
+
+ # Convert the distances to a square matrix
+ embedding_distances_matrix: Float[
+ np.ndarray, "n_coord_tokens n_coord_tokens"
+ ] = squareform(embedding_distances)
+
+ # Calculate the correlation between the embedding and coordinate distances
+ coordinate_coordinates: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
+ list(tokenizer.coordinate_tokens_coords.keys())
+ )
+ coordinate_distances = pdist(
+ coordinate_coordinates,
+ metric=coordinate_metric,
+ )
+ correlation, corr_pval = pearsonr(embedding_distances, coordinate_distances)
+
+ return dict(
+ embedding_distances_matrix=embedding_distances_matrix,
+ correlation=correlation,
+ corr_pval=corr_pval,
+ tokenizer=tokenizer,
+ embedding_metric=embedding_metric,
+ coordinate_metric=coordinate_metric,
+ )
+
+
+def plot_distances_matrix(
+ embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"],
+ tokenizer: MazeTokenizer,
+ embedding_metric: str,
+ show: bool = True,
+ **kwargs,
+) -> tuple[plt.Figure, plt.Axes]:
+ coord_tokens_ids: dict[str, int] = tokenizer.coordinate_tokens_ids
+
+ # Plot the embedding distances
+ fig, ax = plt.subplots(figsize=(15, 15))
+ cax = ax.matshow(
+ embedding_distances_matrix,
+ cmap="viridis",
+ interpolation="none",
+ )
+ ax.grid(which="major", color="white", linestyle="-", linewidth=0.5)
+ ax.grid(which="minor", color="white", linestyle="-", linewidth=0.0)
+ fig.colorbar(cax)
+
+ ax.set_xticks(np.arange(len(coord_tokens_ids)))
+ ax.set_yticks(np.arange(len(coord_tokens_ids)))
+ ax.set_xticklabels(coord_tokens_ids.keys())
+ ax.set_yticklabels(coord_tokens_ids.keys())
+
+ plt.setp(ax.get_xticklabels(), rotation=90, ha="left", rotation_mode="anchor")
+
+ ax.set_title(f"{embedding_metric} Distances Between Coordinate Embeddings")
+ ax.grid(False)
+
+ if show:
+ plt.show()
+
+ return fig, ax
+
+
+def compute_grid_distances(
+ embedding_distances_matrix: Float[np.ndarray, "n_coord_tokens n_coord_tokens"],
+ tokenizer: MazeTokenizer,
+) -> Float[np.ndarray, "n n n n"]:
+ n: int = tokenizer.max_grid_size
+ grid_distances: Float[np.ndarray, "n n n n"] = np.full((n, n, n, n), np.nan)
+
+ for idx, ((x, y), token_id) in enumerate(
+ tokenizer.coordinate_tokens_coords.items()
+ ):
+ # Extract distances for this particular token from the distance matrix
+ distances: Float[np.ndarray, "n_coord_tokens"] = embedding_distances_matrix[
+ idx, :
+ ]
+
+ # get distances
+ for (x2, y2), distance in zip(
+ tokenizer.coordinate_tokens_coords.keys(), distances
+ ):
+ grid_distances[x, y, x2, y2] = distance
+ # coords = np.array(list(tokenizer.coordinate_tokens_coords.keys()))
+ # grid_distances[x, y, coords[:, 0], coords[:, 1]] = distances
+
+ return grid_distances
+
+
+def plot_distance_grid(
+ grid_distances: Float[np.ndarray, "n n n n"],
+ embedding_metric: str,
+ show: bool = True,
+) -> tuple[plt.Figure, plt.Axes]:
+ n: int = grid_distances.shape[0]
+ # print(n)
+ # print(tokenizer.coordinate_tokens_coords)
+ fig, axs = plt.subplots(n, n, figsize=(20, 20))
+
+ for i in range(n):
+ for j in range(n):
+ ax = axs[i, j]
+ cax = ax.matshow(grid_distances[i, j], cmap="viridis", interpolation="none")
+ ax.plot(j, i, "rx")
+ ax.set_title(f"from ({i},{j})")
+ # fully remove both major and minor gridlines
+ ax.grid(False)
+
+ fig.suptitle(f"{embedding_metric} distances grid")
+ plt.colorbar(cax, ax=axs.ravel().tolist())
+
+ if show:
+ plt.show()
+
+ return fig, axs
+
+
+def plot_distance_correlation(
+ distance_grid: Float[np.ndarray, "n n n n"],
+ embedding_metric: str,
+ coordinate_metric: str,
+ show: bool = False,
+ **kwargs,
+) -> plt.Axes:
+ n: int = distance_grid.shape[0]
+ n_coord_tokens: int = n**2
+ n_dists: int = ((n_coord_tokens) ** 2 - n_coord_tokens) / 2
+
+ # Initialize lists to store distances
+ embedding_distances: list[float] = []
+
+ # Create an array of points to be used with pdist
+ points: Float[np.ndarray, "n_coord_tokens 2"] = np.array(
+ list(itertools.product(range(n), range(n)))
+ )
+ assert points.shape == (n_coord_tokens, 2)
+ pdist_distances: Float[np.ndarray, "n_dists"] = pdist(
+ points, metric=coordinate_metric
+ )
+ assert pdist_distances.shape == (n_dists,)
+
+ # Calculate distances in the embedding space using itertools for unique pairs
+ for idx1, idx2 in zip(*np.triu_indices(len(points), k=1)):
+ x1, y1 = points[idx1]
+ x2, y2 = points[idx2]
+ embedding_distances.append(distance_grid[x1, y1, x2, y2])
+
+ embedding_distances_np: Float[np.ndarray, "n_dists"] = np.array(embedding_distances)
+ assert embedding_distances_np.shape == (n_dists,)
+
+ ax = sns.boxplot(
+ x=pdist_distances,
+ y=embedding_distances,
+ color=sns.color_palette()[0],
+ showfliers=False,
+ width=0.5,
+ )
+ ax.set_xlabel(f"{coordinate_metric} distance between coordinates")
+ ax.set_ylabel(f"{embedding_metric} embedding distance")
+
+ if show:
+ plt.show()
+
+ return ax
diff --git a/maze_transformer/training/config.py b/maze_transformer/training/config.py
index 97e9d891..16af592a 100644
--- a/maze_transformer/training/config.py
+++ b/maze_transformer/training/config.py
@@ -408,6 +408,23 @@ class ConfigHolder(SerializableDataclass):
loading_fn=_load_maze_tokenizer,
)
+ # shortcut properties
+ @property
+ def d_model(self) -> int:
+ return self.model_cfg.d_model
+
+ @property
+ def d_head(self) -> int:
+ return self.model_cfg.d_head
+
+ @property
+ def n_layers(self) -> int:
+ return self.model_cfg.n_layers
+
+ @property
+ def n_heads(self) -> int:
+ return self.model_cfg.n_heads
+
def _set_tok_gridsize_from_dataset(self):
self.maze_tokenizer.max_grid_size = self.dataset_cfg.max_grid_n
self.maze_tokenizer.clear_cache()
@@ -552,7 +569,7 @@ def get_config_multisource(
@set_config_class(ConfigHolder)
-class ZanjHookedTransformer(ConfiguredModel, HookedTransformer):
+class ZanjHookedTransformer(ConfiguredModel[ConfigHolder], HookedTransformer):
"""A hooked transformer that is configured by a ConfigHolder
the inheritance order is critical here -- super() does not call parent, but calls the next class in the MRO
diff --git a/maze_transformer/training/training.py b/maze_transformer/training/training.py
index 63570b3f..15af763d 100644
--- a/maze_transformer/training/training.py
+++ b/maze_transformer/training/training.py
@@ -123,6 +123,8 @@ def train(
f"will train for {n_batches} batches, {evals_enabled=}, with intervals: {intervals}"
)
+ # TODO: add model output dir / run name to model.training_records
+
# start up training
# ==============================
model.train()
diff --git a/maze_transformer/utils/dict_shapes.py b/maze_transformer/utils/dict_shapes.py
new file mode 100644
index 00000000..d00b4fbf
--- /dev/null
+++ b/maze_transformer/utils/dict_shapes.py
@@ -0,0 +1,23 @@
+import json
+
+from muutils.dictmagic import dotlist_to_nested_dict
+
+
+def get_dict_shapes(d: dict[str, "torch.Tensor"]) -> dict[str, tuple[int, ...]]:
+ """given a state dict or cache dict, compute the shapes and put them in a nested dict"""
+ return dotlist_to_nested_dict({k: tuple(v.shape) for k, v in d.items()})
+
+
+def string_dict_shapes(d: dict[str, "torch.Tensor"]) -> str:
+ """printable version of get_dict_shapes"""
+ return json.dumps(
+ dotlist_to_nested_dict(
+ {
+ k: str(
+ tuple(v.shape)
+ ) # to string, since indent wont play nice with tuples
+ for k, v in d.items()
+ }
+ ),
+ indent=2,
+ )
diff --git a/notebooks/demo_dataset.ipynb b/notebooks/demo_dataset.ipynb
index aa612675..2b81acb9 100644
--- a/notebooks/demo_dataset.ipynb
+++ b/notebooks/demo_dataset.ipynb
@@ -1,5 +1,15 @@
{
"cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Basics\n",
+ "\n",
+ "to start, let's import a few things we'll need:"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 1,
@@ -9,25 +19,28 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "dict_keys(['test-g3-n5-a_dfs-h81250', 'demo_small-g3-n100-a_dfs-h12257', 'demo-g6-n10K-a_dfs-h76502'])\n"
+ "dict_keys(['test-g3-n5-a_dfs-h75556', 'demo_small-g3-n100-a_dfs-h88371', 'demo-g6-n10K-a_dfs-h30615'])\n"
]
}
],
"source": [
- "import json\n",
- "\n",
+ "# other package imports\n",
"import matplotlib.pyplot as plt # keep this import for CI to work\n",
- "from zanj import ZANJ\n",
+ "from zanj import ZANJ # saving/loading data\n",
+ "from muutils.mlutils import pprint_summary # pretty printing as json\n",
"\n",
- "from maze_dataset.generation import LatticeMazeGenerators\n",
- "from maze_dataset import SolvedMaze, MazeDataset, MazeDatasetConfig\n",
+ "# maze_dataset imports\n",
+ "from maze_dataset import LatticeMaze, SolvedMaze, MazeDataset, MazeDatasetConfig\n",
+ "from maze_dataset.generation import LatticeMazeGenerators, GENERATORS_MAP\n",
+ "from maze_dataset.generation.default_generators import DEFAULT_GENERATORS\n",
"from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS\n",
"from maze_dataset.plotting import plot_dataset_mazes, print_dataset_mazes\n",
- "from muutils.mlutils import pprint_summary\n",
"\n",
+ "# check the configs\n",
"print(MAZE_DATASET_CONFIGS.keys())\n",
- "\n",
- "LOCAL_DATA_PATH: str = \"../data/maze_dataset/\""
+ "# for saving/loading things\n",
+ "LOCAL_DATA_PATH: str = \"../data/maze_dataset/\"\n",
+ "zanj: ZANJ = ZANJ(external_list_threshold=256)"
]
},
{
@@ -35,7 +48,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "you should always see `test-g3-n5-a_dfs-h81250` in the list of available dataset configs"
+ "You should always see `test-g3-n5-a_dfs-h9136` in the list of available dataset configs above.\n",
+ "\n",
+ "Now, let's set up our initial config and dataset:"
]
},
{
@@ -47,19 +62,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "test-g3-n32-a_dfs-h36643\n"
+ "test-g5-n4-a_dfs-h84708\n"
]
}
],
"source": [
- "zanj: ZANJ = ZANJ(external_list_threshold=256)\n",
"cfg: MazeDatasetConfig = MazeDatasetConfig(\n",
- "\tname=\"test\",\n",
- "\tgrid_n=3,\n",
- "\tn_mazes=32,\n",
- "\tmaze_ctor=LatticeMazeGenerators.gen_dfs,\n",
+ "\tname=\"test\", # name is only for you to keep track of things\n",
+ "\tgrid_n=5, # number of rows/columns in the lattice\n",
+ "\tn_mazes=4, # number of mazes to generate\n",
+ "\tmaze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze\n",
+ " # there are a few more arguments here, to be discussed later\n",
")\n",
"\n",
+ "# each config will use this function to get the name of the dataset\n",
+ "# it contains some basic info about the algorithm, size, and number of mazes\n",
+ "# at the end after \"h\" is a stable hash of the config to avoid collisions\n",
"print(cfg.to_fname())"
]
},
@@ -72,6 +90,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
+ "trying to get the dataset 'test-g5-n4-a_dfs-h84708'\n",
"generating dataset...\n"
]
},
@@ -79,14 +98,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "generating & solving mazes: 100%|██████████| 32/32 [00:00<00:00, 695.67maze/s]"
+ "generating & solving mazes: 100%|██████████| 4/4 [00:00<00:00, 181.81maze/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "saving dataset to ..\\data\\maze_dataset\\test-g3-n32-a_dfs-h36643.zanj\n"
+ "saving dataset to ..\\data\\maze_dataset\\test-g5-n4-a_dfs-h84708.zanj\n"
]
},
{
@@ -100,13 +119,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Got dataset test with 32 items. output.cfg.to_fname() = 'test-g3-n32-a_dfs-h36643'\n"
+ "Got dataset test with 4 items. output.cfg.to_fname() = 'test-g5-n4-a_dfs-h84708'\n"
]
}
],
"source": [
+ "# to create a dataset, just call MazeDataset.from_config\n",
"dataset: MazeDataset = MazeDataset.from_config(\n",
+ " # your config\n",
"\tcfg,\n",
+ " # and all this below is completely optional\n",
"\tdo_download=False,\n",
"\tload_local=False,\n",
"\tdo_generate=True,\n",
@@ -118,75 +140,66 @@
")"
]
},
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "now that we have our dataset, let's take a look at it!"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "loading dataset from ../data/maze_dataset/test-g3-n32-a_dfs-h36643.zanj\n",
- "Got dataset test with 32 items. output.cfg.to_fname() = 'test-g3-n32-a_dfs-h36643'\n"
- ]
+ "data": {
+ "text/plain": [
+ "(