-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1819452
commit bb97b90
Showing
19 changed files
with
1,218 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"search_params": [ | ||
{ | ||
"num_filters": [16, 32], | ||
"smi_filter_len": [3, 4], | ||
"prot_filter_len": [4, 6], | ||
"embedding_dim": [128] | ||
}, | ||
{ | ||
"batch_size": [128, 256], | ||
"optimizer": ["adam"], | ||
"learning_rate": [0.001] | ||
} | ||
], | ||
|
||
"fixed_params": { | ||
"max_smi_len": 100, | ||
"max_prot_len": 1000, | ||
"n_epochs": 200 | ||
} | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
{ | ||
"search_params": [ | ||
{ | ||
"num_filters": [0], | ||
"smi_filter_len": [0], | ||
"prot_filter_len": [0], | ||
"embedding_dim": [0] | ||
} | ||
], | ||
|
||
"fixed_params": { | ||
"batch_size": 128, | ||
"optimizer": "adam", | ||
"learning_rate": 0.001, | ||
"lm_ligand_embed_size":768, | ||
"lm_protein_embed_size":1024, | ||
"max_smi_len": 100, | ||
"max_prot_len": 1000, | ||
"n_epochs": 200 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
{ | ||
"chembl27_raw": "./data/corpora/chembl27_chemreps.txt", | ||
"chembl27_smi": "./data/corpora/chembl27.smi", | ||
"chembl27_smi_encoded": "./data/corpora/chembl27.smi.enc", | ||
"chembl27_vocab": "./data/corpora/chembl27.vocab", | ||
"chembl27_brics": "./data/corpora/chembl27.brics", | ||
"chembl27_brics_lm": "./data/corpora/chembl27.brics.lm", | ||
"chembl27_selfies_vocab": "./data/corpora/chembl27_selfies.vocab", | ||
"uniprot_raw": "./data/corpora/uniprot-reviewed_yes.fasta", | ||
"uniprot_aa": "./data/corpora/uniprot-reviewed_yes.aa", | ||
"natural_lang": "./data/natural_langs/", | ||
|
||
"chem_vocab": "./data/vocabs/chemical/", | ||
"prot_vocab": "./data/vocabs/protein/", | ||
"natural_lang_vocab": "./data/vocabs/natural_langs/", | ||
"models": "./models/", | ||
|
||
"bdb": { | ||
"folds": "./data/bdb/setups/", | ||
"old_setups": "./data/bdb/leave_out/", | ||
"ligands": "./data/bdb/ligands.json", | ||
"proteins": "./data/bdb/proteins.json", | ||
"pfams": "./data/bdb/pfams.csv", | ||
"sw_sim_matrix": "./data/bdb/sw_sim_matrix.csv" | ||
}, | ||
|
||
"kiba": { | ||
"folds": "./data/kiba/setups/", | ||
"old_setups": "./data/kiba/leave_out/", | ||
"ligands": "./data/kiba/ligands.json", | ||
"proteins": "./data/kiba/proteins.json", | ||
"pfams": "./data/kiba/pfams.csv", | ||
"sw_sim_matrix": "./data/kiba/sw_sim_matrix.csv" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import datetime | ||
import json | ||
import time | ||
import sys | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from sklearn.linear_model import LinearRegression | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.tree import DecisionTreeRegressor | ||
|
||
from src.dta_models import BoWDTA, IDDTA | ||
from src.dta_models import DebiasedDTA, BPEDeepDTA, LMDeepDTA | ||
from src.tokenization_methods import SMIAwareBPE, LanguageModel | ||
from src.utils import init_debiasing_parser | ||
|
||
|
||
def create_weak_learner(name): | ||
if name == 'BOWDTA': | ||
return BoWDTA(DecisionTreeRegressor(), tokenizer, paths['chembl27_vocab'], model_name == "LM") | ||
if name == 'IDDTA': | ||
return IDDTA(DecisionTreeRegressor()) | ||
|
||
|
||
parser = init_debiasing_parser() | ||
args = parser.parse_args() | ||
args = vars(args) # to use as a dict | ||
|
||
start = time.time() | ||
mini_val_frac = 0.2 | ||
|
||
model_name = args['model_name'] | ||
weak_learner_name = args['weak_learner_name'].upper() | ||
chem_vocab_size = int(args['chem_vocab_size']) if model_name != 'lmdta' else 'NA' | ||
prot_vocab_size = int(args['prot_vocab_size']) if model_name != 'lmdta' else 'NA' | ||
dataset = args['dataset'] | ||
n_bootstrapping = args['n_bootstrapping'] | ||
decay_type = args['decay_type'] | ||
lm_ligand_path = args['lm_ligand_path'] | ||
lm_protein_path = args['lm_protein_path'] | ||
standardize_labels = args["scale"] | ||
val = args['val'] | ||
|
||
with open('./paths.json') as f: | ||
paths = json.load(f) | ||
dataset_paths = paths[dataset] | ||
|
||
decay_mode = 'BD' if decay_type=='lin_decrase' else 'BG' | ||
model_dir = paths['models'] + f'{dataset}/debiaseddta/{model_name}/{weak_learner_name}-{decay_mode}/' | ||
|
||
DEBUG_MODE = False | ||
test_scores = [] | ||
|
||
for setup_ix in range(5): | ||
train = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/train.csv') | ||
test_data = {} | ||
test_data['warm'] = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/test_warm.csv') | ||
test_data['cold_lig'] = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/test_cold_lig.csv') | ||
test_data['cold_prot'] = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/test_cold_prot.csv') | ||
test_data['cold_both'] = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/test_cold_both.csv') | ||
if val: | ||
test_data['val'] = pd.read_csv(dataset_paths['folds'] + f'setup_{setup_ix}/val_warm.csv') | ||
scaler = None | ||
if standardize_labels: | ||
scaler = StandardScaler() | ||
train["affinity_score"] = scaler.fit_transform(train["affinity_score"].values.reshape(-1, 1)) | ||
for set_type in test_data: | ||
test_data[set_type]["affinity_score"] = scaler.transform( | ||
test_data[set_type]["affinity_score"].values.reshape(-1, 1)) | ||
|
||
if model_name == 'bpedta': | ||
strong_model_params_path = paths['models'] + f'{dataset}/{model_name}/chem_{chem_vocab_size}_prot_{prot_vocab_size}/setup_{setup_ix}/params.json' | ||
else: | ||
strong_model_params_path = paths['models'] + f'{dataset}/{model_name}/setup_{setup_ix}/params.json' | ||
with open(strong_model_params_path) as f: | ||
strong_model_params = json.load(f) | ||
|
||
if DEBUG_MODE: | ||
train = train.iloc[:100, :] | ||
test_data = {k: v.iloc[:100, :] for k, v in test_data.items()} | ||
strong_model_params['n_epochs'] = 1 | ||
n_bootstrapping = 2 | ||
|
||
tokenization_methods = {'bpedta': SMIAwareBPE, 'deepdta':SMIAwareBPE, 'lmdta': LanguageModel} | ||
strong_learner_methods = {'bpedta': BPEDeepDTA, 'deepdta':BPEDeepDTA, 'lmdta': LMDeepDTA} | ||
tokenization_method = tokenization_methods[model_name] | ||
strong_learner_method = strong_learner_methods[model_name] | ||
|
||
extra_tokenizer_args = {'dataset_name': dataset, | ||
'lm_protein_path': lm_protein_path, | ||
'lm_ligand_path': lm_ligand_path | ||
} | ||
tokenizer = tokenization_method(paths, chem_vocab_size, prot_vocab_size, **extra_tokenizer_args) | ||
if model_name == 'LM': | ||
strong_model_params['chem_vocab_size'] = tokenizer.chem_vocab_size | ||
strong_model_params['prot_vocab_size'] = tokenizer.prot_vocab_size | ||
|
||
strong_model_params["scaler"] = scaler | ||
|
||
strong_learner = strong_learner_method(strong_model_params, tokenizer, paths['chembl27_vocab']) | ||
weak_learner = create_weak_learner(weak_learner_name) | ||
debiaseddta = DebiasedDTA(weak_learner, strong_learner, mini_val_frac, n_bootstrapping, decay_type) | ||
|
||
setup_dir = model_dir + f'setup_{setup_ix}/' | ||
if val: | ||
debiaseddta.train(train, val_data=test_data['val'], savedir=setup_dir) | ||
else: | ||
debiaseddta.train(train, savedir=setup_dir) | ||
|
||
setup_scores = debiaseddta.evaluate(test_data, setup_dir, 'test') | ||
test_scores.append(setup_scores) | ||
debiaseddta.save(setup_dir) | ||
debiaseddta.plot_loss(setup_dir) | ||
|
||
cv_test_results = {} | ||
for fold in ['warm', 'cold_lig', 'cold_prot', 'cold_both']: | ||
cv_test_results[fold] = {} | ||
for metric in ['mse', 'ci', 'rmse', 'r2']: | ||
cv_test_results[fold][metric] = {} | ||
metric_scores = [score[fold][metric] for score in test_scores] | ||
mean = np.mean(metric_scores) | ||
std = np.std(metric_scores) | ||
cv_test_results[fold][metric]['mean'] = mean | ||
cv_test_results[fold][metric]['std'] = std | ||
|
||
with open(model_dir + 'test_scores.json', 'w') as f: | ||
json.dump(cv_test_results, f, indent=4) | ||
|
||
end = time.time() | ||
print('The program took:', datetime.timedelta(seconds=end - start)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from src.dta_models.bowdta import BoWDTA | ||
from src.dta_models.deepdta import DeepDTA | ||
from src.dta_models.bpedeepdta import BPEDeepDTA | ||
from src.dta_models.debiaseddta import DebiasedDTA | ||
from src.dta_models.iddta import IDDTA | ||
from src.dta_models.lmdeepdta import LMDeepDTA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import datetime | ||
import time | ||
|
||
import numpy as np | ||
from tensorflow.keras.preprocessing.text import Tokenizer | ||
|
||
from src.utils import encode_smiles | ||
|
||
|
||
class BoWDTA: | ||
def __init__(self, prediction_model, bpe_tokenizer, smi_encoding_vocab_path, strong_LM=False): | ||
self.bpe_tokenizer = bpe_tokenizer | ||
self.smi_encoding_vocab_path = smi_encoding_vocab_path | ||
self.prediction_model = prediction_model | ||
self.chem_bow_vectorizer = Tokenizer(filters=None, lower=False, oov_token='C') | ||
self.prot_bow_vectorizer = Tokenizer(filters=None, lower=False, oov_token='$') | ||
self.strong_LM = strong_LM | ||
|
||
def __preprocess_data_for_bow(self, data): | ||
data = data.copy() | ||
if not self.strong_LM: | ||
data['smiles'] = data['smiles'].apply(encode_smiles, encoding_vocab_path=self.smi_encoding_vocab_path) | ||
chemicals = self.bpe_tokenizer.chem_tokenizer.identify_words(data['smiles'], out_type='int') | ||
proteins = self.bpe_tokenizer.prot_tokenizer.identify_words(data['aa_sequence'], out_type='int') | ||
return chemicals, proteins | ||
|
||
def __get_bow_representations(self, chemicals, proteins): | ||
X_chem = self.chem_bow_vectorizer.texts_to_matrix(chemicals, mode='freq') | ||
X_prot = self.prot_bow_vectorizer.texts_to_matrix(proteins, mode='freq') | ||
return np.hstack([X_chem, X_prot]) | ||
|
||
def train(self, train): | ||
chemicals, proteins = self.__preprocess_data_for_bow(train) | ||
self.chem_bow_vectorizer.fit_on_texts(chemicals) | ||
self.prot_bow_vectorizer.fit_on_texts(proteins) | ||
|
||
X_train = self.__get_bow_representations(chemicals, proteins) | ||
start = time.time() | ||
print('Started training decision-tree on bow vectors') | ||
self.prediction_model.fit(X_train, train['affinity_score']) | ||
end = time.time() | ||
print('Weak model training took:', datetime.timedelta(seconds=end - start)) | ||
|
||
def predict(self, test): | ||
chemicals, proteins = self.__preprocess_data_for_bow(test) | ||
X_test = self.__get_bow_representations(chemicals, proteins) | ||
return self.prediction_model.predict(X_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from src.dta_models import DeepDTA | ||
|
||
|
||
class BPEDeepDTA: | ||
def __init__(self, model_configs, bpe_tokenizer, smi_encoding_vocab_path): | ||
self.model = DeepDTA(**model_configs) | ||
self.bpe_tokenizer = bpe_tokenizer | ||
self.smi_encoding_vocab_path = smi_encoding_vocab_path | ||
|
||
def train(self, train_data, val_data=None, sample_weights=None, decay_type=None): | ||
pp_train = self.bpe_tokenizer.fn_pp(data=train_data.copy(), | ||
max_smi_len=self.model.max_smi_len, | ||
max_prot_len=self.model.max_prot_len, | ||
smi_encoding_vocab_path=self.smi_encoding_vocab_path) | ||
pp_val = None | ||
if val_data is not None: | ||
pp_val = self.bpe_tokenizer.fn_pp(data=val_data.copy(), | ||
max_smi_len=self.model.max_smi_len, | ||
max_prot_len=self.model.max_prot_len, | ||
smi_encoding_vocab_path=self.smi_encoding_vocab_path) | ||
return self.model.train(pp_train, pp_val, sample_weights, decay_type) | ||
|
||
def evaluate(self, evaluation_data, savedir=None, mode='train'): | ||
pp_test = {name: self.bpe_tokenizer.fn_pp(data=fold, | ||
max_smi_len=self.model.max_smi_len, | ||
max_prot_len=self.model.max_prot_len, | ||
smi_encoding_vocab_path=self.smi_encoding_vocab_path) | ||
for name, fold in evaluation_data.items()} | ||
return self.model.evaluate(pp_test, savedir, mode) | ||
|
||
def save(self, savedir): | ||
self.model.save(savedir) | ||
|
||
def plot_loss(self, savedir): | ||
self.model.plot_loss(savedir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
import pickle | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
class DebiasedDTA: | ||
def __init__(self, weak_learner, strong_learner, mini_val_frac, n_bootstrapping, decay_type=None): | ||
self.weak_learner = weak_learner | ||
self.strong_learner = strong_learner | ||
self.mini_val_frac = mini_val_frac | ||
self.n_bootstrapping = n_bootstrapping | ||
self.decay_type = decay_type | ||
|
||
def learn_sample_weights(self, train_data, savedir=None): | ||
train = train_data.copy() | ||
train['interaction_id'] = list(range(len(train))) | ||
mini_val_data_size = int(len(train) * self.mini_val_frac) + 1 | ||
interaction_id_to_sq_diff = [[] for i in range(len(train))] | ||
|
||
for i in range(self.n_bootstrapping): | ||
print(f'Bootstrapping ix:{i + 1}/{self.n_bootstrapping}') | ||
train = train.sample(frac=1) # shuffle | ||
n_mini_val = int(1 / self.mini_val_frac) | ||
for mini_val_ix in range(n_mini_val): | ||
print(f'Mini val ix:{mini_val_ix + 1}/{n_mini_val}') | ||
val_start_ix = mini_val_ix * mini_val_data_size | ||
val_end_ix = val_start_ix + mini_val_data_size | ||
mini_val = train.iloc[val_start_ix: val_end_ix, :] | ||
mini_train = pd.concat([train.iloc[:val_start_ix, :], | ||
train.iloc[val_end_ix:, :]]) | ||
assert len(mini_train) + len(mini_val) == len(train) | ||
|
||
self.weak_learner.train(mini_train) | ||
preds = self.weak_learner.predict(mini_val) | ||
mini_val['sq_diff'] = (mini_val['affinity_score'] - preds) ** 2 | ||
dct = mini_val.groupby('interaction_id')['sq_diff'].first().to_dict() | ||
for k, v in dct.items(): | ||
interaction_id_to_sq_diff[k].append(v) | ||
|
||
for ix, l in enumerate(interaction_id_to_sq_diff): | ||
assert len(l) == self.n_bootstrapping | ||
|
||
interaction_id_to_med_diff = [np.median(diffs) for diffs in interaction_id_to_sq_diff] | ||
weights = [med / sum(interaction_id_to_med_diff) for med in interaction_id_to_med_diff] | ||
if savedir is not None: | ||
train['sq_diff'] = interaction_id_to_med_diff | ||
train['weights'] = weights | ||
if not os.path.exists(savedir): | ||
os.makedirs(savedir) | ||
train.to_csv(savedir + 'train_weights.csv', index=None) | ||
with open(savedir + "weak_model.pkl", "wb") as f: | ||
pickle.dump(self.weak_learner, f) | ||
return np.array(weights) | ||
|
||
def train(self, train_data, val_data=None, savedir=None): | ||
sample_weights = self.learn_sample_weights(train_data, savedir) | ||
return self.strong_learner.train(train_data, val_data, sample_weights, self.decay_type) | ||
|
||
def only_weak_train(self, train_data, savedir=None): | ||
return self.learn_sample_weights(train_data, savedir, sample_path="") | ||
|
||
def evaluate(self, test_data, savedir=None, mode='train'): | ||
return self.strong_learner.evaluate(test_data, savedir, mode) | ||
|
||
def save(self, savedir): | ||
self.strong_learner.save(savedir) | ||
|
||
def plot_loss(self, savedir): | ||
self.strong_learner.plot_loss(savedir) |
Oops, something went wrong.