Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Apr 26, 2024
1 parent 34cf86c commit d4499c6
Show file tree
Hide file tree
Showing 11 changed files with 187 additions and 3 deletions.
2 changes: 1 addition & 1 deletion scripts/ahc/ahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def create_env_fn():
optim.step()

# Then add new experiences to the replay buffer
if cfg.experience_replay is True:
if cfg.experience_replay:

replay_data = data.clone()

Expand Down
2 changes: 1 addition & 1 deletion scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def create_env_fn():
log_info.update({f"train/{key}": value.item()})

# Add new experiences to the replay buffer
if cfg.experience_replay is True:
if cfg.experience_replay:

# MaxValueWriter is not compatible with storages of more than one dimension.
replay_data.batch_size = [replay_data.batch_size[0]]
Expand Down
5 changes: 5 additions & 0 deletions scripts/reinforce/config_denovo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ model: gru # gru, lstm, or gpt2
lr: 0.0001
eps: 1.0e-08
weight_decay: 0.0

# Data replay configuration
replay_buffer_size: 100
experience_replay: False
replay_batch_size: 10
5 changes: 5 additions & 0 deletions scripts/reinforce/config_fragment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ model: gru # gru, lstm, or gpt2
lr: 0.0001
eps: 1.0e-08
weight_decay: 0.0

# Data replay configuration
replay_buffer_size: 100
experience_replay: False
replay_batch_size: 10
5 changes: 5 additions & 0 deletions scripts/reinforce/config_scaffold.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ model: gru # gru, lstm, or gpt2
lr: 0.0001
eps: 1.0e-08
weight_decay: 0.0

# Data replay configuration
replay_buffer_size: 100
experience_replay: False
replay_batch_size: 10
44 changes: 44 additions & 0 deletions scripts/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ def create_env_fn():

env = create_env_fn()

# Create replay buffer
####################################################################################################################

storage = LazyTensorStorage(cfg.replay_buffer_size, device=device)
experience_replay_buffer = TensorDictReplayBuffer(
storage=storage,
sampler=PrioritizedSampler(storage.max_size, alpha=1.0, beta=1.0),
batch_size=cfg.replay_batch_size,
writer=TensorDictMaxValueWriter(rank_key="priority"),
priority_key="priority",
)

# Create optimizer
####################################################################################################################

Expand Down Expand Up @@ -256,6 +268,15 @@ def create_env_fn():

data, loss = compute_loss(data, actor_training)

# Compute experience replay loss
if (
cfg.experience_replay
and len(experience_replay_buffer) > cfg.replay_batch_size
):
replay_batch = experience_replay_buffer.sample()
_, replay_loss = compute_loss(replay_batch, actor_training)
loss = torch.cat((loss, replay_loss), 0)

# Average loss over the batch
loss = loss.mean()

Expand All @@ -264,6 +285,29 @@ def create_env_fn():
loss.backward()
optim.step()

# Then add new experiences to the replay buffer
if cfg.experience_replay:

replay_data = data.clone()

# MaxValueWriter is not compatible with storages of more than one dimension.
replay_data.batch_size = [replay_data.batch_size[0]]

# Remove SMILES that are already in the replay buffer
if len(experience_replay_buffer) > 0:
is_duplicated = isin(
input=replay_data,
key="action",
reference=experience_replay_buffer[:],
)
replay_data = replay_data[~is_duplicated]

# Add data to the replay buffer
reward = replay_data.get(("next", "reward"))
replay_data.set("priority", reward)
if len(replay_data) > 0:
experience_replay_buffer.extend(replay_data)

# Log info
if logger:
for key, value in log_info.items():
Expand Down
2 changes: 1 addition & 1 deletion scripts/reinvent/reinvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def create_env_fn():
optim.step()

# Then add new experiences to the replay buffer
if cfg.experience_replay is True:
if cfg.experience_replay:

replay_data = data.clone()

