Skip to content

Commit

Permalink
Merge pull request #33 from Acellera/llama2
Browse files Browse the repository at this point in the history
Llama2
  • Loading branch information
albertbou92 authored Jun 14, 2024
2 parents 9c16ec7 + 34114fc commit a82cb95
Show file tree
Hide file tree
Showing 31 changed files with 508 additions and 26 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ We provide a variety of default priors that can be selected in the configuration


- A Llama2 model (requires installation of HuggingFace's `transformers` library)
- pre-training dataset: [ChEMBL](https://www.ebi.ac.uk/chembl/)
- number of parameters: 2,809,216
- to select set the field `model` to `mamba` in any configuration file
- pre-training dataset: [REAL Database, 6B cpds, CXSMILES](https://enamine.net/compound-collections/real-compounds/real-database)
- number of parameters: 5,965,760
- to select set the field `model` to `llama2` in any configuration file

### Integration of custom models

Expand Down
15 changes: 15 additions & 0 deletions acegen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@
create_gpt2_actor_critic,
create_gpt2_critic,
)

from acegen.models.gru import (
create_gru_actor,
create_gru_actor_critic,
create_gru_critic,
)
from acegen.models.llama2 import (
create_llama2_actor,
create_llama2_actor_critic,
create_llama2_critic,
)
from acegen.models.lstm import (
create_lstm_actor,
create_lstm_actor_critic,
Expand All @@ -25,6 +31,7 @@
)
from acegen.models.utils import adapt_state_dict
from acegen.vocabulary.tokenizers import (
AsciiSMILESTokenizer,
SMILESTokenizerChEMBL,
SMILESTokenizerEnamine,
SMILESTokenizerGuacaMol,
Expand Down Expand Up @@ -76,6 +83,14 @@ def extract(path):
resources.files("acegen.priors") / "gpt2_enamine_real.ckpt",
SMILESTokenizerEnamine(),
),
"llama2": (
create_llama2_actor,
create_llama2_critic,
create_llama2_actor_critic,
resources.files("acegen.priors") / "ascii.pt",
resources.files("acegen.priors") / "llama2_enamine_real_6B.ckpt",
AsciiSMILESTokenizer(),
),
"mamba": (
create_mamba_actor,
create_mamba_critic,
Expand Down
290 changes: 290 additions & 0 deletions acegen/models/llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs import ExplorationType
from torchrl.modules import ActorValueOperator, ProbabilisticActor

try:
from transformers import LlamaConfig, LlamaModel

_has_transformers = True
except ImportError as err:
_has_transformers = False
TRANSFORMERS_ERR = err


class Llama2(nn.Module):
"""Llama2 model for language modeling. This model is a simple wrapper around the HuggingFace Llama22Model."""

def __init__(self, config=None):
if not _has_transformers:
raise RuntimeError(
"transformers library not found, please install with pip install transformers."
) from TRANSFORMERS_ERR

super(Llama2, self).__init__()

# Define model
if config is not None:
self.feature_extractor = LlamaModel(config)
else:
self.feature_extractor = None

# Start in evaluation mode
self._train_mode = False

@property
def train_mode(self):
return self._train_mode

def set_train_mode(self, train_mode: bool = True):
if train_mode is self._train_mode:
return self
out = Llama2()
out.feature_extractor = self.feature_extractor
out._train_mode = train_mode
return out

def forward(self, sequence, sequence_mask):

out = self.feature_extractor(
input_ids=sequence,
attention_mask=sequence_mask.long().reshape(*sequence.shape),
).last_hidden_state

if self.train_mode is False: # Data collection, return only last token
obs_length = sequence_mask.sum(-1)
out = out[torch.arange(len(out)), obs_length.to(torch.int64) - 1]

return out


def define_llama2_configuration(
vocabulary_size: int,
n_positions: int = 2048,
n_head: int = 16,
n_kv_head: int = 4,
n_layer: int = 4,
n_embd: int = 320,
attn_pdrop: float = 0.0,
):
"""Define a Llama2 configuration.
This function is a simple wrapper around the HuggingFace Llama2Config, allowing to specify relevant parameters.
"""
# Define model
config = LlamaConfig()

# Adjust model parameters
config.vocab_size = vocabulary_size
config.max_position_embeddings = n_positions
config.num_attention_heads = n_head
config.num_key_value_heads = n_kv_head
config.num_hidden_layers = n_layer
config.hidden_size = n_embd
config.intermediate_size = 4 * n_embd
config.attention_dropout = attn_pdrop
return config


def create_llama2_actor(
vocabulary_size: int,
n_positions: int = 2048,
n_head: int = 16,
n_kv_head: int = 4,
n_layer: int = 4,
n_embd: int = 320,
attn_pdrop: float = 0.0,
return_log_prob=True,
):
"""Create a Llama2 actor for language modeling."""
config = define_llama2_configuration(
vocabulary_size,
n_positions,
n_head,
n_kv_head,
n_layer,
n_embd,
attn_pdrop,
)
# Define transformer
lm = Llama2(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)
lm_inference = TensorDictModule(
lm,
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)

# Define final layer and also make it a TensorDictModule
lm_head = TensorDictModule(
nn.Linear(config.hidden_size, vocabulary_size, bias=False),
in_keys=["features"],
out_keys=["logits"],
)

# Concatenate lm and head, similar to torch.nn.Sequential
policy_training = TensorDictSequential(lm_training, lm_head)
policy_inference = TensorDictSequential(lm_inference, lm_head)

# To make the actor probabilistic, wrap the policy in a ProbabilisticActor
# This module will take care of sampling and computing log probabilities
probabilistic_policy_training = ProbabilisticActor(
module=policy_training,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)
probabilistic_policy_inference = ProbabilisticActor(
module=policy_inference,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)
return probabilistic_policy_training, probabilistic_policy_inference


def create_llama2_critic(
vocabulary_size: int,
n_positions: int = 2048,
n_head: int = 16,
n_kv_head: int = 4,
n_layer: int = 4,
n_embd: int = 320,
attn_pdrop: float = 0.0,
critic_value_per_action=False,
):
"""Create a Llama2 critic for language modeling."""
config = define_llama2_configuration(
vocabulary_size,
n_positions,
n_head,
n_kv_head,
n_layer,
n_embd,
attn_pdrop,
)
# Define transformer
lm = Llama2(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)
lm_inference = TensorDictModule(
lm,
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)

# Define final layer and also make it a TensorDictModule
lm_head = TensorDictModule(
nn.Linear(
config.hidden_size,
vocabulary_size if critic_value_per_action else 1,
bias=False,
),
in_keys=["features"],
out_keys=["action_value"] if critic_value_per_action else ["state_value"],
)

# Concatenate lm and head, similar to torch.nn.Sequential
# Critic does not need to be probabilistic, so we can return directly
critic_training = TensorDictSequential(lm_training, lm_head)
critic_inference = TensorDictSequential(lm_inference, lm_head)
return critic_training, critic_inference


def create_llama2_actor_critic(
vocabulary_size: int,
n_positions: int = 2048,
n_head: int = 16,
n_kv_head: int = 4,
n_layer: int = 4,
n_embd: int = 320,
attn_pdrop: float = 0.0,
return_log_prob=True,
critic_value_per_action=False,
):
"""Create a Llama2 shared actor-critic network for language modeling."""
config = define_llama2_configuration(
vocabulary_size,
n_positions,
n_head,
n_kv_head,
n_layer,
n_embd,
attn_pdrop,
)
# Define transformer
lm = Llama2(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)
lm_inference = TensorDictModule(
lm,
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)

# Define actor head and also make it a TensorDictModule and Probabilistic
actor_head = TensorDictModule(
nn.Linear(config.hidden_size, vocabulary_size, bias=False),
in_keys=["features"],
out_keys=["logits"],
)
actor_head = ProbabilisticActor(
module=actor_head,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)

# Define critic head and also make it a TensorDictModule
critic_head = TensorDictModule(
nn.Linear(
config.hidden_size,
vocabulary_size if critic_value_per_action else 1,
bias=False,
),
in_keys=["features"],
out_keys=["action_value"] if critic_value_per_action else ["state_value"],
)

# Create shared actor-critic TensorDictModule
actor_critic_train = ActorValueOperator(
common_operator=lm_training,
policy_operator=actor_head,
value_operator=critic_head,
)
actor_critic_inference = ActorValueOperator(
common_operator=lm_inference,
policy_operator=actor_head,
value_operator=critic_head,
)

# Get individual operators
actor_training = actor_critic_train.get_policy_operator()
critic_training = actor_critic_train.get_value_operator()
actor_inference = actor_critic_inference.get_policy_operator()
critic_inference = actor_critic_inference.get_value_operator()

return actor_training, actor_inference, critic_training, critic_inference
Binary file added acegen/priors/ascii.pt
Binary file not shown.
Binary file added acegen/priors/llama2_enamine_real_6B.ckpt
Binary file not shown.
2 changes: 2 additions & 0 deletions acegen/vocabulary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from acegen.vocabulary.base import Tokenizer, Vocabulary
from acegen.vocabulary.tokenizers import (
AISTokenizer,
AsciiSMILESTokenizer,
DeepSMILESTokenizer,
SAFETokenizer,
SELFIESTokenizer,
Expand All @@ -18,4 +19,5 @@
"SMILESTokenizerChEMBL": SMILESTokenizerChEMBL,
"SMILESTokenizerEnamine": SMILESTokenizerEnamine,
"SMILESTokenizerGuacaMol": SMILESTokenizerGuacaMol,
"AsciiSMILESTokenizer": AsciiSMILESTokenizer,
}
24 changes: 24 additions & 0 deletions acegen/vocabulary/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,27 @@ def untokenize(self, tokens, convert_to_smiles=True):
else:
smi = ",".join(ntokens)
return smi


class AsciiSMILESTokenizer:
"""Deals with the tokenization and untokenization of SMILES.
Uses ASCII characters as tokens.
"""

def tokenize(self, data, with_begin_and_end=True):
"""Tokenizes a SMILES string."""
tokens = list(data)
if with_begin_and_end:
tokens = ["^"] + tokens + ["$"]
return tokens

def untokenize(self, tokens):
"""Untokenizes a SMILES string."""
smi = ""
for token in tokens:
if token == "$":
break
if token != "^":
smi += token
return smi
2 changes: 1 addition & 1 deletion scripts/a2c/config_denovo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules

# Model architecture
shared_nets: False
model: gru # gru, lstm, gpt2 or mamba
model: gru # gru, lstm, gpt2, mamba or llama2
# The default prior varies for each model. Refer to the README file in the root directory for more information.
# The default vocabulary varies for each prior. Refer to the README file in the root directory for more information.
custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model)
Expand Down
Loading

0 comments on commit a82cb95

Please sign in to comment.