Skip to content

Commit

Permalink
Modify configs to allow specification of pos encoding type (#201)
Browse files Browse the repository at this point in the history
* Modify configs to allow specification of pos encoding type (fc02dc0)

* fix up pos embed (41466c0)

* fix test (07eab25)

---------

Co-authored-by: mivanit <[email protected]>
  • Loading branch information
afspies and mivanit authored Jul 26, 2024
1 parent 70eef54 commit 98ea056
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
6 changes: 6 additions & 0 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ class BaseGPTConfig(SerializableDataclass):
d_model: int
d_head: int
n_layers: int
positional_embedding_type: str = serializable_field(
default="standard",
loading_fn=lambda data: data.get("positional_embedding_type", "standard"),
)

weight_processing: dict[str, bool] = serializable_field(
default_factory=lambda: dict(
Expand All @@ -59,6 +63,7 @@ def summary(self) -> dict:
d_model=self.d_model,
d_head=self.d_head,
n_layers=self.n_layers,
positional_embedding_type=self.positional_embedding_type,
weight_processing=self.weight_processing,
n_heads=self.n_heads,
)
Expand Down Expand Up @@ -501,6 +506,7 @@ def hooked_transformer_cfg(self) -> HookedTransformerConfig:
d_model=self.model_cfg.d_model,
d_head=self.model_cfg.d_head,
n_layers=self.model_cfg.n_layers,
positional_embedding_type=self.model_cfg.positional_embedding_type,
n_ctx=self.dataset_cfg.seq_len_max,
d_vocab=self.maze_tokenizer.vocab_size,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _custom_serialized_config():
"d_head": 1,
"n_layers": 1,
"n_heads": 1,
"positional_embedding_type": "standard",
"weight_processing": {
"are_layernorms_folded": False,
"are_weights_processed": False,
Expand Down

0 comments on commit 98ea056

Please sign in to comment.