Skip to content

Commit

Permalink
Enhancements for Node Classification
Browse files Browse the repository at this point in the history
Enhancements for Node Classification in DGL
  • Loading branch information
wassimj committed Aug 19, 2022
1 parent e2499fb commit 38f66b0
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 79 deletions.
6 changes: 5 additions & 1 deletion __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# * You should have received a copy of the GNU Affero General Public License
# * along with this program. If not, see <https://www.gnu.org/licenses/>.


bl_info = {
"name": "Topologic",
"author": "Wassim Jabi",
"version": (0, 8, 2, 2),
"version": "0.8.2.3",
"blender": (3, 2, 0),
"location": "Node Editor",
"category": "Node",
Expand Down Expand Up @@ -57,6 +58,9 @@
import topologic
from topologic import Vertex, Edge, Wire, Face, Shell, Cell, CellComplex, Cluster, Topology

__version__ = '0.8.2.3'
__version_info__ = tuple([ int(num) for num in __version__.split('.')])

#from topologicsverchok import icons
# make sverchok the root module name, (if sverchok dir not named exactly "sverchok")

Expand Down
Binary file modified examples/Machine Learning/DGL-Testing.blend
Binary file not shown.
Binary file modified examples/Machine Learning/DGL-Testing.blend1
Binary file not shown.
19 changes: 14 additions & 5 deletions nodes/Topologic/DGLDatasetBySamples_NC.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

from . import Replication

samples = [("Cora", "Cora", "", 1)]
samples = [("Cora", "Cora", "", 1), ("Citeseer", "Citeseer", "", 2), ("Pubmed", "Pubmed", "", 3)]

def processItem(item):
sample = item
if sample == 'Cora':
return dgl.data.CoraGraphDataset()
return [dgl.data.CoraGraphDataset(), 7]
elif sample == 'Citeseer':
return [dgl.data.CiteseerGraphDataset(), 6]
elif sample == 'Pubmed':
return [dgl.data.PubmedGraphDataset(), 3]
else:
raise NotImplementedError

Expand All @@ -31,16 +35,21 @@ def sv_init(self, context):
self.width=200
self.inputs.new('SvStringsSocket', 'Sample').prop_name="SamplesProp"
self.outputs.new('SvStringsSocket', 'DGL Dataset')
self.outputs.new('SvStringsSocket', 'Num Labels')

def process(self):
if not any(socket.is_linked for socket in self.outputs):
return
sampleList = self.inputs['Sample'].sv_get(deepcopy=True)
sampleList = Replication.flatten(sampleList)
outputs = []
datasets = []
numLabels = []
for anInput in sampleList:
outputs.append(processItem(anInput))
self.outputs['DGL Dataset'].sv_set(outputs)
dataset, numLabel = processItem(anInput)
datasets.append(dataset)
numLabels.append(numLabel)
self.outputs['DGL Dataset'].sv_set(datasets)
self.outputs['Num Labels'].sv_set(numLabels)

def register():
bpy.utils.register_class(SvDGLDatasetBySamples_NC)
Expand Down
173 changes: 124 additions & 49 deletions nodes/Topologic/DGLPredict_NC.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,53 @@
from sverchok.data_structure import updateNode

import torch
from . import Replication
from . import Replication, DGLDatasetGraphs_NC

def processItem(item):
"""
Parameters
----------
model_checkpoint_path : str
Path for the entire model
test_dataset : list
A list containing several dgl graphs for prediction
Returns
-------
Labels for the test graphs in test_dataset
"""
classifier, dataset = item
predictions = []
labels = []
if dataset.name == "cora_v2":
g = dataset[0]
else:
g = dataset[0][0]
features = g.ndata['feat']
labels.append(g.ndata['label'].tolist())
# Forward
logits = classifier(g, features)
# Compute prediction
pred = logits.argmax(1).tolist()
predictions.append(pred)
return [Replication.flatten(labels), Replication.flatten(predictions)]
allLabels = []
allPredictions = []
trainLabels = []
trainPredictions = []
valLabels = []
valPredictions = []
testLabels = []
testPredictions = []

graphs = DGLDatasetGraphs_NC.processItem(dataset)
for g in graphs:
if not g.ndata:
continue
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata['feat']
labels = g.ndata['label']
train_labels = labels[train_mask]
val_labels = labels[val_mask]
test_labels = labels[test_mask]
allLabels.append(labels.tolist())
trainLabels.append(train_labels.tolist())
valLabels.append(val_labels.tolist())
testLabels.append(test_labels.tolist())

# Forward
logits = classifier(g, features)
train_logits = logits[train_mask]
val_logits = logits[val_mask]
test_logits = logits[test_mask]

# Compute prediction
predictions = logits.argmax(1)
train_predictions = train_logits.argmax(1)
val_predictions = val_logits.argmax(1)
test_predictions = test_logits.argmax(1)
allPredictions.append(predictions.tolist())
trainPredictions.append(train_predictions.tolist())
valPredictions.append(val_predictions.tolist())
testPredictions.append(test_predictions.tolist())

return [Replication.flatten(allLabels), Replication.flatten(allPredictions),Replication.flatten(trainLabels), Replication.flatten(trainPredictions), Replication.flatten(valLabels), Replication.flatten(valPredictions), Replication.flatten(testLabels), Replication.flatten(testPredictions)]

replication = [("Default", "Default", "", 1),("Trim", "Trim", "", 2),("Iterate", "Iterate", "", 3),("Repeat", "Repeat", "", 4),("Interlace", "Interlace", "", 5)]

Expand All @@ -54,8 +69,15 @@ class SvDGLPredict_NC(bpy.types.Node, SverchCustomTreeNode):
def sv_init(self, context):
self.inputs.new('SvStringsSocket', 'Classifier')
self.inputs.new('SvStringsSocket', 'Dataset')
self.outputs.new('SvStringsSocket', 'Labels')
self.outputs.new('SvStringsSocket', 'Predictions')
self.outputs.new('SvStringsSocket', 'All Labels')
self.outputs.new('SvStringsSocket', 'All Predictions')
self.outputs.new('SvStringsSocket', 'Train Labels')
self.outputs.new('SvStringsSocket', 'Train Predictions')
self.outputs.new('SvStringsSocket', 'Val Labels')
self.outputs.new('SvStringsSocket', 'Val Predictions')
self.outputs.new('SvStringsSocket', 'Test Labels')
self.outputs.new('SvStringsSocket', 'Test Predictions')

self.width = 175
for socket in self.inputs:
if socket.prop_name != '':
Expand Down Expand Up @@ -83,32 +105,85 @@ def process(self):
inputs_nested.append(inp)
inputs_flat.append(Replication.flatten(inp))
inputs_replicated = Replication.replicateInputs(inputs_flat, self.Replication)
labels = []
predictions = []
allLabels = []
allPredictions = []
trainLabels = []
trainPredictions = []
valLabels = []
valPredictions = []
testLabels = []
testPredictions = []

for anInput in inputs_replicated:
label, prediction = processItem(anInput)
labels.append(label)
predictions.append(prediction)
all_labels, all_predictions, train_labels, train_predictions, val_labels, val_predictions, test_labels, test_predictions = processItem(anInput)
allLabels.append(all_labels)
allPredictions.append(all_predictions)
trainLabels.append(train_labels)
trainPredictions.append(train_predictions)
valLabels.append(val_labels)
valPredictions.append(val_predictions)
testLabels.append(test_labels)
testPredictions.append(test_predictions)

inputs_flat = []
for anInput in self.inputs:
inp = anInput.sv_get(deepcopy=True)
inputs_flat.append(Replication.flatten(inp))
if self.Replication == "Interlace":
labels = Replication.re_interlace(labels, inputs_flat)
predictions = Replication.re_interlace(predictions, inputs_flat)
allLabels = Replication.re_interlace(allLabels, inputs_flat)
allPredictions = Replication.re_interlace(allPredictions, inputs_flat)
trainLabels = Replication.re_interlace(trainLabels, inputs_flat)
trainPredictions = Replication.re_interlace(trainPredictions, inputs_flat)
valLabels = Replication.re_interlace(valLabels, inputs_flat)
valPredictions = Replication.re_interlace(valPredictions, inputs_flat)
testLabels = Replication.re_interlace(testLabels, inputs_flat)
testPredictions = Replication.re_interlace(testPredictions, inputs_flat)

else:
match_list = Replication.best_match(inputs_nested, inputs_flat, self.Replication)
labels = Replication.unflatten(labels, match_list)
predictions = Replication.unflatten(predictions, match_list)
if len(labels) == 1:
if isinstance(labels[0], list):
labels = labels[0]
if len(predictions) == 1:
if isinstance(predictions[0], list):
predictions = predictions[0]
self.outputs['Labels'].sv_set([labels])
self.outputs['Predictions'].sv_set([predictions])

allLabels = Replication.unflatten(allLabels, match_list)
allPredictions = Replication.unflatten(allPredictions, match_list)
trainLabels = Replication.unflatten(trainLabels, match_list)
trainPredictions = Replication.unflatten(trainPredictions, match_list)
valLabels = Replication.unflatten(valLabels, match_list)
valPredictions = Replication.unflatten(valPredictions, match_list)
testLabels = Replication.unflatten(testLabels, match_list)
testPredictions = Replication.unflatten(testPredictions, match_list)

if len(allLabels) == 1:
if isinstance(allLabels[0], list):
allLabels = allLabels[0]
self.outputs['All Labels'].sv_set([allLabels])
if len(allPredictions) == 1:
if isinstance(allPredictions[0], list):
allPredictions = allPredictions[0]
self.outputs['All Predictions'].sv_set([allPredictions])
if len(trainLabels) == 1:
if isinstance(trainLabels[0], list):
trainLabels = trainLabels[0]
self.outputs['Train Labels'].sv_set([trainLabels])
if len(trainPredictions) == 1:
if isinstance(trainPredictions[0], list):
trainPredictions = trainPredictions[0]
self.outputs['Train Predictions'].sv_set([trainPredictions])
if len(valLabels) == 1:
if isinstance(valLabels[0], list):
valLabels = valLabels[0]
self.outputs['Val Labels'].sv_set([valLabels])
if len(valPredictions) == 1:
if isinstance(valPredictions[0], list):
valPredictions = valPredictions[0]
self.outputs['Val Predictions'].sv_set([valPredictions])
if len(testLabels) == 1:
if isinstance(testLabels[0], list):
testLabels = testLabels[0]
self.outputs['Test Labels'].sv_set([testLabels])
if len(testPredictions) == 1:
if isinstance(testPredictions[0], list):
testPredictions = testPredictions[0]
self.outputs['Test Predictions'].sv_set([testPredictions])


def register():
bpy.utils.register_class(SvDGLPredict_NC)

Expand Down
Loading

0 comments on commit 38f66b0

Please sign in to comment.