Skip to content

Commit

Permalink
add more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
skeletondyh committed Jul 5, 2021
1 parent 55effbe commit 462802a
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 44 deletions.
4 changes: 4 additions & 0 deletions lp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ epoch number; training error; validation error; architecture for source node typ

To obtain a good architecture, we usually need to run the search algorithm several times with different random seeds.

## Architecture Interpretation

Similar to the node classification task. Please refer to README therein.

## Evaluation

Run the following commands to train the derived architectures from scratch:
Expand Down
48 changes: 29 additions & 19 deletions lp/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np

class Op(nn.Module):

'''
operation for one link in the DAG search space
'''
def __init__(self):
super(Op, self).__init__()

Expand All @@ -13,27 +15,29 @@ def forward(self, x, adjs, ws, idx):
return ws[idx] * torch.spmm(adjs[idx], x)

class Cell(nn.Module):

'''
the DAG search space
'''
def __init__(self, n_step, n_hid_prev, n_hid, cstr, use_norm = True, use_nl = True):
super(Cell, self).__init__()

self.affine = nn.Linear(n_hid_prev, n_hid)
self.n_step = n_step
self.n_step = n_step #* number of intermediate states (i.e., K)
self.norm = nn.LayerNorm(n_hid, elementwise_affine = False) if use_norm is True else lambda x : x
self.use_nl = use_nl
assert(isinstance(cstr, list))
self.cstr = cstr
self.cstr = cstr #* type constraint

self.ops_seq = nn.ModuleList() ##! exclude last step
for i in range(self.n_step - 1):
self.ops_seq = nn.ModuleList() #* state (i - 1) -> state i, 1 <= i < K
for i in range(1, self.n_step):
self.ops_seq.append(Op())
self.ops_res = nn.ModuleList() ##! exclude last step
for i in range(1, self.n_step - 1):
for j in range(i):
self.ops_res = nn.ModuleList() #* state j -> state i, 0 <= j < i - 1, 2 <= i < K
for i in range(2, self.n_step):
for j in range(i - 1):
self.ops_res.append(Op())

self.last_seq = Op()
self.last_res = nn.ModuleList()
self.last_seq = Op() #* state (K - 1) -> state K
self.last_res = nn.ModuleList() #* state i -> state K, 0 <= i < K - 1
for i in range(self.n_step - 1):
self.last_res.append(Op())

