Skip to content

Commit

Permalink
Fix handling of datanode in models, updated methods gettinng attribut…
Browse files Browse the repository at this point in the history
…es from Datanode in skeleton mode, usage of skeleton in metrics
  • Loading branch information
Andrzej Uszok committed Jul 19, 2023
1 parent 7b74120 commit 986c6ca
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"request": "launch",
"program": "${file}",
"cwd": "${fileDirname}",
"console": "integratedTerminal"
"console": "integratedTerminal",
"justMyCode": false
}
]
}
2 changes: 1 addition & 1 deletion Tasks
38 changes: 30 additions & 8 deletions domiknows/graph/dataNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,11 @@ def hasAttribute(self, key):
return True
elif keyInVariableSet in rootDataNode.attributes["propertySet"]:
return True
elif "variableSet" in self.attributes:
if key in self.attributes["variableSet"]:
return True
elif key in self.attributes["propertySet"]:
return True
else:
return False

Expand Down Expand Up @@ -288,14 +293,21 @@ def getAttribute(self, *keys):
return self.attributes[keyBis]
else:
return self.attributes[keyBis][index]
elif "rootDataNode" in self.attributes:
rootDataNode = self.attributes["rootDataNode"]
if "variableSet" in rootDataNode.attributes:
elif "rootDataNode" in self.attributes or "variableSet" in self.attributes:
if "rootDataNode" in self.attributes:
rootDataNode = self.attributes["rootDataNode"]
keyInVariableSet = self.ontologyNode.name + "/" + key
if keyInVariableSet in rootDataNode.attributes["variableSet"]:
return rootDataNode.attributes["variableSet"][keyInVariableSet][self.instanceID]
elif keyInVariableSet in rootDataNode.attributes["propertySet"]:
return rootDataNode.attributes["propertySet"][keyInVariableSet][self.instanceID]

if "variableSet" in rootDataNode.attributes:
if keyInVariableSet in rootDataNode.attributes["variableSet"]:
return rootDataNode.attributes["variableSet"][keyInVariableSet][self.instanceID]
elif keyInVariableSet in rootDataNode.attributes["propertySet"]:
return rootDataNode.attributes["propertySet"][keyInVariableSet][self.instanceID]
elif "variableSet" in self.attributes:
if key in self.attributes["variableSet"]:
return self.attributes["variableSet"][key]
elif key in self.attributes["propertySet"]:
return self.attributes["propertySet"][key]

return None

Expand Down Expand Up @@ -972,6 +984,15 @@ def collectInferredResults(self, concept, inferKey):

if not rootConceptDns:
return torch.tensor(collectAttributeList)

if getDnSkeletonMode() and "variableSet" in self.attributes:
vKeyInVariableSet = rootConcept.name + "/<" + concept[0].name +">"

# inferKey
inferKeyInVariableSet = vKeyInVariableSet + "/" + inferKey

if self.hasAttribute(inferKeyInVariableSet):
return self.getAttribute(inferKeyInVariableSet)

keys = [concept, inferKey]

Expand Down Expand Up @@ -1150,7 +1171,8 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if needSoftmax:
localSoftmaxKeyInVariableSet = vKeyInVariableSet + "/local/softmax"

inferLocalKeys.remove("softmax")
if "softmax" in inferLocalKeys:
inferLocalKeys.remove("softmax")

if not self.hasAttribute(localSoftmaxKeyInVariableSet):
v = self.attributes["variableSet"][vKeyInVariableSet]
Expand Down
26 changes: 20 additions & 6 deletions domiknows/program/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@


class CMWithLogitsMetric(torch.nn.Module):
def forward(self, input, target, data_item, prop, weight=None):
def forward(self, input, target, _, prop, weight=None):
if weight is None:
weight = torch.tensor(1, device=input.device)
else:
weight = weight.to(input.device)

preds = input.argmax(dim=-1).clone().detach().to(dtype=weight.dtype)
labels = target.clone().detach().to(dtype=weight.dtype, device=input.device)
tp = (preds * labels * weight).sum()
fp = (preds * (1 - labels) * weight).sum()
tn = ((1 - preds) * (1 - labels) * weight).sum()
fn = ((1 - preds) * labels * weight).sum()

return {'TP': tp, 'FP': fp, 'TN': tn, 'FN': fn}


Expand Down Expand Up @@ -68,8 +70,7 @@ def __init__(self, inferType='ILP'):
self.inferType = inferType

