-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset_graph.py
112 lines (72 loc) · 3.65 KB
/
dataset_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from pathlib import Path
import anndata
import numpy as np
import os
import pickle
import torch
import torch.nn.functional as F
from graph_construct import genegene, cellgene, cellcell
class Dataset:
def __init__(self, dataset_name):
self.dataset_name= dataset_name
data_path = os.path.join("processed_data", self.dataset_name+".h5ad")
self.y, self.train_ids,self.test_ids, self.valid_ids= self.train_test_ids(data_path)
self.expression_matrix_binned= self.expression_values(data_path)[0][torch.hstack([self.train_ids,self.valid_ids, self.test_ids])]
# I created them but I dont use them now.
self.expression_matrix_raw_all=self.expression_values(data_path)[1] #here, the valid and train is mixed, be carefull!!!
self.expression_matrix_raw_test= self.expression_matrix_raw_all[self.expression_matrix_raw_all.obs["batch_id"]==1] # consider only the test
self.generate_graph()
def train_test_ids(self,path):
train_ids, test_ids = [], []
loaded_ann = anndata.read_h5ad(path)
all_batch_ids = loaded_ann.obs["batch_id"].tolist()
for id, val in enumerate(all_batch_ids):
if val==0:
train_ids.append(id)
if val==1:
test_ids.append(id)
y= loaded_ann.obs["celltype_id"].tolist()
y= np.array(y)
train_ids= np.array(train_ids)
loaded_data = np.load(f"/auto/k2/aykut3/scgpt/scGPT/scgpt_gcn/save_scgcn/scgpt_{self.dataset_name}_median/indices.npz")
train_indices= train_ids[loaded_data["tr_indices"]]
valid_indices= train_ids[loaded_data["val_indices"]]
test_ids= np.array(test_ids)
return torch.tensor(y), torch.tensor(train_indices),torch.tensor(test_ids), torch.tensor(valid_indices)
def expression_values(self,path):
expression_matrix_raw = anndata.read_h5ad(path)
expression_matrix_binned= expression_matrix_raw.layers["X_binned"]
return (expression_matrix_binned, expression_matrix_raw)
def generate_graph(self):
expression_matrix= self.expression_matrix_binned
self.GG=genegene(expression_matrix)
self.CG=cellgene(expression_matrix,n_bins=51)
self.CC=cellcell(expression_matrix)
self.GC=self.CG.T
def __repr__(self) -> str:
print_str = (
f"Dataset({self.dataset_name})"
f"\nTotal Number of cells: {len(self.expression_matrix_binned)}"
f"\nTotal Number of cells: {len(self.expression_matrix_binned[0])}"
f"\nNumber of unique cell type is: {len(self.y.unique())}"
)
return print_str
def load_processed_dataset(dataset_name):
file_path = os.path.join("processed_data", f"{dataset_name}.pkl")
with open(file_path, "rb") as file:
dataset: Dataset = pickle.load(file)
return dataset
if __name__ == "__main__":
for datasets in ["ms"]: #["ms","pancreas","myeloid"]
DATA_DIR = os.path.join("processed_data", datasets +".pkl")
if not os.path.exists(DATA_DIR):
dataset = Dataset(dataset_name=datasets)
print(dataset)
with open(DATA_DIR, "wb") as f:
pickle.dump(dataset, f)
print("Dataset object is pickled.")
else:
print(f"{DATA_DIR} already exists.")
dataset=load_processed_dataset(datasets)
print(len(dataset.test_ids))
print(dataset.expression_matrix_raw_test.obs)