diff --git a/README.md b/README.md index 53de786b..3f8b6de7 100644 --- a/README.md +++ b/README.md @@ -181,9 +181,9 @@ We provide a variety of default priors that can be selected in the configuration - A Llama2 model (requires installation of HuggingFace's `transformers` library) - - pre-training dataset: [ChEMBL](https://www.ebi.ac.uk/chembl/) - - number of parameters: 2,809,216 - - to select set the field `model` to `mamba` in any configuration file + - pre-training dataset: [REAL Database, 6B cpds, CXSMILES](https://enamine.net/compound-collections/real-compounds/real-database) + - number of parameters: 5,965,760 + - to select set the field `model` to `llama2` in any configuration file ### Integration of custom models diff --git a/acegen/models/__init__.py b/acegen/models/__init__.py index a8726393..ea4af07b 100644 --- a/acegen/models/__init__.py +++ b/acegen/models/__init__.py @@ -8,11 +8,17 @@ create_gpt2_actor_critic, create_gpt2_critic, ) + from acegen.models.gru import ( create_gru_actor, create_gru_actor_critic, create_gru_critic, ) +from acegen.models.llama2 import ( + create_llama2_actor, + create_llama2_actor_critic, + create_llama2_critic, +) from acegen.models.lstm import ( create_lstm_actor, create_lstm_actor_critic, @@ -25,6 +31,7 @@ ) from acegen.models.utils import adapt_state_dict from acegen.vocabulary.tokenizers import ( + AsciiSMILESTokenizer, SMILESTokenizerChEMBL, SMILESTokenizerEnamine, SMILESTokenizerGuacaMol, @@ -76,6 +83,14 @@ def extract(path): resources.files("acegen.priors") / "gpt2_enamine_real.ckpt", SMILESTokenizerEnamine(), ), + "llama2": ( + create_llama2_actor, + create_llama2_critic, + create_llama2_actor_critic, + resources.files("acegen.priors") / "ascii.pt", + resources.files("acegen.priors") / "llama2_enamine_real_6B.ckpt", + AsciiSMILESTokenizer(), + ), "mamba": ( create_mamba_actor, create_mamba_critic, diff --git a/acegen/models/llama2.py b/acegen/models/llama2.py new file mode 100644 index 00000000..ba1475a4 --- /dev/null +++ b/acegen/models/llama2.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +from tensordict.nn import TensorDictModule, TensorDictSequential +from torchrl.envs import ExplorationType +from torchrl.modules import ActorValueOperator, ProbabilisticActor + +try: + from transformers import LlamaConfig, LlamaModel + + _has_transformers = True +except ImportError as err: + _has_transformers = False + TRANSFORMERS_ERR = err + + +class Llama2(nn.Module): + """Llama2 model for language modeling. This model is a simple wrapper around the HuggingFace Llama22Model.""" + + def __init__(self, config=None): + if not _has_transformers: + raise RuntimeError( + "transformers library not found, please install with pip install transformers." + ) from TRANSFORMERS_ERR + + super(Llama2, self).__init__() + + # Define model + if config is not None: + self.feature_extractor = LlamaModel(config) + else: + self.feature_extractor = None + + # Start in evaluation mode + self._train_mode = False + + @property + def train_mode(self): + return self._train_mode + + def set_train_mode(self, train_mode: bool = True): + if train_mode is self._train_mode: + return self + out = Llama2() + out.feature_extractor = self.feature_extractor + out._train_mode = train_mode + return out + + def forward(self, sequence, sequence_mask): + + out = self.feature_extractor( + input_ids=sequence, + attention_mask=sequence_mask.long().reshape(*sequence.shape), + ).last_hidden_state + + if self.train_mode is False: # Data collection, return only last token + obs_length = sequence_mask.sum(-1) + out = out[torch.arange(len(out)), obs_length.to(torch.int64) - 1] + + return out + + +def define_llama2_configuration( + vocabulary_size: int, + n_positions: int = 2048, + n_head: int = 16, + n_kv_head: int = 4, + n_layer: int = 4, + n_embd: int = 320, + attn_pdrop: float = 0.0, +): + """Define a Llama2 configuration. + + This function is a simple wrapper around the HuggingFace Llama2Config, allowing to specify relevant parameters. + """ + # Define model + config = LlamaConfig() + + # Adjust model parameters + config.vocab_size = vocabulary_size + config.max_position_embeddings = n_positions + config.num_attention_heads = n_head + config.num_key_value_heads = n_kv_head + config.num_hidden_layers = n_layer + config.hidden_size = n_embd + config.intermediate_size = 4 * n_embd + config.attention_dropout = attn_pdrop + return config + + +def create_llama2_actor( + vocabulary_size: int, + n_positions: int = 2048, + n_head: int = 16, + n_kv_head: int = 4, + n_layer: int = 4, + n_embd: int = 320, + attn_pdrop: float = 0.0, + return_log_prob=True, +): + """Create a Llama2 actor for language modeling.""" + config = define_llama2_configuration( + vocabulary_size, + n_positions, + n_head, + n_kv_head, + n_layer, + n_embd, + attn_pdrop, + ) + # Define transformer + lm = Llama2(config) + + # Wrap the transformer in a TensorDictModule to make TensorDict compatible + lm_training = TensorDictModule( + lm.set_train_mode(True), + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + lm_inference = TensorDictModule( + lm, + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + + # Define final layer and also make it a TensorDictModule + lm_head = TensorDictModule( + nn.Linear(config.hidden_size, vocabulary_size, bias=False), + in_keys=["features"], + out_keys=["logits"], + ) + + # Concatenate lm and head, similar to torch.nn.Sequential + policy_training = TensorDictSequential(lm_training, lm_head) + policy_inference = TensorDictSequential(lm_inference, lm_head) + + # To make the actor probabilistic, wrap the policy in a ProbabilisticActor + # This module will take care of sampling and computing log probabilities + probabilistic_policy_training = ProbabilisticActor( + module=policy_training, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=return_log_prob, + default_interaction_type=ExplorationType.RANDOM, + ) + probabilistic_policy_inference = ProbabilisticActor( + module=policy_inference, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=return_log_prob, + default_interaction_type=ExplorationType.RANDOM, + ) + return probabilistic_policy_training, probabilistic_policy_inference + + +def create_llama2_critic( + vocabulary_size: int, + n_positions: int = 2048, + n_head: int = 16, + n_kv_head: int = 4, + n_layer: int = 4, + n_embd: int = 320, + attn_pdrop: float = 0.0, + critic_value_per_action=False, +): + """Create a Llama2 critic for language modeling.""" + config = define_llama2_configuration( + vocabulary_size, + n_positions, + n_head, + n_kv_head, + n_layer, + n_embd, + attn_pdrop, + ) + # Define transformer + lm = Llama2(config) + + # Wrap the transformer in a TensorDictModule to make TensorDict compatible + lm_training = TensorDictModule( + lm.set_train_mode(True), + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + lm_inference = TensorDictModule( + lm, + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + + # Define final layer and also make it a TensorDictModule + lm_head = TensorDictModule( + nn.Linear( + config.hidden_size, + vocabulary_size if critic_value_per_action else 1, + bias=False, + ), + in_keys=["features"], + out_keys=["action_value"] if critic_value_per_action else ["state_value"], + ) + + # Concatenate lm and head, similar to torch.nn.Sequential + # Critic does not need to be probabilistic, so we can return directly + critic_training = TensorDictSequential(lm_training, lm_head) + critic_inference = TensorDictSequential(lm_inference, lm_head) + return critic_training, critic_inference + + +def create_llama2_actor_critic( + vocabulary_size: int, + n_positions: int = 2048, + n_head: int = 16, + n_kv_head: int = 4, + n_layer: int = 4, + n_embd: int = 320, + attn_pdrop: float = 0.0, + return_log_prob=True, + critic_value_per_action=False, +): + """Create a Llama2 shared actor-critic network for language modeling.""" + config = define_llama2_configuration( + vocabulary_size, + n_positions, + n_head, + n_kv_head, + n_layer, + n_embd, + attn_pdrop, + ) + # Define transformer + lm = Llama2(config) + + # Wrap the transformer in a TensorDictModule to make TensorDict compatible + lm_training = TensorDictModule( + lm.set_train_mode(True), + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + lm_inference = TensorDictModule( + lm, + in_keys=["sequence", "sequence_mask"], + out_keys=["features"], + ) + + # Define actor head and also make it a TensorDictModule and Probabilistic + actor_head = TensorDictModule( + nn.Linear(config.hidden_size, vocabulary_size, bias=False), + in_keys=["features"], + out_keys=["logits"], + ) + actor_head = ProbabilisticActor( + module=actor_head, + in_keys=["logits"], + out_keys=["action"], + distribution_class=torch.distributions.Categorical, + return_log_prob=return_log_prob, + default_interaction_type=ExplorationType.RANDOM, + ) + + # Define critic head and also make it a TensorDictModule + critic_head = TensorDictModule( + nn.Linear( + config.hidden_size, + vocabulary_size if critic_value_per_action else 1, + bias=False, + ), + in_keys=["features"], + out_keys=["action_value"] if critic_value_per_action else ["state_value"], + ) + + # Create shared actor-critic TensorDictModule + actor_critic_train = ActorValueOperator( + common_operator=lm_training, + policy_operator=actor_head, + value_operator=critic_head, + ) + actor_critic_inference = ActorValueOperator( + common_operator=lm_inference, + policy_operator=actor_head, + value_operator=critic_head, + ) + + # Get individual operators + actor_training = actor_critic_train.get_policy_operator() + critic_training = actor_critic_train.get_value_operator() + actor_inference = actor_critic_inference.get_policy_operator() + critic_inference = actor_critic_inference.get_value_operator() + + return actor_training, actor_inference, critic_training, critic_inference diff --git a/acegen/priors/ascii.pt b/acegen/priors/ascii.pt new file mode 100644 index 00000000..bd15a277 Binary files /dev/null and b/acegen/priors/ascii.pt differ diff --git a/acegen/priors/llama2_enamine_real_6B.ckpt b/acegen/priors/llama2_enamine_real_6B.ckpt new file mode 100644 index 00000000..0cd4e96b Binary files /dev/null and b/acegen/priors/llama2_enamine_real_6B.ckpt differ diff --git a/acegen/vocabulary/__init__.py b/acegen/vocabulary/__init__.py index 4d662e3c..4bca3cd7 100644 --- a/acegen/vocabulary/__init__.py +++ b/acegen/vocabulary/__init__.py @@ -1,6 +1,7 @@ from acegen.vocabulary.base import Tokenizer, Vocabulary from acegen.vocabulary.tokenizers import ( AISTokenizer, + AsciiSMILESTokenizer, DeepSMILESTokenizer, SAFETokenizer, SELFIESTokenizer, @@ -18,4 +19,5 @@ "SMILESTokenizerChEMBL": SMILESTokenizerChEMBL, "SMILESTokenizerEnamine": SMILESTokenizerEnamine, "SMILESTokenizerGuacaMol": SMILESTokenizerGuacaMol, + "AsciiSMILESTokenizer": AsciiSMILESTokenizer, } diff --git a/acegen/vocabulary/tokenizers.py b/acegen/vocabulary/tokenizers.py index f23e7444..dafd1743 100644 --- a/acegen/vocabulary/tokenizers.py +++ b/acegen/vocabulary/tokenizers.py @@ -519,3 +519,27 @@ def untokenize(self, tokens, convert_to_smiles=True): else: smi = ",".join(ntokens) return smi + + +class AsciiSMILESTokenizer: + """Deals with the tokenization and untokenization of SMILES. + + Uses ASCII characters as tokens. + """ + + def tokenize(self, data, with_begin_and_end=True): + """Tokenizes a SMILES string.""" + tokens = list(data) + if with_begin_and_end: + tokens = ["^"] + tokens + ["$"] + return tokens + + def untokenize(self, tokens): + """Untokenizes a SMILES string.""" + smi = "" + for token in tokens: + if token == "$": + break + if token != "^": + smi += token + return smi diff --git a/scripts/a2c/config_denovo.yaml b/scripts/a2c/config_denovo.yaml index a12b6918..f082c650 100644 --- a/scripts/a2c/config_denovo.yaml +++ b/scripts/a2c/config_denovo.yaml @@ -21,7 +21,7 @@ prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/a2c/config_fragment.yaml b/scripts/a2c/config_fragment.yaml index 64fc6497..1e06d606 100644 --- a/scripts/a2c/config_fragment.yaml +++ b/scripts/a2c/config_fragment.yaml @@ -24,7 +24,7 @@ promptsmiles_multi: False # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/a2c/config_scaffold.yaml b/scripts/a2c/config_scaffold.yaml index 230e5431..84c0a669 100644 --- a/scripts/a2c/config_scaffold.yaml +++ b/scripts/a2c/config_scaffold.yaml @@ -24,7 +24,7 @@ promptsmiles_multi: False # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ahc/config_denovo.yaml b/scripts/ahc/config_denovo.yaml index 8ecf46c8..4d796d36 100644 --- a/scripts/ahc/config_denovo.yaml +++ b/scripts/ahc/config_denovo.yaml @@ -20,7 +20,7 @@ custom_task: null # Requires molscore to be set to null prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ahc/config_fragment.yaml b/scripts/ahc/config_fragment.yaml index e2a90403..775651c7 100644 --- a/scripts/ahc/config_fragment.yaml +++ b/scripts/ahc/config_fragment.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ahc/config_scaffold.yaml b/scripts/ahc/config_scaffold.yaml index 8f51713a..5493e30f 100644 --- a/scripts/ahc/config_scaffold.yaml +++ b/scripts/ahc/config_scaffold.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/dpo/config_denovo.yaml b/scripts/dpo/config_denovo.yaml index fff1ea6a..67903646 100644 --- a/scripts/dpo/config_denovo.yaml +++ b/scripts/dpo/config_denovo.yaml @@ -20,7 +20,7 @@ custom_task: null # Requires molscore to be set to null prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/dpo/config_fragment.yaml b/scripts/dpo/config_fragment.yaml index e2e7f54e..3b73a71b 100644 --- a/scripts/dpo/config_fragment.yaml +++ b/scripts/dpo/config_fragment.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/dpo/config_scaffold.yaml b/scripts/dpo/config_scaffold.yaml index 5e8c773a..989422d6 100644 --- a/scripts/dpo/config_scaffold.yaml +++ b/scripts/dpo/config_scaffold.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/hill_climb/config_denovo.yaml b/scripts/hill_climb/config_denovo.yaml index 2ffee06c..67590e38 100644 --- a/scripts/hill_climb/config_denovo.yaml +++ b/scripts/hill_climb/config_denovo.yaml @@ -20,7 +20,7 @@ custom_task: null # Requires molscore to be set to null prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/hill_climb/config_fragment.yaml b/scripts/hill_climb/config_fragment.yaml index 0a04aad1..76998a15 100644 --- a/scripts/hill_climb/config_fragment.yaml +++ b/scripts/hill_climb/config_fragment.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/hill_climb/config_scaffold.yaml b/scripts/hill_climb/config_scaffold.yaml index c79d66d3..05f48ad4 100644 --- a/scripts/hill_climb/config_scaffold.yaml +++ b/scripts/hill_climb/config_scaffold.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ppo/config_denovo.yaml b/scripts/ppo/config_denovo.yaml index d117527e..d38b6a0c 100644 --- a/scripts/ppo/config_denovo.yaml +++ b/scripts/ppo/config_denovo.yaml @@ -21,7 +21,7 @@ prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ppo/config_fragment.yaml b/scripts/ppo/config_fragment.yaml index 6aa7bb0b..0b69f212 100644 --- a/scripts/ppo/config_fragment.yaml +++ b/scripts/ppo/config_fragment.yaml @@ -24,7 +24,7 @@ promptsmiles_multi: False # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/ppo/config_scaffold.yaml b/scripts/ppo/config_scaffold.yaml index c1b476b4..29570697 100644 --- a/scripts/ppo/config_scaffold.yaml +++ b/scripts/ppo/config_scaffold.yaml @@ -24,7 +24,7 @@ promptsmiles_multi: False # Model architecture shared_nets: False -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/pretrain/config.yaml b/scripts/pretrain/config.yaml index 8be36e8e..44844dee 100644 --- a/scripts/pretrain/config.yaml +++ b/scripts/pretrain/config.yaml @@ -12,7 +12,7 @@ recompute_dataset: True # if False and dataset_log_dir contains a dataset, it wi dataset_log_dir: /tmp/pretrain # if recomputing dataset, save it here # Model configuration -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) model_log_dir: test #/tmp/pretrain # save model here diff --git a/scripts/pretrain/config_mamba.yaml b/scripts/pretrain/config_mamba.yaml index 2cad85a4..65574354 100644 --- a/scripts/pretrain/config_mamba.yaml +++ b/scripts/pretrain/config_mamba.yaml @@ -13,7 +13,7 @@ recompute_dataset: True # if False and dataset_log_dir contains a dataset, it wi dataset_log_dir: /tmp/pretrain # Model configuration -model: mamba # gru, lstm, gpt2 or mamba +model: mamba # gru, lstm, gpt2, mamba or llama2 model_log_dir: mamba/chembl28p # Training configuration diff --git a/scripts/reinforce/config_denovo.yaml b/scripts/reinforce/config_denovo.yaml index 3defe034..453afd22 100644 --- a/scripts/reinforce/config_denovo.yaml +++ b/scripts/reinforce/config_denovo.yaml @@ -20,7 +20,7 @@ custom_task: null # Requires molscore to be set to null prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/reinforce/config_fragment.yaml b/scripts/reinforce/config_fragment.yaml index 21013dbb..43941a5b 100644 --- a/scripts/reinforce/config_fragment.yaml +++ b/scripts/reinforce/config_fragment.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/reinforce/config_scaffold.yaml b/scripts/reinforce/config_scaffold.yaml index 94561cd8..231cdb20 100644 --- a/scripts/reinforce/config_scaffold.yaml +++ b/scripts/reinforce/config_scaffold.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/reinvent/config_denovo.yaml b/scripts/reinvent/config_denovo.yaml index 58516774..7fce6b65 100644 --- a/scripts/reinvent/config_denovo.yaml +++ b/scripts/reinvent/config_denovo.yaml @@ -20,7 +20,7 @@ custom_task: null # Requires molscore to be set to null prompt: null # e.g. c1ccccc # Fix the beginning of the generated molecules # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/reinvent/config_fragment.yaml b/scripts/reinvent/config_fragment.yaml index d233f53e..67b44713 100644 --- a/scripts/reinvent/config_fragment.yaml +++ b/scripts/reinvent/config_fragment.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/scripts/reinvent/config_scaffold.yaml b/scripts/reinvent/config_scaffold.yaml index bfea5782..99d0c99a 100644 --- a/scripts/reinvent/config_scaffold.yaml +++ b/scripts/reinvent/config_scaffold.yaml @@ -23,7 +23,7 @@ promptsmiles_shuffle: True promptsmiles_multi: False # Model architecture -model: gru # gru, lstm, gpt2 or mamba +model: gru # gru, lstm, gpt2, mamba or llama2 # The default prior varies for each model. Refer to the README file in the root directory for more information. # The default vocabulary varies for each prior. Refer to the README file in the root directory for more information. custom_model_factory: null # Path to a custom model factory (e.g. my_module.create_model) diff --git a/tests/test_llama2_model.py b/tests/test_llama2_model.py new file mode 100644 index 00000000..5c508664 --- /dev/null +++ b/tests/test_llama2_model.py @@ -0,0 +1,151 @@ +import pytest +import torch +from acegen.data import smiles_to_tensordict +from acegen.models.llama2 import ( + create_llama2_actor, + create_llama2_actor_critic, + create_llama2_critic, +) +from utils import get_default_devices + + +def generate_valid_data_batch( + vocabulary_size: int, batch_size: int, sequence_length: int +): + tokens = torch.randint(0, vocabulary_size, (batch_size, sequence_length + 1)) + batch = smiles_to_tensordict( + tokens, replace_mask_value=0 + ) # batch_size, sequence_length) + batch.set("sequence", batch.get("observation")) + batch.set("sequence_mask", batch.get("mask")) + return batch + + +@pytest.mark.parametrize("vocabulary_size", [10]) +@pytest.mark.parametrize("device", get_default_devices()) +def test_llama2_actor(vocabulary_size, device, sequence_length=5, batch_size=10): + torch.manual_seed(0) + # Create the model and a data batch + training_actor, inference_actor = create_llama2_actor( + vocabulary_size, + n_head=32, + n_layer=2, + ) + training_batch = generate_valid_data_batch( + vocabulary_size, batch_size, sequence_length + ) + inference_batch = training_batch.clone() + inference_batch.batch_size = [batch_size] + + # Check that the inference model works + inference_actor = inference_actor.to(device) + inference_batch = inference_batch.to(device) + inference_batch = inference_actor(inference_batch) + assert "logits" in inference_batch.keys() + assert "action" in inference_batch.keys() + + # Check that the training model works + training_actor = training_actor.to(device) + training_batch = training_batch.to(device) + training_batch = training_actor(training_batch) + assert "logits" in training_batch.keys() + assert "action" in training_batch.keys() + + +@pytest.mark.parametrize("vocabulary_size", [10]) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("critic_value_per_action", [True, False]) +def test_llama2_critic( + vocabulary_size, + device, + critic_value_per_action, + sequence_length=5, + batch_size=10, +): + torch.manual_seed(0) + # Create the model and a data batch + training_critic, inference_critic = create_llama2_critic( + vocabulary_size, + critic_value_per_action=critic_value_per_action, + n_head=32, + n_layer=2, + ) + training_batch = generate_valid_data_batch( + vocabulary_size, batch_size, sequence_length + ) + inference_batch = training_batch.clone() + inference_batch.batch_size = [batch_size] + + # Check that the inference model works + inference_critic = inference_critic.to(device) + inference_batch = inference_batch.to(device) + inference_batch = inference_critic(inference_batch) + if critic_value_per_action: + assert "action_value" in inference_batch.keys() + else: + assert "state_value" in inference_batch.keys() + + # Check that the training model works + training_critic = training_critic.to(device) + training_batch = training_batch.to(device) + training_batch = training_critic(training_batch) + if critic_value_per_action: + assert "action_value" in training_batch.keys() + else: + assert "state_value" in training_batch.keys() + + +@pytest.mark.parametrize("vocabulary_size", [10]) +@pytest.mark.parametrize("device", get_default_devices()) +@pytest.mark.parametrize("critic_value_per_action", [True, False]) +def test_llama2_actor_critic( + vocabulary_size, + device, + critic_value_per_action, + sequence_length=5, + batch_size=10, +): + torch.manual_seed(0) + # Create the model and a data batch + ( + training_actor, + inference_actor, + training_critic, + inference_critic, + ) = create_llama2_actor_critic( + vocabulary_size, + critic_value_per_action=critic_value_per_action, + n_head=32, + n_layer=2, + ) + training_batch = generate_valid_data_batch( + vocabulary_size, batch_size, sequence_length + ) + inference_batch = training_batch.clone() + inference_batch.batch_size = [batch_size] + + # Check that the inference model works + inference_actor = inference_actor.to(device) + inference_critic = inference_critic.to(device) + inference_batch = inference_batch.to(device) + inference_batch = inference_actor(inference_batch) + inference_batch = inference_critic(inference_batch) + assert "logits" in inference_batch.keys() + assert "action" in inference_batch.keys() + if critic_value_per_action: + assert "action_value" in inference_batch.keys() + else: + assert "state_value" in inference_batch.keys() + + # Check that the training model works + training_actor = training_actor.to(device) + training_critic = training_critic.to(device) + training_batch = training_batch.to(device) + training_batch = training_actor(training_batch) + training_batch = training_critic(training_batch) + assert "logits" in training_batch.keys() + assert "action" in training_batch.keys() + if critic_value_per_action: + assert "action_value" in training_batch.keys() + else: + assert "state_value" in training_batch.keys()