Skip to content

Commit

Permalink
v0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiaoyu11 committed May 21, 2021
1 parent 8fe9bf3 commit 345953c
Show file tree
Hide file tree
Showing 10 changed files with 116 additions and 112 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# OmiEmbed

**OmiEmbed: reconstruct comprehensive phenotypic information from multi-omics data using multi-task deep learning**
**OmiEmbed: A Unified Multi-task Deep Learning Framework for Multi-omics Data**

**Xiaoyu Zhang, Kai Sun, Yike Guo**
**Xiaoyu Zhang, Yuting Xing, Kai Sun, Yike Guo**

Data Science Institute, Imperial College London

Expand Down
13 changes: 7 additions & 6 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ class CustomDataLoader:
"""
Create a dataloader for certain dataset.
"""
def __init__(self, dataset, param, shuffle=True):
def __init__(self, dataset, param, shuffle=True, enable_drop_last=False):
self.dataset = dataset
self.param = param

drop_last = False
if len(dataset) % param.batch_size < 3*len(param.gpu_ids):
drop_last = True
if enable_drop_last:
if len(dataset) % param.batch_size < 3*len(param.gpu_ids):
drop_last = True

# Create dataloader for this dataset
self.dataloader = DataLoaderPrefetch(
Expand Down Expand Up @@ -115,12 +116,12 @@ def get_sample_list(self):
return self.dataset.sample_list


def create_single_dataloader(param, shuffle=True):
def create_single_dataloader(param, shuffle=True, enable_drop_last=False):
"""
Create a single dataloader
"""
dataset = create_dataset(param)
dataloader = CustomDataLoader(dataset, param, shuffle=shuffle)
dataloader = CustomDataLoader(dataset, param, shuffle=shuffle, enable_drop_last=enable_drop_last)
sample_list = dataset.sample_list

return dataloader, sample_list
Expand Down Expand Up @@ -163,7 +164,7 @@ def create_separate_dataloader(param):
test_dataset = Subset(full_dataset, test_idx)

full_dataloader = CustomDataLoader(full_dataset, param)
train_dataloader = CustomDataLoader(train_dataset, param)
train_dataloader = CustomDataLoader(train_dataset, param, enable_drop_last=True)
val_dataloader = CustomDataLoader(val_dataset, param, shuffle=False)
test_dataloader = CustomDataLoader(test_dataset, param, shuffle=False)

Expand Down
140 changes: 69 additions & 71 deletions models/networks.py

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions params/basic_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def initialize(self, parser):
help='name of the folder in the checkpoint directory')

