From d4499c6a193453282401929360a3dc9168997f05 Mon Sep 17 00:00:00 2001 From: albert bou Date: Fri, 26 Apr 2024 14:35:50 +0200 Subject: [PATCH] fix --- scripts/ahc/ahc.py | 2 +- scripts/ppo/ppo.py | 2 +- scripts/reinforce/config_denovo.yaml | 5 +++ scripts/reinforce/config_fragment.yaml | 5 +++ scripts/reinforce/config_scaffold.yaml | 5 +++ scripts/reinforce/reinforce.py | 44 +++++++++++++++++++ scripts/reinvent/reinvent.py | 2 +- tests/check_scripts/run-example-scripts.sh | 3 ++ tests/check_scripts/run_reinforce_denovo.sh | 42 ++++++++++++++++++ tests/check_scripts/run_reinforce_fragment.sh | 40 +++++++++++++++++ tests/check_scripts/run_reinforce_scaffold.sh | 40 +++++++++++++++++ 11 files changed, 187 insertions(+), 3 deletions(-) create mode 100755 tests/check_scripts/run_reinforce_denovo.sh create mode 100755 tests/check_scripts/run_reinforce_fragment.sh create mode 100755 tests/check_scripts/run_reinforce_scaffold.sh diff --git a/scripts/ahc/ahc.py b/scripts/ahc/ahc.py index d2b3b0b3..46a1fe3c 100644 --- a/scripts/ahc/ahc.py +++ b/scripts/ahc/ahc.py @@ -300,7 +300,7 @@ def create_env_fn(): optim.step() # Then add new experiences to the replay buffer - if cfg.experience_replay is True: + if cfg.experience_replay: replay_data = data.clone() diff --git a/scripts/ppo/ppo.py b/scripts/ppo/ppo.py index 81ee426c..7f414633 100644 --- a/scripts/ppo/ppo.py +++ b/scripts/ppo/ppo.py @@ -413,7 +413,7 @@ def create_env_fn(): log_info.update({f"train/{key}": value.item()}) # Add new experiences to the replay buffer - if cfg.experience_replay is True: + if cfg.experience_replay: # MaxValueWriter is not compatible with storages of more than one dimension. replay_data.batch_size = [replay_data.batch_size[0]] diff --git a/scripts/reinforce/config_denovo.yaml b/scripts/reinforce/config_denovo.yaml index 434ebdf8..15a13167 100644 --- a/scripts/reinforce/config_denovo.yaml +++ b/scripts/reinforce/config_denovo.yaml @@ -26,3 +26,8 @@ model: gru # gru, lstm, or gpt2 lr: 0.0001 eps: 1.0e-08 weight_decay: 0.0 + +# Data replay configuration +replay_buffer_size: 100 +experience_replay: False +replay_batch_size: 10 diff --git a/scripts/reinforce/config_fragment.yaml b/scripts/reinforce/config_fragment.yaml index fa286628..173ada55 100644 --- a/scripts/reinforce/config_fragment.yaml +++ b/scripts/reinforce/config_fragment.yaml @@ -29,3 +29,8 @@ model: gru # gru, lstm, or gpt2 lr: 0.0001 eps: 1.0e-08 weight_decay: 0.0 + +# Data replay configuration +replay_buffer_size: 100 +experience_replay: False +replay_batch_size: 10 diff --git a/scripts/reinforce/config_scaffold.yaml b/scripts/reinforce/config_scaffold.yaml index 9a6156f8..8cde46d0 100644 --- a/scripts/reinforce/config_scaffold.yaml +++ b/scripts/reinforce/config_scaffold.yaml @@ -29,3 +29,8 @@ model: gru # gru, lstm, or gpt2 lr: 0.0001 eps: 1.0e-08 weight_decay: 0.0 + +# Data replay configuration +replay_buffer_size: 100 +experience_replay: False +replay_batch_size: 10 diff --git a/scripts/reinforce/reinforce.py b/scripts/reinforce/reinforce.py index 8ad5b8ee..9737d9ef 100644 --- a/scripts/reinforce/reinforce.py +++ b/scripts/reinforce/reinforce.py @@ -178,6 +178,18 @@ def create_env_fn(): env = create_env_fn() + # Create replay buffer + #################################################################################################################### + + storage = LazyTensorStorage(cfg.replay_buffer_size, device=device) + experience_replay_buffer = TensorDictReplayBuffer( + storage=storage, + sampler=PrioritizedSampler(storage.max_size, alpha=1.0, beta=1.0), + batch_size=cfg.replay_batch_size, + writer=TensorDictMaxValueWriter(rank_key="priority"), + priority_key="priority", + ) + # Create optimizer #################################################################################################################### @@ -256,6 +268,15 @@ def create_env_fn(): data, loss = compute_loss(data, actor_training) + # Compute experience replay loss + if ( + cfg.experience_replay + and len(experience_replay_buffer) > cfg.replay_batch_size + ): + replay_batch = experience_replay_buffer.sample() + _, replay_loss = compute_loss(replay_batch, actor_training) + loss = torch.cat((loss, replay_loss), 0) + # Average loss over the batch loss = loss.mean() @@ -264,6 +285,29 @@ def create_env_fn(): loss.backward() optim.step() + # Then add new experiences to the replay buffer + if cfg.experience_replay: + + replay_data = data.clone() + + # MaxValueWriter is not compatible with storages of more than one dimension. + replay_data.batch_size = [replay_data.batch_size[0]] + + # Remove SMILES that are already in the replay buffer + if len(experience_replay_buffer) > 0: + is_duplicated = isin( + input=replay_data, + key="action", + reference=experience_replay_buffer[:], + ) + replay_data = replay_data[~is_duplicated] + + # Add data to the replay buffer + reward = replay_data.get(("next", "reward")) + replay_data.set("priority", reward) + if len(replay_data) > 0: + experience_replay_buffer.extend(replay_data) + # Log info if logger: for key, value in log_info.items(): diff --git a/scripts/reinvent/reinvent.py b/scripts/reinvent/reinvent.py index d7b47f3f..ec43e414 100644 --- a/scripts/reinvent/reinvent.py +++ b/scripts/reinvent/reinvent.py @@ -296,7 +296,7 @@ def create_env_fn(): optim.step() # Then add new experiences to the replay buffer - if cfg.experience_replay is True: + if cfg.experience_replay: replay_data = data.clone() diff --git a/tests/check_scripts/run-example-scripts.sh b/tests/check_scripts/run-example-scripts.sh index 2f48f7c7..bce1b3c2 100755 --- a/tests/check_scripts/run-example-scripts.sh +++ b/tests/check_scripts/run-example-scripts.sh @@ -48,18 +48,21 @@ scripts=( run_pretrain_single_node.sh run_pretrain_distributed.sh + run_reinforce_denovo.sh run_reinvent_denovo.sh run_ahc_denovo.sh run_a2c_denovo.sh run_ppo_denovo.sh run_ppod_denovo.sh + run_reinforce_scaffold.sh run_reinvent_scaffold.sh run_ahc_scaffold.sh run_a2c_scaffold.sh run_ppo_scaffold.sh run_ppod_scaffold.sh + run_reinforce_fragment.sh run_reinvent_fragment.sh run_ahc_fragment.sh run_a2c_fragment.sh diff --git a/tests/check_scripts/run_reinforce_denovo.sh b/tests/check_scripts/run_reinforce_denovo.sh new file mode 100755 index 00000000..9b3cd7cd --- /dev/null +++ b/tests/check_scripts/run_reinforce_denovo.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +#SBATCH --job-name=reinforce_denovo +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/reinforce_denovo%j.txt +#SBATCH --error=slurm_errors/reinforce_denovo%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="reinforce_denovo" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]" + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/reinforce/reinforce.py \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + molscore=MolOpt \ + molscore_include=[Albuterol_similarity] \ + seed=$N_RUN \ + log_dir="$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi + +mv "$agent_name"_seed"$N_RUN"* slurm_logs/ diff --git a/tests/check_scripts/run_reinforce_fragment.sh b/tests/check_scripts/run_reinforce_fragment.sh new file mode 100755 index 00000000..15bc14b8 --- /dev/null +++ b/tests/check_scripts/run_reinforce_fragment.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +#SBATCH --job-name=reinforce_fragment +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/reinforce_fragment%j.txt +#SBATCH --error=slurm_errors/reinforcefragment%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="reinforce_fragment" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]" + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/reinforce/reinforce.py --config-name config_fragment \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir="$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi + +mv "$agent_name"_seed"$N_RUN"* slurm_logs/ diff --git a/tests/check_scripts/run_reinforce_scaffold.sh b/tests/check_scripts/run_reinforce_scaffold.sh new file mode 100755 index 00000000..118e736d --- /dev/null +++ b/tests/check_scripts/run_reinforce_scaffold.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +#SBATCH --job-name=reinforce_scaffold +#SBATCH --ntasks=6 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/reinforce_scaffold%j.txt +#SBATCH --error=slurm_errors/reinforce_scaffold%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="acegen-scripts-check-$current_commit" +agent_name="reinforce_scaffold" +if [ -z "$N_RUN" ]; then + echo "N_RUN is not set. Setting to default value of 1." + N_RUN=1 +fi +if [ -z "$ACEGEN_MODEL" ]; then + echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]" + ACEGEN_MODEL="gru" +fi + +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/scripts/reinforce/reinforce.py --config-name config_scaffold \ + logger_backend=wandb \ + experiment_name="$project_name" \ + agent_name="$agent_name" \ + seed=$N_RUN \ + log_dir="$agent_name"_seed"$N_RUN" \ + model=$ACEGEN_MODEL + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log +else + echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log +fi + +mv "$agent_name"_seed"$N_RUN"* slurm_logs/