def forward(self, input, target, data_item, prop, weight=None):
data_item.createBatchRootDN()
datanode = data_item.getDataNode(context=self.inferType)
datanode = data_item
result = datanode.getInferMetrics(prop.name, inferType=self.inferType)
if len(result.keys())==2:
if str(prop.name) in result:
Expand Down Expand Up @@ -162,6 +163,9 @@ def __str__(self):
return str(value)

class MacroAverageTracker(MetricTracker):
def __init__(self, metric):
super().__init__(metric)

def forward(self, values):
def func(value):
return value.clone().detach().mean()
Expand Down Expand Up @@ -218,14 +222,24 @@ def forward(self, values):
else:
tn = CM['TN'].sum().float()

# check if tp, fp, fn, tn are tensors if not make them tensors
if not torch.is_tensor(tp):
tp = torch.tensor(tp)
if not torch.is_tensor(fp):
fp = torch.tensor(fp)
if not torch.is_tensor(fn):
fn = torch.tensor(fn)
if not torch.is_tensor(tn):
tn = torch.tensor(tn)

if tp:
p = tp / (tp + fp)
r = tp / (tp + fn)
f1 = 2 * p * r / (p + r)
else:
p = torch.zeros_like(torch.tensor(tp))
r = torch.zeros_like(torch.tensor(tp))
f1 = torch.zeros_like(torch.tensor(tp))
p = torch.zeros_like(tp)
r = torch.zeros_like(tp)
f1 = torch.zeros_like(tp)
if (tp + fp + fn + tn):
accuracy=(tp + tn) / (tp + fp + fn + tn)
return {'P': p, 'R': r, 'F1': f1,"accuracy":accuracy}
Expand Down
2 changes: 1 addition & 1 deletion domiknows/program/model/gbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, graph, solver_model=None, gbi_iters = 100, device='auto'):
nconstr = len(self.constr)
if nconstr == 0:
warnings.warn('No logical constraint detected in the graph. '
'PrimalDualModel will not generate any constraint loss.')
'GBIModel will not generate any constraint loss.')

self.lmbd = torch.nn.Parameter(torch.zeros(nconstr).float())
self.lmbd_index = {}
Expand Down
32 changes: 24 additions & 8 deletions domiknows/program/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,16 @@ def forward(self, data_item, build=None):
if build:
data_item.update({"graph": self.graph, 'READER': 0})
builder = DataNodeBuilder(data_item)
datanode, loss, metric = self.populate(builder)

out = self.populate(builder)

if len(out) == 2:
builder.createBatchRootDN()
datanode = builder.getDataNode(context="build", device=self.device)
loss = out[0]
metric = out[1]
else:
datanode, loss, metric = out

return (loss, metric, datanode, builder)
else:
*out, = self.populate(data_item)
Expand Down Expand Up @@ -165,31 +173,35 @@ def reset(self):
def poi_loss(self, data_item, _, sensors):
if not self.loss:
return 0

outs = [sensor(data_item) for sensor in sensors]

if len(outs[0]) == 0:
return None
local_loss = self.loss[(*sensors,)](*outs)

selfLoss = self.loss[(*sensors,)]
local_loss = selfLoss(*outs)

if local_loss != local_loss:
raise Exception("Calculated local_loss is nan")

return local_loss

def poi_metric(self, data_item, prop, sensors):
def poi_metric(self, data_item, prop, sensors, datanode=None):
if not self.metric:
return None
outs = [sensor(data_item) for sensor in sensors]
if len(outs[0]) == 0:
return None
local_metric = {}
for key, metric in self.metric.items():
local_metric[key] = metric[(*sensors,)](*outs, data_item=data_item, prop=prop)
local_metric[key] = metric[(*sensors,)](*outs, data_item=datanode, prop=prop)
if len(local_metric) == 1:
local_metric = list(local_metric.values())[0]

return local_metric

def populate(self, builder, run=True):
def populate(self, builder, datanode = None, run=True):
loss = 0
metric = {}

Expand All @@ -207,7 +219,11 @@ def populate(self, builder, run=True):
if local_loss is not None:
loss += local_loss
if self.metric:
local_metric = self.poi_metric(builder, prop, sensors)
if datanode is None:
builder.createBatchRootDN()
datanode = builder.getDataNode()

local_metric = self.poi_metric(builder, prop, sensors, datanode=datanode)
if local_metric is not None:
metric[(*sensors,)] = local_metric

Expand Down Expand Up @@ -273,7 +289,7 @@ def inference(self, builder):

def populate(self, builder, run=True):
datanode = self.inference(builder)
lose, metric = super().populate(builder, run=False)
lose, metric = super().populate(builder, datanode = datanode, run=False)

return datanode, lose, metric

Expand Down

0 comments on commit 986c6ca

Please sign in to comment.