Expand Down Expand Up @@ -68,17 +72,17 @@ def __init__(self, in_dims, n_hid, n_adjs, n_steps, cstr, attn_dim = 64, use_nor
self.cstr = cstr
self.n_adjs = n_adjs
self.n_hid = n_hid
self.ws = nn.ModuleList()
self.ws = nn.ModuleList() #* node type-specific transformation
assert(isinstance(in_dims, list))
for i in range(len(in_dims)):
self.ws.append(nn.Linear(in_dims[i], n_hid))
assert(isinstance(n_steps, list))
assert(isinstance(n_steps, list)) #* [optional] combine more than one meta graph?
self.metas = nn.ModuleList()
for i in range(len(n_steps)):
self.metas.append(Cell(n_steps[i], n_hid, n_hid, cstr, use_norm = use_norm, use_nl = out_nl))

self.as_seq = []
self.as_last_seq = []
self.as_seq = [] #* arch parameters for ops_seq
self.as_last_seq = [] #* arch parameters for last_seq
for i in range(len(n_steps)):
if n_steps[i] > 1:
ai = 1e-3 * torch.randn(n_steps[i] - 1, n_adjs - 1) #! exclude zero Op
Expand All @@ -92,9 +96,9 @@ def __init__(self, in_dims, n_hid, n_adjs, n_steps, cstr, attn_dim = 64, use_nor
ai_last.requires_grad_(True)
self.as_last_seq.append(ai_last)

ks = [sum(1 for i in range(1, n_steps[k] - 1) for j in range(i)) for k in range(len(n_steps))]
self.as_res = []
self.as_last_res = []
ks = [sum(1 for i in range(2, n_steps[k]) for j in range(i - 1)) for k in range(len(n_steps))]
self.as_res = [] #* arch parameters for ops_res
self.as_last_res = [] #* arch parameters for last_res
for i in range(len(n_steps)):
if ks[i] > 0:
ai = 1e-3 * torch.randn(ks[i], n_adjs)
Expand All @@ -114,7 +118,7 @@ def __init__(self, in_dims, n_hid, n_adjs, n_steps, cstr, attn_dim = 64, use_nor

assert(ks[0] + n_steps[0] + (0 if self.as_last_res[0] is None else self.as_last_res[0].size(0)) == (1 + n_steps[0]) * n_steps[0] // 2)

#* [Optional] Combine more than one meta graph?
#* [optional] combine more than one meta graph?
self.attn_fc1 = nn.Linear(n_hid, attn_dim)
self.attn_fc2 = nn.Linear(attn_dim, 1)

Expand All @@ -134,6 +138,9 @@ def alphas(self):
return alphas

def sample(self, eps):
'''
to sample one candidate edge type per link
'''
idxes_seq = []
idxes_res = []
if np.random.uniform() < eps:
Expand Down Expand Up @@ -183,6 +190,9 @@ def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res):
return out

def parse(self):
'''
to derive a meta graph indicated by arch parameters
'''
idxes_seq, idxes_res = self.sample(0.)
msg_seq = []; msg_res = []
for i in range(len(idxes_seq)):
Expand Down
4 changes: 2 additions & 2 deletions lp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parser.add_argument('--n_hid', type=int, default=64, help='hidden dimension')
parser.add_argument('--dataset', type=str, default='Yelp')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--epochs', type=int, default=200, help='number of training epochs')
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--seed', type=int, default=1)
args = parser.parse_args()
Expand Down Expand Up @@ -79,7 +79,7 @@ def main():
neg_val = neg['val']
neg_test = neg['test']

#* inputs
#* one-hot IDs as input features
in_dims = []
node_feats = []
for k in range(num_node_types):
Expand Down
4 changes: 2 additions & 2 deletions lp/train_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
parser.add_argument('--steps_t', type=int, nargs='+', help='number of intermediate states in the meta graph for target node type')
parser.add_argument('--dataset', type=str, default='Yelp')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--epochs', type=int, default=100, help='number of epochs for supernet training')
parser.add_argument('--eps', type=float, default=0., help='probability of random sampling')
parser.add_argument('--decay', type=float, default=0.9, help='decay factor for eps')
parser.add_argument('--seed', type=int, default=0)
Expand Down Expand Up @@ -81,7 +81,7 @@ def main():
neg_val = neg['val']
neg_test = neg['test']

#* inputs
#* one-hot IDs as input features
in_dims = []
node_feats = []
for k in range(num_node_types):
Expand Down
4 changes: 4 additions & 0 deletions nc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ epoch number; training error; validation error; architecture derived at the end

To obtain a good architecture, we usually need to run the search algorithm several times with different random seeds.

## Architecture Interpretation

Suppose the number of intermediate states in a meta graph is K. An encoding of a meta graph consists of two lists. The first list is of length K and contains indexes of selected edge types between state (i - 1) and state i (1 <= i <= K). The second list is of length K(K-1)/2 and contains indexes of selected edge types between state j and state i (0 <= j < i - 1, 2 <= i <= K). You can refer to the function `parse` defined in `model_search.py` for how an encoding is obtained based on architecture parameters. For search and evaluation, the mapping between edge types and indexes should be consistent so that an encoding is able to be correctly recognized.

## Evaluation

Run the following commands to train the derived architectures from scratch:
Expand Down
48 changes: 29 additions & 19 deletions nc/model_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np

class Op(nn.Module):

'''
operation for one link in the DAG search space
'''
def __init__(self):
super(Op, self).__init__()

Expand All @@ -13,27 +15,29 @@ def forward(self, x, adjs, ws, idx):
return ws[idx] * torch.spmm(adjs[idx], x)

class Cell(nn.Module):

'''
the DAG search space
'''
def __init__(self, n_step, n_hid_prev, n_hid, cstr, use_norm = True, use_nl = True):
super(Cell, self).__init__()

self.affine = nn.Linear(n_hid_prev, n_hid)
self.n_step = n_step
self.n_step = n_step #* number of intermediate states (i.e., K)
self.norm = nn.LayerNorm(n_hid, elementwise_affine = False) if use_norm is True else lambda x : x
self.use_nl = use_nl
assert(isinstance(cstr, list))
self.cstr = cstr
self.cstr = cstr #* type constraint

self.ops_seq = nn.ModuleList() ##! exclude last step
for i in range(self.n_step - 1):
self.ops_seq = nn.ModuleList() #* state (i - 1) -> state i, 1 <= i < K
for i in range(1, self.n_step):
self.ops_seq.append(Op())
self.ops_res = nn.ModuleList() ##! exclude last step
for i in range(1, self.n_step - 1):
for j in range(i):
self.ops_res = nn.ModuleList() #* state j -> state i, 0 <= j < i - 1, 2 <= i < K
for i in range(2, self.n_step):
for j in range(i - 1):
self.ops_res.append(Op())

self.last_seq = Op()
self.last_res = nn.ModuleList()
self.last_seq = Op() #* state (K - 1) -> state K
self.last_res = nn.ModuleList() #* state i -> state K, 0 <= i < K - 1
for i in range(self.n_step - 1):
self.last_res.append(Op())

Expand Down Expand Up @@ -69,16 +73,16 @@ def __init__(self, in_dim, n_hid, num_node_types, n_adjs, n_classes, n_steps, cs
self.cstr = cstr
self.n_adjs = n_adjs
self.n_hid = n_hid
self.ws = nn.ModuleList()
self.ws = nn.ModuleList() #* node type-specific transformation
for i in range(num_node_types):
self.ws.append(nn.Linear(in_dim, n_hid))
assert(isinstance(n_steps, list))
assert(isinstance(n_steps, list)) #* [optional] combine more than one meta graph?
self.metas = nn.ModuleList()
for i in range(len(n_steps)):
self.metas.append(Cell(n_steps[i], n_hid, n_hid, cstr, use_norm = use_norm, use_nl = out_nl))

self.as_seq = []
self.as_last_seq = []
self.as_seq = [] #* arch parameters for ops_seq
self.as_last_seq = [] #* arch parameters for last_seq
for i in range(len(n_steps)):
if n_steps[i] > 1:
ai = 1e-3 * torch.randn(n_steps[i] - 1, n_adjs - 1) #! exclude zero Op
Expand All @@ -92,9 +96,9 @@ def __init__(self, in_dim, n_hid, num_node_types, n_adjs, n_classes, n_steps, cs
ai_last.requires_grad_(True)
self.as_last_seq.append(ai_last)

ks = [sum(1 for i in range(1, n_steps[k] - 1) for j in range(i)) for k in range(len(n_steps))]
self.as_res = []
self.as_last_res = []
ks = [sum(1 for i in range(2, n_steps[k]) for j in range(i - 1)) for k in range(len(n_steps))]
self.as_res = [] #* arch parameters for ops_res
self.as_last_res = [] #* arch parameters for last_res
for i in range(len(n_steps)):
if ks[i] > 0:
ai = 1e-3 * torch.randn(ks[i], n_adjs)
Expand All @@ -114,7 +118,7 @@ def __init__(self, in_dim, n_hid, num_node_types, n_adjs, n_classes, n_steps, cs

assert(ks[0] + n_steps[0] + (0 if self.as_last_res[0] is None else self.as_last_res[0].size(0)) == (1 + n_steps[0]) * n_steps[0] // 2)

#* [Optional] Combine more than one meta graph?
#* [optional] combine more than one meta graph?
self.attn_fc1 = nn.Linear(n_hid, attn_dim)
self.attn_fc2 = nn.Linear(attn_dim, 1)

Expand All @@ -137,6 +141,9 @@ def alphas(self):
return alphas

def sample(self, eps):
'''
to sample one candidate edge type per link
'''
idxes_seq = []
idxes_res = []
if np.random.uniform() < eps:
Expand Down Expand Up @@ -188,6 +195,9 @@ def forward(self, node_feats, node_types, adjs, idxes_seq, idxes_res):
return logits

def parse(self):
'''
to derive a meta graph indicated by arch parameters
'''
idxes_seq, idxes_res = self.sample(0.)
msg_seq = []; msg_res = []
for i in range(len(idxes_seq)):
Expand Down
2 changes: 1 addition & 1 deletion nc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parser.add_argument('--n_hid', type=int, default=64, help='hidden dimension')
parser.add_argument('--dataset', type=str, default='DBLP')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--epochs', type=int, default=50, help='maximum number of training epochs')
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--seed', type=int, default=24)
parser.add_argument('--no_norm', action='store_true', default=False, help='disable layer norm')
Expand Down
2 changes: 1 addition & 1 deletion nc/train_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
parser.add_argument('--steps', type=int, nargs='+', help='number of intermediate states in the meta graph')
parser.add_argument('--dataset', type=str, default='DBLP')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--epochs', type=int, default=50, help='number of epochs for supernet training')
parser.add_argument('--eps', type=float, default=0.3, help='probability of random sampling')
parser.add_argument('--decay', type=float, default=0.9, help='decay factor for eps')
parser.add_argument('--seed', type=int, default=0)
Expand Down

0 comments on commit 462802a

Please sign in to comment.