diff --git a/helpers.py b/helpers.py index f59c883..4865be1 100644 --- a/helpers.py +++ b/helpers.py @@ -14,7 +14,8 @@ def set_seeds(seed=0): np.random.seed(seed) _ = torch.manual_seed(seed) - _ = torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + _ = torch.cuda.manual_seed(seed) def to_numpy(x): diff --git a/utils/convert.py b/utils/convert.py index c44dbc5..a301a23 100644 --- a/utils/convert.py +++ b/utils/convert.py @@ -56,6 +56,10 @@ def save_problem(problem, outpath): assert validate_problem(problem) assert not os.path.exists(outpath), 'save_problem: %s already exists' % outpath + if 'sparse' in problem and problem['sparse']: + problem['adj'] = spadj2edgelist(problem['adj']) + problem['train_adj'] = spadj2edgelist(problem['train_adj']) + f = h5py.File(outpath) for k,v in problem.items(): if v is not None: @@ -212,11 +216,11 @@ def parse_args(): # "n_classes" : n_classes, # "sparse" : True, - # "adj" : spadj2edgelist(adj), - # "train_adj" : spadj2edgelist(train_adj), + # "adj" : adj, + # "train_adj" : train_adj, # "feats" : aug_feats, # "targets" : aug_targets, # "folds" : aug_folds, # }, './data/reddit/sparse-problem.h5') - # # << \ No newline at end of file + # # <<