# Dataset parameters
parser.add_argument('--omics_mode', type=str, default='abc',
parser.add_argument('--omics_mode', type=str, default='a',
help='omics types would like to use in the model, options: [abc | ab | a | b | c]')
parser.add_argument('--data_root', required=True,
help='path to input data')
parser.add_argument('--batch_size', type=int, default=32,
help='input data batch size')
parser.add_argument('--num_threads', default=6, type=int,
parser.add_argument('--num_threads', default=0, type=int,
help='number of threads for loading data')
parser.add_argument('--set_pin_memory', action='store_true',
help='set pin_memory in the dataloader to increase data loading performance')
Expand All @@ -55,7 +55,7 @@ def initialize(self, parser):
# Model parameters
parser.add_argument('--model', type=str, default='vae_classifier',
help='chooses which model want to use, options: [vae_classifier | vae_regression | vae_survival | vae_multitask]')
parser.add_argument('--net_VAE', type=str, default='conv_1d',
parser.add_argument('--net_VAE', type=str, default='fc_sep',
help='specify the backbone of the VAE, default is the one dimensional CNN, options: [conv_1d | fc_sep | fc]')
parser.add_argument('--net_down', type=str, default='multi_FC_classifier',
help='specify the backbone of the downstream task network, default is the multi-layer FC classifier, options: [multi_FC_classifier | multi_FC_regression | multi_FC_survival | multi_FC_multitask]')
Expand All @@ -65,7 +65,7 @@ def initialize(self, parser):
help='number of filters in the last convolution layer in the generator')
parser.add_argument('--conv_k_size', type=int, default=9,
help='the kernel size of convolution layer, default kernel size is 9, the kernel is one dimensional.')
parser.add_argument('--dropout_p', type=float, default=0,
parser.add_argument('--dropout_p', type=float, default=0.2,
help='probability of an element to be zeroed in a dropout layer, default is 0 which means no dropout.')
parser.add_argument('--leaky_slope', type=float, default=0.2,
help='the negative slope of the Leaky ReLU activation function')
Expand All @@ -81,11 +81,11 @@ def initialize(self, parser):
# Loss parameters
parser.add_argument('--recon_loss', type=str, default='BCE',
help='chooses the reconstruction loss function, options: [BCE | MSE | L1]')
parser.add_argument('--reduction', type=str, default='sum',
parser.add_argument('--reduction', type=str, default='mean',
help='chooses the reduction to apply to the loss function, options: [sum | mean]')
parser.add_argument('--k_kl', type=float, default=1,
parser.add_argument('--k_kl', type=float, default=0.01,
help='weight for the kl loss')
parser.add_argument('--k_embed', type=float, default=0,
parser.add_argument('--k_embed', type=float, default=0.001,
help='weight for the embedding loss')

# Other parameters
Expand Down
8 changes: 4 additions & 4 deletions params/train_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ def initialize(self, parser):
# Training parameters
parser.add_argument('--epoch_num_p1', type=int, default=50,
help='epoch number for phase 1')
parser.add_argument('--epoch_num_p2', type=int, default=0,
parser.add_argument('--epoch_num_p2', type=int, default=50,
help='epoch number for phase 2')
parser.add_argument('--epoch_num_p3', type=int, default=100,
help='epoch number for phase 3')
parser.add_argument('--lr', type=float, default=1e-3,
parser.add_argument('--lr', type=float, default=1e-4,
help='initial learning rate')
parser.add_argument('--beta1', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]')
parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, default start from 1')
parser.add_argument('--epoch_num_decay', type=int, default=0,
parser.add_argument('--epoch_num_decay', type=int, default=50,
help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)')
parser.add_argument('--decay_step_size', type=int, default=50,
help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)')
parser.add_argument('--weight_decay', type=float, default=0,
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay (L2 penalty)')

# Network saving and loading parameters
Expand Down
8 changes: 4 additions & 4 deletions params/train_test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ def initialize(self, parser):
# Training parameters
parser.add_argument('--epoch_num_p1', type=int, default=50,
help='epoch number for phase 1')
parser.add_argument('--epoch_num_p2', type=int, default=0,
parser.add_argument('--epoch_num_p2', type=int, default=50,
help='epoch number for phase 2')
parser.add_argument('--epoch_num_p3', type=int, default=100,
help='epoch number for phase 3')
parser.add_argument('--lr', type=float, default=1e-3,
parser.add_argument('--lr', type=float, default=1e-4,
help='initial learning rate')
parser.add_argument('--beta1', type=float, default=0.5,
help='momentum term of adam')
parser.add_argument('--lr_policy', type=str, default='linear',
help='The learning rate policy for the scheduler. [linear | step | plateau | cosine]')
parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, default start from 1')
parser.add_argument('--epoch_num_decay', type=int, default=0,
parser.add_argument('--epoch_num_decay', type=int, default=50,
help='Number of epoch to linearly decay learning rate to zero (lr_policy == linear)')
parser.add_argument('--decay_step_size', type=int, default=50,
help='The original learning rate multiply by a gamma every decay_step_size epoch (lr_policy == step)')
parser.add_argument('--weight_decay', type=float, default=0,
parser.add_argument('--weight_decay', type=float, default=1e-4,
help='weight decay (L2 penalty)')

# Network saving and loading parameters
Expand Down
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Separated training for OmiEmbed
"""
import time
import warnings
from util import util
from params.train_params import TrainParams
from datasets import create_single_dataloader
Expand All @@ -10,13 +11,14 @@


if __name__ == "__main__":
warnings.filterwarnings('ignore')
# Get parameters
param = TrainParams().parse()
if param.deterministic:
util.setup_seed(param.seed)

# Dataset related
dataloader, sample_list = create_single_dataloader(param)
dataloader, sample_list = create_single_dataloader(param, enable_drop_last=True)
print('The size of training set is {}'.format(len(dataloader)))
# Get the dimension of input omics data
param.omics_dims = dataloader.get_omics_dims()
Expand Down
3 changes: 2 additions & 1 deletion train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Training and testing for OmiEmbed
"""
import time
import warnings
from util import util
from params.train_test_params import TrainTestParams
from datasets import create_separate_dataloader
Expand All @@ -10,8 +11,8 @@


if __name__ == "__main__":
warnings.filterwarnings('ignore')
full_start_time = time.time()

# Get parameters
param = TrainTestParams().parse()
if param.deterministic:
Expand Down
6 changes: 3 additions & 3 deletions util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def mkdir(path):
path(str) -- a directory path we would like to create
"""
if not os.path.exists(path):
os.mkdir(path)
os.makedirs(path)


def clear_dir(path):
Expand All @@ -27,8 +27,8 @@ def clear_dir(path):
path(str) -- a directory path that we would like to delete all files in it
"""
if os.path.exists(path):
shutil.rmtree(path)
os.mkdir(path)
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path, exist_ok=True)


def setup_seed(seed):
Expand Down
28 changes: 15 additions & 13 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def get_epoch_metrics(self, output_dict):
y_prob = y_prob[:, 1]

accuracy = sk.metrics.accuracy_score(y_true, y_pred)
precision = sk.metrics.precision_score(y_true, y_pred, average='weighted', zero_division=0)
recall = sk.metrics.recall_score(y_true, y_pred, average='weighted', zero_division=0)
f1 = sk.metrics.f1_score(y_true, y_pred, average='weighted', zero_division=0)
precision = sk.metrics.precision_score(y_true, y_pred, average='macro', zero_division=0)
recall = sk.metrics.recall_score(y_true, y_pred, average='macro', zero_division=0)
f1 = sk.metrics.f1_score(y_true, y_pred, average='macro', zero_division=0)
try:
auc = sk.metrics.roc_auc_score(y_true_binary, y_prob, multi_class='ovo', average='weighted')
auc = sk.metrics.roc_auc_score(y_true_binary, y_prob, multi_class='ovo', average='macro')
except ValueError:
auc = -1
print('ValueError: ROC AUC score is not defined in this case.')
Expand Down Expand Up @@ -284,28 +284,30 @@ def get_epoch_metrics(self, output_dict):
if self.param.class_num == 2:
y_prob_cla = y_prob_cla[:, 1]
accuracy = sk.metrics.accuracy_score(y_true_cla, y_pred_cla)
# precision = sk.metrics.precision_score(y_true_cla, y_pred_cla, average='weighted', zero_division=0)
# recall = sk.metrics.recall_score(y_true_cla, y_pred_cla, average='weighted', zero_division=0)
f1 = sk.metrics.f1_score(y_true_cla, y_pred_cla, average='weighted', zero_division=0)
precision = sk.metrics.precision_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
recall = sk.metrics.recall_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
f1 = sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)
'''
try:
auc = sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='weighted')
auc = sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro')
except ValueError:
auc = -1
print('ValueError: ROC AUC score is not defined in this case.')
'''

# Regression
y_true_reg = output_dict['y_true_reg'].cpu().numpy()
y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy()
# mse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg)
rmse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False)
# mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg)
# medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg)
mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg)
medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg)
r2 = sk.metrics.r2_score(y_true_reg, y_pred_reg)

metrics_time = time.time() - metrics_start_time
print('Metrics computing time: {:.3f}s'.format(metrics_time))

return {'c-index': c_index, 'ibs': ibs, 'accuracy': accuracy, 'f1': f1, 'auc': auc, 'rmse': rmse, 'r2': r2}
return {'c-index': c_index, 'ibs': ibs, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'rmse': rmse, 'mae': mae, 'medae': medae, 'r2': r2}

elif self.param.downstream_task == 'alltask':
metrics_start_time = time.time()
Expand Down Expand Up @@ -339,9 +341,9 @@ def get_epoch_metrics(self, output_dict):
if self.param.class_num[i] == 2:
y_prob_cla = y_prob_cla[:, 1]
accuracy.append(sk.metrics.accuracy_score(y_true_cla, y_pred_cla))
f1.append(sk.metrics.f1_score(y_true_cla, y_pred_cla, average='weighted', zero_division=0))
f1.append(sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0))
try:
auc.append(sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='weighted'))
auc.append(sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro'))
except ValueError:
auc.append(-1)
print('ValueError: ROC AUC score is not defined in this case.')
Expand Down

0 comments on commit 345953c

Please sign in to comment.