diff --git a/.gitignore b/.gitignore index 87f69bd..69b42d1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,5 +5,9 @@ lp/preprocessed/Douban_Movie/neg_ratings_offset.npy lp/preprocessed/Douban_Movie/unconnected_pairs_offset.npy lp/preprocessed/Yelp/neg_ratings_offset.npy lp/preprocessed/Yelp/unconnected_pairs_offset.npy +lp/log/ -nc/data/ \ No newline at end of file +nc/data/ +nc/log/ + +*.pyc \ No newline at end of file diff --git a/lp/arch.py b/lp/arch.py new file mode 100644 index 0000000..37d3230 --- /dev/null +++ b/lp/arch.py @@ -0,0 +1,18 @@ +archs = { + "Amazon" : { + "source" : ([[4, 3, 2, 0]], [[1, 1, 9, 9, 0, 8]]), + "target" : ([[5, 4, 2, 1]], [[9, 2, 7, 9, 8, 6]]) + }, + "Yelp" : { + "source" : ([[6, 5, 4, 3]], [[9, 4, 10, 10, 9, 9]]), + "target" : ([[4, 5, 9, 2]], [[3, 2, 8, 10, 5, 10]]) + }, + "Douban_Movie" : { + "source" : ([[5, 7, 0, 1]], [[6, 0, 3, 11, 12, 11]]), + "target" : ([[10, 0, 9, 2]], [[7, 5, 6, 12, 11, 5]]) + }, + "Try" : { + "source" : ([[3, 6, 9, 3], [3, 7, 0]], [[10, 9, 5, 3, 10, 1], [9, 9, 9]]), + "target" : ([[6, 3, 6, 2], [4, 4, 7]], [[5, 6, 10, 5, 7, 10], [1, 2, 10]]) + } +} \ No newline at end of file diff --git a/lp/model.py b/lp/model.py new file mode 100644 index 0000000..6621bb1 --- /dev/null +++ b/lp/model.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Op(nn.Module): + + def __init__(self): + super(Op, self).__init__() + + def forward(self, x, adjs, idx): + return torch.spmm(adjs[idx], x) + +class Cell(nn.Module): + + def __init__(self, n_step, n_hid_prev, n_hid, use_norm = True, use_nl = True): + super(Cell, self).__init__() + + self.affine = nn.Linear(n_hid_prev, n_hid) + self.n_step = n_step + self.norm = nn.LayerNorm(n_hid) if use_norm is True else lambda x : x + self.use_nl = use_nl + self.ops_seq = nn.ModuleList() + self.ops_res = nn.ModuleList() + for i in range(self.n_step): + self.ops_seq.append(Op()) + for i in range(1, self.n_step): + for j in range(i): + self.ops_res.append(Op()) + + def forward(self, x, adjs, idxes_seq, idxes_res): + + x = self.affine(x) + states = [x] + offset = 0 + for i in range(self.n_step): + seqi = self.ops_seq[i](states[i], adjs[:-1], idxes_seq[i]) #! exclude zero Op + resi = sum(self.ops_res[offset + j](h, adjs, idxes_res[offset + j]) for j, h in enumerate(states[:i])) + offset += i + states.append(seqi + resi) + #assert(offset == len(self.ops_res)) + + output = self.norm(states[-1]) + if self.use_nl: + output = F.gelu(output) + return output + + +class Model(nn.Module): + + def __init__(self, in_dims, n_hid, n_steps, dropout = None, attn_dim = 64, use_norm = True, out_nl = True): + super(Model, self).__init__() + self.n_hid = n_hid + self.ws = nn.ModuleList() + assert(isinstance(in_dims, list)) + for i in range(len(in_dims)): + self.ws.append(nn.Linear(in_dims[i], n_hid)) + assert(isinstance(n_steps, list)) + self.metas = nn.ModuleList() + for i in range(len(n_steps)): + self.metas.append(Cell(n_steps[i], n_hid, n_hid, use_norm = use_norm, use_nl = out_nl)) + + #* [Optional] Combine more than one meta graph? + self.attn_fc1 = nn.Linear(n_hid, attn_dim) + self.attn_fc2 = nn.Linear(attn_dim, 1) + + self.feats_drop = nn.Dropout(dropout) if dropout is not None else lambda x : x + + def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res): + hid = torch.zeros((node_types.size(0), self.n_hid)).cuda() + for i in range(len(node_feats)): + hid[node_types == i] = self.ws[i](node_feats[i]) + hid = self.feats_drop(hid) + temps = []; attns = [] + for i, meta in enumerate(self.metas): + hidi = meta(hid, adjs, idxes_seq[i], idxes_res[i]) + temps.append(hidi) + attni = self.attn_fc2(torch.tanh(self.attn_fc1(temps[-1]))) + attns.append(attni) + + hids = torch.stack(temps, dim=0).transpose(0, 1) + attns = F.softmax(torch.cat(attns, dim=-1), dim=-1) + out = (attns.unsqueeze(dim=-1) * hids).sum(dim=1) + return out \ No newline at end of file diff --git a/lp/model_search.py b/lp/model_search.py new file mode 100644 index 0000000..bab18ff --- /dev/null +++ b/lp/model_search.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Op(nn.Module): + + def __init__(self): + super(Op, self).__init__() + + def forward(self, x, adjs, ws, idx): + #assert(ws.size(0) == len(adjs)) + return ws[idx] * torch.spmm(adjs[idx], x) + +class Cell(nn.Module): + + def __init__(self, n_step, n_hid_prev, n_hid, cstr, use_norm = True, use_nl = True): + super(Cell, self).__init__() + + self.affine = nn.Linear(n_hid_prev, n_hid) + self.n_step = n_step + self.norm = nn.LayerNorm(n_hid, elementwise_affine = False) if use_norm is True else lambda x : x + self.use_nl = use_nl + assert(isinstance(cstr, list)) + self.cstr = cstr + + self.ops_seq = nn.ModuleList() ##! exclude last step + for i in range(self.n_step - 1): + self.ops_seq.append(Op()) + self.ops_res = nn.ModuleList() ##! exclude last step + for i in range(1, self.n_step - 1): + for j in range(i): + self.ops_res.append(Op()) + + self.last_seq = Op() + self.last_res = nn.ModuleList() + for i in range(self.n_step - 1): + self.last_res.append(Op()) + + + def forward(self, x, adjs, ws_seq, idxes_seq, ws_res, idxes_res): + #assert(isinstance(ws_seq, list)) + #assert(len(ws_seq) == 2) + x = self.affine(x) + states = [x] + offset = 0 + for i in range(self.n_step - 1): + seqi = self.ops_seq[i](states[i], adjs[:-1], ws_seq[0][i], idxes_seq[0][i]) #! exclude zero Op + resi = sum(self.ops_res[offset + j](h, adjs, ws_res[0][offset + j], idxes_res[0][offset + j]) for j, h in enumerate(states[:i])) + offset += i + states.append(seqi + resi) + #assert(offset == len(self.ops_res)) + + adjs_cstr = [adjs[i] for i in self.cstr] + out_seq = self.last_seq(states[-1], adjs_cstr, ws_seq[1], idxes_seq[1]) + adjs_cstr.append(adjs[-1]) + out_res = sum(self.last_res[i](h, adjs_cstr, ws_res[1][i], idxes_res[1][i]) for i, h in enumerate(states[:-1])) + output = self.norm(out_seq + out_res) + if self.use_nl: + output = F.gelu(output) + return output + + +class Model(nn.Module): + + def __init__(self, in_dims, n_hid, n_adjs, n_steps, cstr, attn_dim = 64, use_norm = True, out_nl = True): + super(Model, self).__init__() + self.cstr = cstr + self.n_adjs = n_adjs + self.n_hid = n_hid + self.ws = nn.ModuleList() + assert(isinstance(in_dims, list)) + for i in range(len(in_dims)): + self.ws.append(nn.Linear(in_dims[i], n_hid)) + assert(isinstance(n_steps, list)) + self.metas = nn.ModuleList() + for i in range(len(n_steps)): + self.metas.append(Cell(n_steps[i], n_hid, n_hid, cstr, use_norm = use_norm, use_nl = out_nl)) + + self.as_seq = [] + self.as_last_seq = [] + for i in range(len(n_steps)): + if n_steps[i] > 1: + ai = 1e-3 * torch.randn(n_steps[i] - 1, n_adjs - 1) #! exclude zero Op + ai = ai.cuda() + ai.requires_grad_(True) + self.as_seq.append(ai) + else: + self.as_seq.append(None) + ai_last = 1e-3 * torch.randn(len(cstr)) + ai_last = ai_last.cuda() + ai_last.requires_grad_(True) + self.as_last_seq.append(ai_last) + + ks = [sum(1 for i in range(1, n_steps[k] - 1) for j in range(i)) for k in range(len(n_steps))] + self.as_res = [] + self.as_last_res = [] + for i in range(len(n_steps)): + if ks[i] > 0: + ai = 1e-3 * torch.randn(ks[i], n_adjs) + ai = ai.cuda() + ai.requires_grad_(True) + self.as_res.append(ai) + else: + self.as_res.append(None) + + if n_steps[i] > 1: + ai_last = 1e-3 * torch.randn(n_steps[i] - 1, len(cstr) + 1) + ai_last = ai_last.cuda() + ai_last.requires_grad_(True) + self.as_last_res.append(ai_last) + else: + self.as_last_res.append(None) + + assert(ks[0] + n_steps[0] + (0 if self.as_last_res[0] is None else self.as_last_res[0].size(0)) == (1 + n_steps[0]) * n_steps[0] // 2) + + #* [Optional] Combine more than one meta graph? + self.attn_fc1 = nn.Linear(n_hid, attn_dim) + self.attn_fc2 = nn.Linear(attn_dim, 1) + + def alphas(self): + alphas = [] + for each in self.as_seq: + if each is not None: + alphas.append(each) + for each in self.as_last_seq: + alphas.append(each) + for each in self.as_res: + if each is not None: + alphas.append(each) + for each in self.as_last_res: + if each is not None: + alphas.append(each) + return alphas + + def sample(self, eps): + idxes_seq = [] + idxes_res = [] + if np.random.uniform() < eps: + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_seq[i] is None else torch.randint(low=0, high=self.as_seq[i].size(-1), size=self.as_seq[i].size()[:-1]).cuda()) + temp.append(torch.randint(low=0, high=self.as_last_seq[i].size(-1), size=(1,)).cuda()) + idxes_seq.append(temp) + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_res[i] is None else torch.randint(low=0, high=self.as_res[i].size(-1), size=self.as_res[i].size()[:-1]).cuda()) + temp.append(None if self.as_last_res[i] is None else torch.randint(low=0, high=self.as_last_res[i].size(-1), size=self.as_last_res[i].size()[:-1]).cuda()) + idxes_res.append(temp) + else: + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_seq[i] is None else torch.argmax(F.softmax(self.as_seq[i], dim=-1), dim=-1)) + temp.append(torch.argmax(F.softmax(self.as_last_seq[i], dim=-1), dim=-1)) + idxes_seq.append(temp) + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_res[i] is None else torch.argmax(F.softmax(self.as_res[i], dim=-1), dim=-1)) + temp.append(None if self.as_last_res[i] is None else torch.argmax(F.softmax(self.as_last_res[i], dim=-1), dim=-1)) + idxes_res.append(temp) + return idxes_seq, idxes_res + + def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res): + hid = torch.zeros((node_types.size(0), self.n_hid)).cuda() + for i in range(len(node_feats)): + hid[node_types == i] = self.ws[i](node_feats[i]) + temps = []; attns = [] + for i, meta in enumerate(self.metas): + ws_seq = [] + ws_seq.append(None if self.as_seq[i] is None else F.softmax(self.as_seq[i], dim=-1)) + ws_seq.append(F.softmax(self.as_last_seq[i], dim=-1)) + ws_res = [] + ws_res.append(None if self.as_res[i] is None else F.softmax(self.as_res[i], dim=-1)) + ws_res.append(None if self.as_last_res[i] is None else F.softmax(self.as_last_res[i], dim=-1)) + hidi = meta(hid, adjs, ws_seq, idxes_seq[i], ws_res, idxes_res[i]) + temps.append(hidi) + attni = self.attn_fc2(torch.tanh(self.attn_fc1(temps[-1]))) + attns.append(attni) + + hids = torch.stack(temps, dim=0).transpose(0, 1) + attns = F.softmax(torch.cat(attns, dim=-1), dim=-1) + out = (attns.unsqueeze(dim=-1) * hids).sum(dim=1) + return out + + def parse(self): + idxes_seq, idxes_res = self.sample(0.) + msg_seq = []; msg_res = [] + for i in range(len(idxes_seq)): + map_seq = [self.cstr[idxes_seq[i][1].item()]] + msg_seq.append(map_seq if idxes_seq[i][0] is None else idxes_seq[i][0].tolist() + map_seq) + assert(len(msg_seq[i]) == self.metas[i].n_step) + + temp_res = [] + if idxes_res[i][1] is not None: + for item in idxes_res[i][1].tolist(): + if item < len(self.cstr): + temp_res.append(self.cstr[item]) + else: + assert(item == len(self.cstr)) + temp_res.append(self.n_adjs - 1) + if idxes_res[i][0] is not None: + temp_res = idxes_res[i][0].tolist() + temp_res + assert(len(temp_res) == self.metas[i].n_step * (self.metas[i].n_step - 1) // 2) + msg_res.append(temp_res) + + return msg_seq, msg_res \ No newline at end of file diff --git a/lp/train.py b/lp/train.py new file mode 100644 index 0000000..5eacfd3 --- /dev/null +++ b/lp/train.py @@ -0,0 +1,157 @@ +import os +import sys +import time +import numpy as np +import pickle +import scipy.sparse as sp +import logging +import argparse +import torch +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score + +from model import Model +from preprocess import normalize_sym, normalize_row, sparse_mx_to_torch_sparse_tensor +from arch import archs + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', type=float, default=0.01, help='learning rate') +parser.add_argument('--wd', type=float, default=0.001, help='weight decay') +parser.add_argument('--n_hid', type=int, default=64, help='hidden dimension') +parser.add_argument('--dataset', type=str, default='Yelp') +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--epochs', type=int, default=200) +parser.add_argument('--dropout', type=float, default=0.2) +parser.add_argument('--seed', type=int, default=1) +args = parser.parse_args() + +prefix = "lr" + str(args.lr) + "_wd" + str(args.wd) + "_h" + str(args.n_hid) + \ + "_drop" + str(args.dropout) + "_epoch" + str(args.epochs) + "_cuda" + str(args.gpu) + +logdir = os.path.join("log/eval", args.dataset) +if not os.path.exists(logdir): + os.makedirs(logdir) + +log_format = '%(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format) +fh = logging.FileHandler(os.path.join(logdir, prefix + ".txt")) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + +def main(): + + torch.cuda.set_device(args.gpu) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + steps_s = [len(meta) for meta in archs[args.dataset]["source"][0]] + steps_t = [len(meta) for meta in archs[args.dataset]["target"][0]] + #print(steps_s, steps_t) + + datadir = "preprocessed" + prefix = os.path.join(datadir, args.dataset) + + #* load data + node_types = np.load(os.path.join(prefix, "node_types.npy")) + num_node_types = node_types.max() + 1 + node_types = torch.from_numpy(node_types).cuda() + + adjs_offset = pickle.load(open(os.path.join(prefix, "adjs_offset.pkl"), "rb")) + adjs_pt = [] + if '0' in adjs_offset: + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_sym(adjs_offset['0'] + sp.eye(adjs_offset['0'].shape[0], dtype=np.float32))).cuda()) + for i in range(1, int(max(adjs_offset.keys())) + 1): + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_row(adjs_offset[str(i)] + sp.eye(adjs_offset[str(i)].shape[0], dtype=np.float32))).cuda()) + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_row(adjs_offset[str(i)].T + sp.eye(adjs_offset[str(i)].shape[0], dtype=np.float32))).cuda()) + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(sp.eye(adjs_offset['1'].shape[0], dtype=np.float32).tocoo()).cuda()) + adjs_pt.append(torch.sparse.FloatTensor(size=adjs_offset['1'].shape).cuda()) + print("Loading {} adjs...".format(len(adjs_pt))) + + #* load labels + pos = np.load(os.path.join(prefix, "pos_pairs_offset.npz")) + pos_train = pos['train'] + pos_val = pos['val'] + pos_test = pos['test'] + + neg = np.load(os.path.join(prefix, "neg_pairs_offset.npz")) + neg_train = neg['train'] + neg_val = neg['val'] + neg_test = neg['test'] + + #* inputs + in_dims = [] + node_feats = [] + for k in range(num_node_types): + in_dims.append((node_types == k).sum().item()) + i = torch.stack((torch.arange(in_dims[-1], dtype=torch.long), torch.arange(in_dims[-1], dtype=torch.long))) + v = torch.ones(in_dims[-1]) + node_feats.append(torch.sparse.FloatTensor(i, v, torch.Size([in_dims[-1], in_dims[-1]])).cuda()) + assert(len(in_dims) == len(node_feats)) + + model_s = Model(in_dims, args.n_hid, steps_s, dropout = args.dropout).cuda() + model_t = Model(in_dims, args.n_hid, steps_t, dropout = args.dropout).cuda() + + optimizer = torch.optim.Adam( + list(model_s.parameters()) + list(model_t.parameters()), + lr=args.lr, + weight_decay=args.wd + ) + + best_val = None + final = None + anchor = None + for epoch in range(args.epochs): + train_loss = train(node_feats, node_types, adjs_pt, pos_train, neg_train, model_s, model_t, optimizer) + val_loss, auc_val, auc_test = infer(node_feats, node_types, adjs_pt, pos_val, neg_val, pos_test, neg_test, model_s, model_t) + logging.info("Epoch {}; Train err {}; Val err {}; Val auc {}".format(epoch + 1, train_loss, val_loss, auc_val)) + if best_val is None or auc_val > best_val: + best_val = auc_val + final = auc_test + anchor = epoch + 1 + logging.info("Best val auc {} at epoch {}; Test auc {}".format(best_val, anchor, final)) + +def train(node_feats, node_types, adjs, pos_train, neg_train, model_s, model_t, optimizer): + + model_s.train() + model_t.train() + optimizer.zero_grad() + out_s = model_s(node_feats, node_types, adjs, archs[args.dataset]["source"][0], archs[args.dataset]["source"][1]) + out_t = model_t(node_feats, node_types, adjs, archs[args.dataset]["target"][0], archs[args.dataset]["target"][1]) + loss = - torch.mean(F.logsigmoid(torch.mul(out_s[pos_train[:, 0]], out_t[pos_train[:, 1]]).sum(dim=-1)) + \ + F.logsigmoid(- torch.mul(out_s[neg_train[:, 0]], out_t[neg_train[:, 1]]).sum(dim=-1))) + loss.backward() + optimizer.step() + return loss.item() + +def infer(node_feats, node_types, adjs, pos_val, neg_val, pos_test, neg_test, model_s, model_t): + + model_s.eval() + model_t.eval() + with torch.no_grad(): + out_s = model_s(node_feats, node_types, adjs, archs[args.dataset]["source"][0], archs[args.dataset]["source"][1]) + out_t = model_t(node_feats, node_types, adjs, archs[args.dataset]["target"][0], archs[args.dataset]["target"][1]) + + #* validation performance + pos_val_prod = torch.mul(out_s[pos_val[:, 0]], out_t[pos_val[:, 1]]).sum(dim=-1) + neg_val_prod = torch.mul(out_s[neg_val[:, 0]], out_t[neg_val[:, 1]]).sum(dim=-1) + loss = - torch.mean(F.logsigmoid(pos_val_prod) + F.logsigmoid(- neg_val_prod)) + + y_true_val = np.zeros((pos_val.shape[0] + neg_val.shape[0]), dtype=np.long) + y_true_val[:pos_val.shape[0]] = 1 + y_pred_val = np.concatenate((torch.sigmoid(pos_val_prod).cpu().numpy(), torch.sigmoid(neg_val_prod).cpu().numpy())) + auc_val = roc_auc_score(y_true_val, y_pred_val) + + #* test performance + pos_test_prod = torch.mul(out_s[pos_test[:, 0]], out_t[pos_test[:, 1]]).sum(dim=-1) + neg_test_prod = torch.mul(out_s[neg_test[:, 0]], out_t[neg_test[:, 1]]).sum(dim=-1) + + y_true_test = np.zeros((pos_test.shape[0] + neg_test.shape[0]), dtype=np.long) + y_true_test[:pos_test.shape[0]] = 1 + y_pred_test = np.concatenate((torch.sigmoid(pos_test_prod).cpu().numpy(), torch.sigmoid(neg_test_prod).cpu().numpy())) + auc_test = roc_auc_score(y_true_test, y_pred_test) + + return loss.item(), auc_val, auc_test + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/lp/train_search.py b/lp/train_search.py new file mode 100644 index 0000000..2a8639b --- /dev/null +++ b/lp/train_search.py @@ -0,0 +1,141 @@ +import os +import sys +import time +import numpy as np +import pickle +import scipy.sparse as sp +import logging +import argparse +import torch +import torch.nn.functional as F +import time + +from model_search import Model +from preprocess import normalize_sym, normalize_row, sparse_mx_to_torch_sparse_tensor +from preprocess import cstr_source, cstr_target + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', type=float, default=0.005, help='learning rate') +parser.add_argument('--wd', type=float, default=0.001, help='weight decay') +parser.add_argument('--n_hid', type=int, default=64, help='hidden dimension') +parser.add_argument('--alr', type=float, default=3e-4, help='learning rate for architecture parameters') +parser.add_argument('--steps_s', type=int, nargs='+', help='number of intermediate states in the meta graph for source node type') +parser.add_argument('--steps_t', type=int, nargs='+', help='number of intermediate states in the meta graph for target node type') +parser.add_argument('--dataset', type=str, default='Yelp') +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--epochs', type=int, default=100) +parser.add_argument('--eps', type=float, default=0., help='probability of random sampling') +parser.add_argument('--decay', type=float, default=0.9, help='decay factor for eps') +parser.add_argument('--seed', type=int, default=0) +args = parser.parse_args() + +prefix = "lr" + str(args.lr) + "_wd" + str(args.wd) + \ + "_h" + str(args.n_hid) + "_alr" + str(args.alr) + \ + "_s" + str(args.steps_s) + "_t" + str(args.steps_t) + "_epoch" + str(args.epochs) + \ + "_cuda" + str(args.gpu) + "_eps" + str(args.eps) + "_d" + str(args.decay) + +logdir = os.path.join("log/search", args.dataset) +if not os.path.exists(logdir): + os.makedirs(logdir) + +log_format = '%(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format) +fh = logging.FileHandler(os.path.join(logdir, prefix + ".txt")) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + +def main(): + + torch.cuda.set_device(args.gpu) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + datadir = "preprocessed" + prefix = os.path.join(datadir, args.dataset) + + #* load data + node_types = np.load(os.path.join(prefix, "node_types.npy")) + num_node_types = node_types.max() + 1 + node_types = torch.from_numpy(node_types).cuda() + + adjs_offset = pickle.load(open(os.path.join(prefix, "adjs_offset.pkl"), "rb")) + adjs_pt = [] + if '0' in adjs_offset: + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_sym(adjs_offset['0'] + sp.eye(adjs_offset['0'].shape[0], dtype=np.float32))).cuda()) + for i in range(1, int(max(adjs_offset.keys())) + 1): + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_row(adjs_offset[str(i)] + sp.eye(adjs_offset[str(i)].shape[0], dtype=np.float32))).cuda()) + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_row(adjs_offset[str(i)].T + sp.eye(adjs_offset[str(i)].shape[0], dtype=np.float32))).cuda()) + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(sp.eye(adjs_offset['1'].shape[0], dtype=np.float32).tocoo()).cuda()) + adjs_pt.append(torch.sparse.FloatTensor(size=adjs_offset['1'].shape).cuda()) + print("Loading {} adjs...".format(len(adjs_pt))) + + #* load labels + pos = np.load(os.path.join(prefix, "pos_pairs_offset.npz")) + pos_train = pos['train'] + pos_val = pos['val'] + pos_test = pos['test'] + + neg = np.load(os.path.join(prefix, "neg_pairs_offset.npz")) + neg_train = neg['train'] + neg_val = neg['val'] + neg_test = neg['test'] + + #* inputs + in_dims = [] + node_feats = [] + for k in range(num_node_types): + in_dims.append((node_types == k).sum().item()) + i = torch.stack((torch.arange(in_dims[-1], dtype=torch.long), torch.arange(in_dims[-1], dtype=torch.long))) + v = torch.ones(in_dims[-1]) + node_feats.append(torch.sparse.FloatTensor(i, v, torch.Size([in_dims[-1], in_dims[-1]])).cuda()) + assert(len(in_dims) == len(node_feats)) + + model_s = Model(in_dims, args.n_hid, len(adjs_pt), args.steps_s, cstr_source[args.dataset]).cuda() + model_t = Model(in_dims, args.n_hid, len(adjs_pt), args.steps_t, cstr_target[args.dataset]).cuda() + + optimizer_w = torch.optim.Adam( + list(model_s.parameters()) + list(model_t.parameters()), + lr=args.lr, + weight_decay=args.wd + ) + + optimizer_a = torch.optim.Adam( + model_s.alphas() + model_t.alphas(), + lr=args.alr + ) + + eps = args.eps + start_t = time.time() + for epoch in range(args.epochs): + train_error, val_error = train(node_feats, node_types, adjs_pt, pos_train, neg_train, pos_val, neg_val, model_s, model_t, optimizer_w, optimizer_a, eps) + logging.info("Epoch {}; Train err {}; Val err {}; Source arch {}; Target arch {}".format(epoch + 1, train_error, val_error, model_s.parse(), model_t.parse())) + eps = eps * args.decay + end_t = time.time() + print("Search time (in minutes): {}".format((end_t - start_t) / 60)) + +def train(node_feats, node_types, adjs, pos_train, neg_train, pos_val, neg_val, model_s, model_t, optimizer_w, optimizer_a, eps): + + idxes_seq_s, idxes_res_s = model_s.sample(eps) + idxes_seq_t, idxes_res_t = model_t.sample(eps) + + optimizer_w.zero_grad() + out_s = model_s(node_feats, node_types, adjs, idxes_seq_s, idxes_res_s) + out_t = model_t(node_feats, node_types, adjs, idxes_seq_t, idxes_res_t) + loss_w = - torch.mean(F.logsigmoid(torch.mul(out_s[pos_train[:, 0]], out_t[pos_train[:, 1]]).sum(dim=-1)) + \ + F.logsigmoid(- torch.mul(out_s[neg_train[:, 0]], out_t[neg_train[:, 1]]).sum(dim=-1))) + loss_w.backward() + optimizer_w.step() + + optimizer_a.zero_grad() + out_s = model_s(node_feats, node_types, adjs, idxes_seq_s, idxes_res_s) + out_t = model_t(node_feats, node_types, adjs, idxes_seq_t, idxes_res_t) + loss_a = - torch.mean(F.logsigmoid(torch.mul(out_s[pos_val[:, 0]], out_t[pos_val[:, 1]]).sum(dim=-1)) + \ + F.logsigmoid(- torch.mul(out_s[neg_val[:, 0]], out_t[neg_val[:, 1]]).sum(dim=-1))) + loss_a.backward() + optimizer_a.step() + + return loss_w.item(), loss_a.item() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/nc/arch.py b/nc/arch.py new file mode 100644 index 0000000..b8a83a6 --- /dev/null +++ b/nc/arch.py @@ -0,0 +1,5 @@ +archs = { + "DBLP" : ([[4, 2, 1, 1]], [[5, 5, 2, 5, 5, 1]]), + "ACM" : ([[3, 4, 0, 0]], [[1, 5, 5, 5, 2, 0]]), + "IMDB" : ([[0, 1, 3, 2]], [[5, 5, 5, 5, 5, 5]]) +} \ No newline at end of file diff --git a/nc/model_search.py b/nc/model_search.py new file mode 100644 index 0000000..350a61e --- /dev/null +++ b/nc/model_search.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Op(nn.Module): + + def __init__(self): + super(Op, self).__init__() + + def forward(self, x, adjs, ws, idx): + #assert(ws.size(0) == len(adjs)) + return ws[idx] * torch.spmm(adjs[idx], x) + +class Cell(nn.Module): + + def __init__(self, n_step, n_hid_prev, n_hid, cstr, use_norm = True, use_nl = True): + super(Cell, self).__init__() + + self.affine = nn.Linear(n_hid_prev, n_hid) + self.n_step = n_step + self.norm = nn.LayerNorm(n_hid, elementwise_affine = False) if use_norm is True else lambda x : x + self.use_nl = use_nl + assert(isinstance(cstr, list)) + self.cstr = cstr + + self.ops_seq = nn.ModuleList() ##! exclude last step + for i in range(self.n_step - 1): + self.ops_seq.append(Op()) + self.ops_res = nn.ModuleList() ##! exclude last step + for i in range(1, self.n_step - 1): + for j in range(i): + self.ops_res.append(Op()) + + self.last_seq = Op() + self.last_res = nn.ModuleList() + for i in range(self.n_step - 1): + self.last_res.append(Op()) + + + def forward(self, x, adjs, ws_seq, idxes_seq, ws_res, idxes_res): + #assert(isinstance(ws_seq, list)) + #assert(len(ws_seq) == 2) + x = self.affine(x) + states = [x] + offset = 0 + for i in range(self.n_step - 1): + seqi = self.ops_seq[i](states[i], adjs[:-1], ws_seq[0][i], idxes_seq[0][i]) #! exclude zero Op + resi = sum(self.ops_res[offset + j](h, adjs, ws_res[0][offset + j], idxes_res[0][offset + j]) for j, h in enumerate(states[:i])) + offset += i + states.append(seqi + resi) + #assert(offset == len(self.ops_res)) + + adjs_cstr = [adjs[i] for i in self.cstr] + out_seq = self.last_seq(states[-1], adjs_cstr, ws_seq[1], idxes_seq[1]) + adjs_cstr.append(adjs[-1]) + out_res = sum(self.last_res[i](h, adjs_cstr, ws_res[1][i], idxes_res[1][i]) for i, h in enumerate(states[:-1])) + output = self.norm(out_seq + out_res) + if self.use_nl: + output = F.gelu(output) + return output + + +class Model(nn.Module): + + def __init__(self, in_dim, n_hid, num_node_types, n_adjs, n_classes, n_steps, cstr, attn_dim = 64, use_norm = True, out_nl = True): + super(Model, self).__init__() + self.num_node_types = num_node_types + self.cstr = cstr + self.n_adjs = n_adjs + self.n_hid = n_hid + self.ws = nn.ModuleList() + for i in range(num_node_types): + self.ws.append(nn.Linear(in_dim, n_hid)) + assert(isinstance(n_steps, list)) + self.metas = nn.ModuleList() + for i in range(len(n_steps)): + self.metas.append(Cell(n_steps[i], n_hid, n_hid, cstr, use_norm = use_norm, use_nl = out_nl)) + + self.as_seq = [] + self.as_last_seq = [] + for i in range(len(n_steps)): + if n_steps[i] > 1: + ai = 1e-3 * torch.randn(n_steps[i] - 1, n_adjs - 1) #! exclude zero Op + ai = ai.cuda() + ai.requires_grad_(True) + self.as_seq.append(ai) + else: + self.as_seq.append(None) + ai_last = 1e-3 * torch.randn(len(cstr)) + ai_last = ai_last.cuda() + ai_last.requires_grad_(True) + self.as_last_seq.append(ai_last) + + ks = [sum(1 for i in range(1, n_steps[k] - 1) for j in range(i)) for k in range(len(n_steps))] + self.as_res = [] + self.as_last_res = [] + for i in range(len(n_steps)): + if ks[i] > 0: + ai = 1e-3 * torch.randn(ks[i], n_adjs) + ai = ai.cuda() + ai.requires_grad_(True) + self.as_res.append(ai) + else: + self.as_res.append(None) + + if n_steps[i] > 1: + ai_last = 1e-3 * torch.randn(n_steps[i] - 1, len(cstr) + 1) + ai_last = ai_last.cuda() + ai_last.requires_grad_(True) + self.as_last_res.append(ai_last) + else: + self.as_last_res.append(None) + + assert(ks[0] + n_steps[0] + (0 if self.as_last_res[0] is None else self.as_last_res[0].size(0)) == (1 + n_steps[0]) * n_steps[0] // 2) + + #* [Optional] Combine more than one meta graph? + self.attn_fc1 = nn.Linear(n_hid, attn_dim) + self.attn_fc2 = nn.Linear(attn_dim, 1) + + #* node classification + self.classifier = nn.Linear(n_hid, n_classes) + + def alphas(self): + alphas = [] + for each in self.as_seq: + if each is not None: + alphas.append(each) + for each in self.as_last_seq: + alphas.append(each) + for each in self.as_res: + if each is not None: + alphas.append(each) + for each in self.as_last_res: + if each is not None: + alphas.append(each) + return alphas + + def sample(self, eps): + idxes_seq = [] + idxes_res = [] + if np.random.uniform() < eps: + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_seq[i] is None else torch.randint(low=0, high=self.as_seq[i].size(-1), size=self.as_seq[i].size()[:-1]).cuda()) + temp.append(torch.randint(low=0, high=self.as_last_seq[i].size(-1), size=(1,)).cuda()) + idxes_seq.append(temp) + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_res[i] is None else torch.randint(low=0, high=self.as_res[i].size(-1), size=self.as_res[i].size()[:-1]).cuda()) + temp.append(None if self.as_last_res[i] is None else torch.randint(low=0, high=self.as_last_res[i].size(-1), size=self.as_last_res[i].size()[:-1]).cuda()) + idxes_res.append(temp) + else: + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_seq[i] is None else torch.argmax(F.softmax(self.as_seq[i], dim=-1), dim=-1)) + temp.append(torch.argmax(F.softmax(self.as_last_seq[i], dim=-1), dim=-1)) + idxes_seq.append(temp) + for i in range(len(self.metas)): + temp = [] + temp.append(None if self.as_res[i] is None else torch.argmax(F.softmax(self.as_res[i], dim=-1), dim=-1)) + temp.append(None if self.as_last_res[i] is None else torch.argmax(F.softmax(self.as_last_res[i], dim=-1), dim=-1)) + idxes_res.append(temp) + return idxes_seq, idxes_res + + def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res): + hid = torch.zeros((node_types.size(0), self.n_hid)).cuda() + for i in range(self.num_node_types): + idx = (node_types == i) + hid[idx] = self.ws[i](node_feats[idx]) + temps = []; attns = [] + for i, meta in enumerate(self.metas): + ws_seq = [] + ws_seq.append(None if self.as_seq[i] is None else F.softmax(self.as_seq[i], dim=-1)) + ws_seq.append(F.softmax(self.as_last_seq[i], dim=-1)) + ws_res = [] + ws_res.append(None if self.as_res[i] is None else F.softmax(self.as_res[i], dim=-1)) + ws_res.append(None if self.as_last_res[i] is None else F.softmax(self.as_last_res[i], dim=-1)) + hidi = meta(hid, adjs, ws_seq, idxes_seq[i], ws_res, idxes_res[i]) + temps.append(hidi) + attni = self.attn_fc2(torch.tanh(self.attn_fc1(temps[-1]))) + attns.append(attni) + + hids = torch.stack(temps, dim=0).transpose(0, 1) + attns = F.softmax(torch.cat(attns, dim=-1), dim=-1) + out = (attns.unsqueeze(dim=-1) * hids).sum(dim=1) + logits = self.classifier(out) + return logits + + def parse(self): + idxes_seq, idxes_res = self.sample(0.) + msg_seq = []; msg_res = [] + for i in range(len(idxes_seq)): + map_seq = [self.cstr[idxes_seq[i][1].item()]] + msg_seq.append(map_seq if idxes_seq[i][0] is None else idxes_seq[i][0].tolist() + map_seq) + assert(len(msg_seq[i]) == self.metas[i].n_step) + + temp_res = [] + if idxes_res[i][1] is not None: + for item in idxes_res[i][1].tolist(): + if item < len(self.cstr): + temp_res.append(self.cstr[item]) + else: + assert(item == len(self.cstr)) + temp_res.append(self.n_adjs - 1) + if idxes_res[i][0] is not None: + temp_res = idxes_res[i][0].tolist() + temp_res + assert(len(temp_res) == self.metas[i].n_step * (self.metas[i].n_step - 1) // 2) + msg_res.append(temp_res) + + return msg_seq, msg_res \ No newline at end of file diff --git a/nc/preprocess.py b/nc/preprocess.py index 1b93af7..3af0204 100644 --- a/nc/preprocess.py +++ b/nc/preprocess.py @@ -3,6 +3,7 @@ import numpy as np import torch import pickle as pkl +import scipy.sparse as sp cstr_nc = { "DBLP" : [1, 4], @@ -10,6 +11,31 @@ "IMDB" : [0, 2, 4] } +def normalize_sym(adj): + """Symmetrically normalize adjacency matrix.""" + rowsum = np.array(adj.sum(1)) + d_inv_sqrt = np.power(rowsum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() + +def normalize_row(mx): + """Row-normalize sparse matrix""" + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + return mx.tocoo() + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + indices = torch.from_numpy( + np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + values = torch.from_numpy(sparse_mx.data) + shape = torch.Size(sparse_mx.shape) + return torch.sparse.FloatTensor(indices, values, shape) + def main(dataset): prefix = os.path.join("./data/", dataset) with open(os.path.join(prefix, "edges.pkl"), "rb") as f: diff --git a/nc/train_search.py b/nc/train_search.py new file mode 100644 index 0000000..7eba96e --- /dev/null +++ b/nc/train_search.py @@ -0,0 +1,127 @@ +import os +import sys +import time +import numpy as np +import pickle +import scipy.sparse as sp +import logging +import argparse +import torch +import torch.nn.functional as F + +from model_search import Model +from preprocess import normalize_sym, normalize_row, sparse_mx_to_torch_sparse_tensor +from preprocess import cstr_nc + +parser = argparse.ArgumentParser() +parser.add_argument('--lr', type=float, default=0.005, help='learning rate') +parser.add_argument('--wd', type=float, default=0.001, help='weight decay') +parser.add_argument('--n_hid', type=int, default=64, help='hidden dimension') +parser.add_argument('--alr', type=float, default=3e-4, help='learning rate for architecture parameters') +parser.add_argument('--steps', type=int, nargs='+', help='number of intermediate states in the meta graph') +parser.add_argument('--dataset', type=str, default='DBLP') +parser.add_argument('--gpu', type=int, default=0) +parser.add_argument('--epochs', type=int, default=50) +parser.add_argument('--eps', type=float, default=0.3, help='probability of random sampling') +parser.add_argument('--decay', type=float, default=0.9, help='decay factor for eps') +parser.add_argument('--seed', type=int, default=0) +args = parser.parse_args() + +prefix = "lr" + str(args.lr) + "_wd" + str(args.wd) + \ + "_h" + str(args.n_hid) + "_alr" + str(args.alr) + \ + "_s" + str(args.steps) + "_epoch" + str(args.epochs) + \ + "_cuda" + str(args.gpu) + "_eps" + str(args.eps) + "_d" + str(args.decay) + +logdir = os.path.join("log/search", args.dataset) +if not os.path.exists(logdir): + os.makedirs(logdir) + +log_format = '%(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format) +fh = logging.FileHandler(os.path.join(logdir, prefix + ".txt")) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + +def main(): + + torch.cuda.set_device(args.gpu) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + datadir = "data" + prefix = os.path.join(datadir, args.dataset) + + #* load data + with open(os.path.join(prefix, "node_features.pkl"), "rb") as f: + node_feats = pickle.load(f) + f.close() + node_feats = torch.from_numpy(node_feats.astype(np.float32)).cuda() + + node_types = np.load(os.path.join(prefix, "node_types.npy")) + num_node_types = node_types.max() + 1 + node_types = torch.from_numpy(node_types).cuda() + + with open(os.path.join(prefix, "edges.pkl"), "rb") as f: + edges = pickle.load(f) + f.close() + + adjs_pt = [] + for mx in edges: + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(normalize_row(mx.astype(np.float32) + sp.eye(mx.shape[0], dtype=np.float32))).cuda()) + adjs_pt.append(sparse_mx_to_torch_sparse_tensor(sp.eye(edges[0].shape[0], dtype=np.float32).tocoo()).cuda()) + adjs_pt.append(torch.sparse.FloatTensor(size=edges[0].shape).cuda()) + print("Loading {} adjs...".format(len(adjs_pt))) + + #* load labels + with open(os.path.join(prefix, "labels.pkl"), "rb") as f: + labels = pickle.load(f) + f.close() + + train_idx = torch.from_numpy(np.array(labels[0])[:, 0]).type(torch.long).cuda() + train_target = torch.from_numpy(np.array(labels[0])[:, 1]).type(torch.long).cuda() + valid_idx = torch.from_numpy(np.array(labels[1])[:, 0]).type(torch.long).cuda() + valid_target = torch.from_numpy(np.array(labels[1])[:, 1]).type(torch.long).cuda() + + n_classes = train_target.max().item() + 1 + print("Number of classes: {}".format(n_classes), "Number of node types: {}".format(num_node_types)) + + model = Model(node_feats.size(1), args.n_hid, num_node_types, len(adjs_pt), n_classes, args.steps, cstr_nc[args.dataset]).cuda() + + optimizer_w = torch.optim.Adam( + model.parameters(), + lr=args.lr, + weight_decay=args.wd + ) + + optimizer_a = torch.optim.Adam( + model.alphas(), + lr=args.alr + ) + + eps = args.eps + for epoch in range(args.epochs): + train_error, val_error = train(node_feats, node_types, adjs_pt, train_idx, train_target, valid_idx, valid_target, model, optimizer_w, optimizer_a, eps) + logging.info("Epoch {}; Train err {}; Val err {}; Arch {}".format(epoch + 1, train_error, val_error, model.parse())) + eps = eps * args.decay + +def train(node_feats, node_types, adjs, train_idx, train_target, valid_idx, valid_target, model, optimizer_w, optimizer_a, eps): + + idxes_seq, idxes_res = model.sample(eps) + + optimizer_w.zero_grad() + out = model(node_feats, node_types, adjs, idxes_seq, idxes_res) + loss_w = F.cross_entropy(out[train_idx], train_target) + loss_w.backward() + optimizer_w.step() + + optimizer_a.zero_grad() + out = model(node_feats, node_types, adjs, idxes_seq, idxes_res) + loss_a = F.cross_entropy(out[valid_idx], valid_target) + loss_a.backward() + optimizer_a.step() + + return loss_w.item(), loss_a.item() + +if __name__ == '__main__': + main() \ No newline at end of file