-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
132 lines (98 loc) · 4.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch
import os
import os.path as osp
import shutil
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, download_url, extract_zip
from torch_geometric.data import Dataset
from loguru import logger
from tqdm.notebook import tqdm
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import DataLoader
from modules import GNN,GCN,CellGraphDataset
# from visualizations import visualize_graph
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 8
def main():
dataset = CellGraphDataset(root='./data', name = 'DS',use_node_attr=False,use_edge_attr=True)
# visualize_graph(dataset[0])
print_graph_stats(dataset)
torch.manual_seed(12345)
dataset = dataset.shuffle()
n_trn = (int(round(.67*len(dataset))))
train_dataset = dataset[:n_trn]
test_dataset = dataset[n_trn:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
for step, data in enumerate(train_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
model = GNN(hidden_channels=64,dataset=dataset).to(DEVICE)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(1, 201):
train_results = train(model,train_loader,optimizer,criterion)
test_results = test(model,test_loader)
train_acc = train_results['accuracy']
test_acc = test_results['accuracy']
print(f'Epoch: {epoch}, Train Acc: {np.mean(train_acc)}, Test Acc: {np.mean(test_acc)}')
def train(model, train_loader, optimizer, criterion):
model.train()
results={'loss':[],'accuracy':[]}
correct = 0
for i, data in enumerate(train_loader): # Iterate in batches over the training dataset.
out = model(data['x'].to(DEVICE), data['edge_index'].to(DEVICE),data['batch'].to(DEVICE)) # Perform a single forward pass.
loss = criterion(out, data['y'].to(DEVICE)) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data['y'].to(DEVICE)).sum()) # Check against ground-truth labels.
results['loss'].append(loss)
results['accuracy'].append(correct / len(train_loader.dataset))
# logger.warning(f'Batch{i}: {results}')
return results
def test(model,test_loader):
model.eval()
correct = 0
results={'accuracy':[]}
for data in test_loader: # Iterate in batches over the training/test dataset.
# data.to(DEVICE)
out = model(data['x'].to(DEVICE), data['edge_index'].to(DEVICE),data['batch'].to(DEVICE))
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data['y'].to(DEVICE)).sum()) # Check against ground-truth labels.
accuracy = correct / len(test_loader.dataset) # Derive ratio of correct predictions.
results['accuracy'].append(accuracy)
# logger.warning(f'Test Results : {results}')
return results # Derive ratio of correct predictions.
def print_graph_stats(dataset):
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # Get the first graph object.
print()
print(data)
print('=============================================================')
# Gather some statistics about the first graph.
print('FOR THE FIRST GRAPH:')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
if __name__=='__main__':
main()