Skip to content

Commit

Permalink
Fix two bugs from Issue.
Browse files Browse the repository at this point in the history
Disable numpy autobroadcasting & default turn on mem_efficient_mode
  • Loading branch information
ruochiz committed Mar 11, 2024
1 parent 107b3e5 commit 1333de2
Showing 1 changed file with 89 additions and 60 deletions.
149 changes: 89 additions & 60 deletions higashi/Higashi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def sum_duplicates(col, data):

def one_thread_generate_neg(edges_part, edges_chrom, edge_weight,
collect_num=1, training=False, chroms_in_batch=None):
global sparse_chrom_list_GCN, neg_num
# global sparse_chrom_list_GCN, neg_num
if neg_num == 0:
y = np.ones((len(edges_part), 1))
w = np.ones((len(edges_part), 1)) * edge_weight.reshape((-1, 1))
Expand Down Expand Up @@ -317,7 +317,7 @@ def one_thread_generate_neg(edges_part, edges_chrom, edge_weight,

# Force to append an empty list and remove it, such that np.array won't broadcasting shapes
to_neighs.append([])
to_neighs = np.array(to_neighs)[:-1]
to_neighs = np.array(to_neighs, dtype='object')[:-1]
to_neighs = np.array(to_neighs, dtype='object').reshape((len(x), 2))

size = int(len(x) / collect_num)
Expand Down Expand Up @@ -453,10 +453,10 @@ def __init__(self, config_path):


# For processing data: old Process.py
def process_data(self):
def process_data(self, disable_mpl=False):
self.generate_chrom_start_end()
self.extract_table()
self.create_matrix()
self.create_matrix(disable_mpl)
try:
from .Process import process_signal, impute_all
except:
Expand Down Expand Up @@ -484,12 +484,12 @@ def extract_table(self):
from Process import extract_table
extract_table(self.config)

def create_matrix(self):
def create_matrix(self, disable_mpl=False):
try:
from .Process import create_matrix
except:
from Process import create_matrix
create_matrix(self.config)
create_matrix(self.config, disable_mpl)

# fetch information from config.JSON
def fetch_info_from_config(self):
Expand All @@ -513,7 +513,7 @@ def fetch_info_from_config(self):
else:
self.current_device = 'cpu'
torch.set_num_threads(self.cpu_num_torch)

self.data_dir = config['data_dir']
self.temp_dir = config['temp_dir']
self.embed_dir = os.path.join(self.temp_dir, "embed")
Expand All @@ -539,7 +539,12 @@ def fetch_info_from_config(self):
self.coassay = config['coassay']
else:
self.coassay = False


if 'pre_cell_embed' in config:
self.pre_cell_embed = config['pre_cell_embed']
else:
self.pre_cell_embed = False

self.save_path = os.path.join(self.temp_dir, "model")
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
Expand Down Expand Up @@ -595,9 +600,16 @@ def generate_attributes(self):

for c in self.chrom_list:
a = np.array(save_file["cell"]["%d" % self.chrom_list.index(c)])
# a = np.eye(a.shape[0]).astype('float32')
cell_node_feats.append(a)

if self.coassay:

if self.pre_cell_embed:
print ("pre_cell_embed")
cell_node_feats = np.load(self.pre_cell_embed).astype('float32')
targets.append(StandardScaler().fit_transform(cell_node_feats.reshape((-1, 1))).reshape((len(cell_node_feats), -1)))
embeddings.append(StandardScaler().fit_transform(cell_node_feats.reshape((-1, 1))).reshape((len(cell_node_feats), -1)))

elif self.coassay:
print("coassay")
cell_node_feats = np.load(os.path.join(self.temp_dir, "pretrain_coassay.npy")).astype('float32')
targets.append(cell_node_feats)
Expand All @@ -615,21 +627,25 @@ def generate_attributes(self):
else:
cell_node_feats1 = remove_BE_linear(cell_node_feats, self.config, self.data_dir, self.cell_feats1)
cell_node_feats2 = cell_node_feats1

targets.append(cell_node_feats2.astype('float32'))
embeddings.append(cell_node_feats1.astype('float32'))

for i, c in enumerate(self.chrom_list):
temp = np.array(save_file["%d" % i]).astype('float32')
# temp = StandardScaler().fit_transform(temp.reshape((-1, 1))).reshape((len(temp), -1))
# temp = np.eye(temp.shape[0]).astype('float32')
# print ("min_max", np.min(temp), np.max(temp), np.median(temp))
temp = StandardScaler().fit_transform(temp.reshape((-1, 1))).reshape((len(temp), -1))
# print(np.min(temp), np.max(temp), np.median(temp))
chrom = np.zeros((len(temp), len(self.chrom_list))).astype('float32')
chrom[:, i] = 1
list1 = [temp, chrom]

temp = np.concatenate(list1, axis=-1)
embeddings.append(temp)
targets.append(temp)



print("start making attribute")
attribute_all = []
for i in range(len(self.num) - 1):
Expand Down Expand Up @@ -686,14 +702,21 @@ def prep_model(self):
# If there are more than 95% zero entries at 1Mb, then rely more on Higashi non-linear transformation
self.median_total_sparsity_cell = np.median(self.total_sparsity_cell)
print("total_sparsity_cell", self.median_total_sparsity_cell)
if self.median_total_sparsity_cell >= 0.05:
if self.coassay or self.pre_cell_embed:
print("contractive loss")
self.contractive_flag = True
self.contractive_loss_weight = 1e-3

else:
print("no contractive loss")
self.contractive_flag = False
self.contractive_loss_weight = 0.0
if self.median_total_sparsity_cell >= 0.05:
print("contractive loss")
self.contractive_flag = True
self.contractive_loss_weight = 1e-3
else:
print("no contractive loss")
self.contractive_flag = False
self.contractive_loss_weight = 0.0

start_end_dict = np.array(input_f['start_end_dict'])
self.cell_feats = np.array(input_f['extra_cell_feats'])
self.cell_feats1 = np.array(input_f['cell2weight'])
Expand All @@ -707,8 +730,8 @@ def prep_model(self):
num_list = np.cumsum(self.num)
self.num_list = num_list
max_bin = int(np.max(self.num[1:]))
mem_efficient_flag = self.cell_num > 30

# mem_efficient_flag = self.cell_num > 30
mem_efficient_flag = True

total_possible = 0

Expand Down Expand Up @@ -791,8 +814,8 @@ def prep_model(self):
train_weight, test_weight = transform_weight_class(train_weight, train_weight_mean, neg_num), \
transform_weight_class(test_weight, train_weight_mean, neg_num)
elif self.mode == 'rank':
train_weight += 1
test_weight += 1
train_weight = [x + 1 for x in train_weight]
test_weight = [x + 1 for x in test_weight]


# Constructing the model
Expand Down Expand Up @@ -932,18 +955,18 @@ def train_epoch(self, training_data_generator, optimizer_list, train_pool, train
j * size: min((j + 1) * size,
len(batch_edge_big))], \
batch_to_neighs_big[j]

pred, loss_bce, loss_mse = self.forward_batch_hyperedge(batch_edge,
batch_edge_weight, batch_chrom,
batch_to_neighs, y=batch_y,
chroms_in_batch=chroms_in_batch)

y_list.append(batch_y.detach().cpu())
w_list.append(batch_edge_weight.detach().cpu())
pred_list.append(pred.detach().cpu())

final_batch_num += 1

if self.use_recon:
for opt in optimizer_list:
opt.zero_grad(set_to_none=True)
Expand All @@ -952,45 +975,45 @@ def train_epoch(self, training_data_generator, optimizer_list, train_pool, train
main_norm = self.node_embedding_init.wstack[0].weight_list[0].grad.data.norm(2)
except:
main_norm = 0.0

for opt in optimizer_list:
opt.zero_grad(set_to_none=True)
loss_mse.backward(retain_graph=True)

recon_norm = self.node_embedding_init.wstack[0].weight_list[0].grad.data.norm(2)
shape = self.node_embedding_init.wstack[0].weight_list[0].shape[1]
ratio = self.beta * main_norm / recon_norm
ratio1 = max(ratio, 100 * self.median_total_sparsity_cell - 3)

if self.contractive_flag:
contractive_loss = 0.0
for i in range(len(self.node_embedding_init.wstack[0].weight_list)):
contractive_loss += torch.sum(self.node_embedding_init.wstack[0].weight_list[i] ** 2)
contractive_loss += torch.sum(self.node_embedding_init.wstack[0].reverse_weight_list[i] ** 2)

else:
contractive_loss = 0.0

else:
contractive_loss = 0.0
ratio = 0.0
ratio1 = 0.0

train_loss = self.alpha * loss_bce + ratio1 * loss_mse + self.contractive_loss_weight * contractive_loss
for opt in optimizer_list:
opt.zero_grad(set_to_none=True)
# backward
train_loss.backward()

# update parameters
for opt in optimizer_list:
opt.step()

bar.update(n=1)
bar.set_description("- (Train) BCE: %.3f MSE: %.3f norm_ratio: %.2f" %
(loss_bce.item(), loss_mse.item(), ratio1),
refresh=False)

bce_total_loss += loss_bce.item()
mse_total_loss += loss_mse.item()

Expand Down Expand Up @@ -1086,7 +1109,8 @@ def train(self, training_data_generator, validation_data_generator, optimizer, e
model.load_state_dict(checkpoint['model_link'])

best_train_loss = 1000



if save_embed:
self.save_embeddings()

Expand Down Expand Up @@ -1196,7 +1220,7 @@ def save_embeddings(self):
model = self.higashi_model
model.eval()
with torch.no_grad():
ids = torch.arange(1, num_list[-1] + 1).long().to(device, non_blocking=True).view(-1)
ids = torch.arange(1, self.num_list[-1] + 1).long().to(device, non_blocking=True).view(-1)
embeddings = []
for j in range(math.ceil(len(ids) / self.batch_size)):
x = ids[j * self.batch_size:min((j + 1) * self.batch_size, len(ids))]
Expand All @@ -1206,9 +1230,9 @@ def save_embeddings(self):
embeddings.append(embed)

embeddings = np.concatenate(embeddings, axis=0)
for i in range(len(num_list)):
start = 0 if i == 0 else num_list[i - 1]
static = embeddings[int(start):int(num_list[i])]
for i in range(len(self.num_list)):
start = 0 if i == 0 else self.num_list[i - 1]
static = embeddings[int(start):int(self.num_list[i])]

if i == 0:
try:
Expand All @@ -1224,7 +1248,7 @@ def save_embeddings(self):
return embeddings

def get_cell_neighbor_be(self, start=1):
v = self.cell_embeddings
v = self.cell_embeddings if not self.pre_cell_embed else np.load(self.pre_cell_embed)
distance = pairwise_distances(v, metric='euclidean')
distance_sorted = np.sort(distance, axis=-1)
distance /= np.quantile(distance_sorted[:, 1:self.neighbor_num].reshape((-1)), q=0.25)
Expand Down Expand Up @@ -1277,7 +1301,7 @@ def get_cell_neighbor_be(self, start=1):
return np.array(cell_neighbor_list_local), np.array(cell_neighbor_weight_list_local)

def get_cell_neighbor(self, start=1):
v = self.cell_embeddings
v = self.cell_embeddings if not self.pre_cell_embed else np.load(self.pre_cell_embed)
distance = pairwise_distances(v, metric='euclidean')
distance_sorted = np.sort(distance, axis=-1)
distance /= np.quantile(distance_sorted[:, 1:self.neighbor_num].reshape((-1)), q=0.25)
Expand All @@ -1299,7 +1323,7 @@ def get_cell_neighbor(self, start=1):
return np.array(cell_neighbor_list_local), np.array(cell_neighbor_weight_list_local)


def train_for_embeddings(self):
def train_for_embeddings(self, max_epochs=None):
global steps, pair_ratio

optimizer = torch.optim.Adam(
Expand Down Expand Up @@ -1336,7 +1360,7 @@ def train_for_embeddings(self):
self.train(
training_data_generator=self.training_data_generator,
validation_data_generator=self.validation_data_generator,
optimizer=[optimizer], epochs=self.embedding_epoch,
optimizer=[optimizer], epochs=self.embedding_epoch if max_epochs is None else max_epochs,
load_first=False, save_embed=True, save_name="_stage1")

checkpoint = {
Expand Down Expand Up @@ -1586,29 +1610,34 @@ def fetch_map(self, chrom, cell):
c = self.chrom_list.index(chrom)
s, e = self.chrom_start_end[c]
size = e - s

with h5py.File(os.path.join(self.temp_dir, "%s_%s_nbr_%d_impute.hdf5" % (chrom, self.embedding_name, 0)), "r") as f:
coordinates = np.array(f['coordinates']).astype('int')
p = np.array(f["cell_%d" % cell])

m1 = csr_matrix((p, (coordinates[:, 0], coordinates[:, 1])), shape=(size, size), dtype='float32')
m1 = m1 + m1.T

with h5py.File(os.path.join(self.temp_dir, "%s_%s_nbr_%d_impute.hdf5" % (chrom, self.embedding_name, self.neighbor_num - 1)), "r") as f:
coordinates = np.array(f['coordinates']).astype('int')
p = np.array(f["cell_%d" % cell])

m2 = csr_matrix((p, (coordinates[:, 0], coordinates[:, 1])), shape=(size, size), dtype='float32')
m2 = m2 + m2.T

try:
with h5py.File(os.path.join(self.temp_dir, "%s_%s_nbr_%d_impute.hdf5" % (chrom, self.embedding_name, 0)), "r") as f:
coordinates = np.array(f['coordinates']).astype('int')
p = np.array(f["cell_%d" % cell])

m1 = csr_matrix((p, (coordinates[:, 0], coordinates[:, 1])), shape=(size, size), dtype='float32')
m1 = m1 + m1.T
except Exception as e:
m1 = np.zeros((size, size))
print ("No 0 nbr imputation for %s %d" % (chrom, cell))

try:
with h5py.File(os.path.join(self.temp_dir, "%s_%s_nbr_%d_impute.hdf5" % (chrom, self.embedding_name, self.neighbor_num - 1)), "r") as f:
coordinates = np.array(f['coordinates']).astype('int')
p = np.array(f["cell_%d" % cell])

m2 = csr_matrix((p, (coordinates[:, 0], coordinates[:, 1])), shape=(size, size), dtype='float32')
m2 = m2 + m2.T
except Exception as e:
m2 = np.zeros((size, size))
print ("No %d nbr imputation for %s %d" % (self.neighbor_num, chrom, cell))

if chrom not in self.ori_sparse_list:
self.ori_sparse_list[chrom] = np.load(os.path.join(self.temp_dir, "raw", "%s_sparse_adj.npy" % chrom), allow_pickle=True)

m3 = self.ori_sparse_list[chrom][cell]

return m3, m1, m2



if __name__ == '__main__':
# Get parameters from config file
Expand Down

0 comments on commit 1333de2

Please sign in to comment.