Expand Down
3 changes: 3 additions & 0 deletions tests/check_scripts/run-example-scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,21 @@ scripts=(
run_pretrain_single_node.sh
run_pretrain_distributed.sh

run_reinforce_denovo.sh
run_reinvent_denovo.sh
run_ahc_denovo.sh
run_a2c_denovo.sh
run_ppo_denovo.sh
run_ppod_denovo.sh

run_reinforce_scaffold.sh
run_reinvent_scaffold.sh
run_ahc_scaffold.sh
run_a2c_scaffold.sh
run_ppo_scaffold.sh
run_ppod_scaffold.sh

run_reinforce_fragment.sh
run_reinvent_fragment.sh
run_ahc_fragment.sh
run_a2c_fragment.sh
Expand Down
42 changes: 42 additions & 0 deletions tests/check_scripts/run_reinforce_denovo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

#SBATCH --job-name=reinforce_denovo
#SBATCH --ntasks=6
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/reinforce_denovo%j.txt
#SBATCH --error=slurm_errors/reinforce_denovo%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="acegen-scripts-check-$current_commit"
agent_name="reinforce_denovo"
if [ -z "$N_RUN" ]; then
echo "N_RUN is not set. Setting to default value of 1."
N_RUN=1
fi
if [ -z "$ACEGEN_MODEL" ]; then
echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]"
ACEGEN_MODEL="gru"
fi

export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/scripts/reinforce/reinforce.py \
logger_backend=wandb \
experiment_name="$project_name" \
agent_name="$agent_name" \
molscore=MolOpt \
molscore_include=[Albuterol_similarity] \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
40 changes: 40 additions & 0 deletions tests/check_scripts/run_reinforce_fragment.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

#SBATCH --job-name=reinforce_fragment
#SBATCH --ntasks=6
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/reinforce_fragment%j.txt
#SBATCH --error=slurm_errors/reinforcefragment%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="acegen-scripts-check-$current_commit"
agent_name="reinforce_fragment"
if [ -z "$N_RUN" ]; then
echo "N_RUN is not set. Setting to default value of 1."
N_RUN=1
fi
if [ -z "$ACEGEN_MODEL" ]; then
echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]"
ACEGEN_MODEL="gru"
fi

export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/scripts/reinforce/reinforce.py --config-name config_fragment \
logger_backend=wandb \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/
40 changes: 40 additions & 0 deletions tests/check_scripts/run_reinforce_scaffold.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

#SBATCH --job-name=reinforce_scaffold
#SBATCH --ntasks=6
#SBATCH --cpus-per-task=1
#SBATCH --gres=gpu:1
#SBATCH --output=slurm_logs/reinforce_scaffold%j.txt
#SBATCH --error=slurm_errors/reinforce_scaffold%j.txt

current_commit=$(git rev-parse --short HEAD)
project_name="acegen-scripts-check-$current_commit"
agent_name="reinforce_scaffold"
if [ -z "$N_RUN" ]; then
echo "N_RUN is not set. Setting to default value of 1."
N_RUN=1
fi
if [ -z "$ACEGEN_MODEL" ]; then
echo "ACEGEN_MODEL is not set. Setting to default value of gru. Choose from [gru, lstm, gpt2]"
ACEGEN_MODEL="gru"
fi

export PYTHONPATH=$(dirname $(dirname $PWD))
python $PYTHONPATH/scripts/reinforce/reinforce.py --config-name config_scaffold \
logger_backend=wandb \
experiment_name="$project_name" \
agent_name="$agent_name" \
seed=$N_RUN \
log_dir="$agent_name"_seed"$N_RUN" \
model=$ACEGEN_MODEL

# Capture the exit status of the Python command
exit_status=$?
# Write the exit status to a file
if [ $exit_status -eq 0 ]; then
echo "${agent_name}_${SLURM_JOB_ID}=success" >> report.log
else
echo "${agent_name}_${SLURM_JOB_ID}=error" >> report.log
fi

mv "$agent_name"_seed"$N_RUN"* slurm_logs/

0 comments on commit d4499c6

Please sign in to comment.