Skip to content

Commit

Permalink
Fixed sensor force option, updated GBI model, fixed lc verify, fixed …
Browse files Browse the repository at this point in the history
…infer method for addiotnal keys
  • Loading branch information
Andrzej Uszok committed Jul 12, 2023
1 parent f30da25 commit 8bdd83e
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Docs
16 changes: 8 additions & 8 deletions domiknows/graph/dataNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "normalizedProb" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProb"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

# Clamps the softmax probabilities
vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18)
Expand All @@ -1212,7 +1212,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "normalizedProbAcc" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProbAcc"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

# Clamps the softmax probabilities
vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18)
Expand All @@ -1237,7 +1237,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "entropyNormalizedProbAcc" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/entropyNormalizedProbAcc"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

# Clamps the softmax probabilities
vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18)
Expand All @@ -1262,7 +1262,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "normalizedJustAcc" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/normalizedJustAcc"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

### Calculate the multiplier factor
if Acc and c[0].name in Acc:
Expand All @@ -1283,7 +1283,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "meanNormalizedProb" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/meanNormalizedProb"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

vector = vSoftmaxT

Expand All @@ -1295,7 +1295,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "normalizedProbAll" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/normalizedProbAll"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

# Clamps the softmax probabilities
vector = torch.clamp(vSoftmaxT, min=1e-18, max=1 - 1e-18)
Expand All @@ -1316,7 +1316,7 @@ def inferLocal(self, keys=("softmax", "argmax"), Acc=None):
if "meanNormalizedProbStd" in keys:
keyNormalizedProb = "<" + c[0].name + ">/local/meanNormalizedProbStd"
if not dn.hasAttribute(keyNormalizedProb): # Already calculated ?
vSoftmaxT = dn.attributes[keySoftmax]
vSoftmaxT = dn.getAttribute(keySoftmax)

vector = vSoftmaxT

Expand Down Expand Up @@ -1381,7 +1381,7 @@ def inferGBIResults(self, *_conceptsRelations, model, builder):
_conceptsRelations = self.collectConceptsAndRelations(_conceptsRelations) # Collect all concepts and relations from graph as default set

from domiknows.program.model.gbi import GBIModel
myGBIModel = GBIModel(self.graph, model)
myGBIModel = GBIModel(self.graph, solver_model=model)
myGBIModel.calculateGBISelection(builder, _conceptsRelations)

# Calculate the percentage of results satisfying each logical constraint
Expand Down
2 changes: 1 addition & 1 deletion domiknows/program/batchprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, *args, batch_size=1, **kwargs):
self.batch_size = batch_size

def train_epoch(self, dataset):
# do not use super() because it call zero_grad for every step definitly
# do not use super() because it call zero_grad for every step definitely
self.model.mode(Mode.TRAIN)
self.model.reset()
self.opt.zero_grad()
Expand Down
10 changes: 5 additions & 5 deletions domiknows/program/lossprogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ def populate_epoch(self, dataset, grad = False):
yield detuple(*output[:1])
else:
for i, data_item in enumerate(dataset):
for dataKey in data_item:
if data_item[dataKey].dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128]:
data_item[dataKey].requires_grad= True

_, _, *output = self.model(data_item)
yield detuple(*output[:1])

Expand Down Expand Up @@ -376,10 +380,6 @@ class GBIProgram(LossProgram):
logger = logging.getLogger(__name__)

def __init__(self, graph, Model, poi, beta=1, **kwargs):
mySolverModel= SolverModel(graph,
poi=poi,
inferTypes=['local/argmax', 'local/softmax'],
metric={})
super().__init__(graph, Model, CModel=GBIModel, beta=beta, solver_model = mySolverModel, poi=poi, **kwargs)
super().__init__(graph, Model, CModel=GBIModel, beta=beta, poi=poi, **kwargs)
from domiknows.utils import setDnSkeletonMode
setDnSkeletonMode(True)
52 changes: 43 additions & 9 deletions domiknows/program/model/gbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
import torch
import copy
from torch.optim import SGD
from domiknows.graph import DataNodeBuilder

from domiknows.program.metric import MacroAverageTracker, MetricTracker
from domiknows.sensor.pytorch.sensors import TorchSensor
from domiknows.sensor.pytorch.learners import ModuleLearner
from domiknows.program.model.base import Mode

from .pytorch import SolverModel
Expand All @@ -17,10 +20,14 @@

# Gradient-based Inference
class GBIModel(torch.nn.Module):
def __init__(self, graph, solver_model, gbi_iters = 100, device='auto'):
def __init__(self, graph, solver_model=None, gbi_iters = 100, device='auto'):
super().__init__()

self.server_model= solver_model
if solver_model is None:
solver_model =self
else:
self.server_model= solver_model

self.gbi_iters = gbi_iters
self.device = device

Expand Down Expand Up @@ -102,9 +109,21 @@ def populate_forward(self, model, data_item):
Forward pass through torch model.
Returns DataNode and DataNodeBuilder.
"""
_, _, *output = model(data_item)
node = detuple(*output[:1])
return node, output[1]

#loss, metric, datanode, builder = model(data_item)

data_item.update({"graph": model.graph, 'READER': 0})
builder = DataNodeBuilder(data_item)

for i, prop in enumerate(model.poi):
for sensor in prop.find(ModuleLearner):
sensor(builder, force=True)

builder.createBatchRootDN()
datanode = builder.getDataNode(context="build", device=self.device)

return datanode, builder

# ----

def forward(self, builder, build=None):
Expand Down Expand Up @@ -168,21 +187,33 @@ def forward(self, builder, build=None):

# -- Constraint loss: NLL * binary satisfaction + regularization loss
# reg loss is calculated based on L2 distance of weights between optimized model and original weights
c_loss = -1 * log_probs * is_satisfied + self.reg_loss(model_l, self.server_model)
reg_loss = self.reg_loss(model_l, self.server_model)
c_loss = -1 * log_probs * is_satisfied + reg_loss

if c_loss != c_loss:
continue

print("iter=%d, c_loss=%d, num_constraints_l=%d, satisfied=%d"%(c_iter, c_loss.item(), num_constraints_l, num_satisfied_l))

# --- Check if constraints are satisfied
if num_satisfied_l == num_constraints_l:
# --- End early if constraints are satisfied
self.server_model
return c_loss, datanode, builder
elif no_of_not_satisfied > 3: # three consecutive iterations where constraints are not satisfied
return c_loss, datanode, builder # ? float("nan")

# --- Backward pass on model_l
if c_loss.requires_grad:
c_loss.backward(retain_graph=True)

print("Step after backward")
for name, x in model_l.named_parameters():
if x.grad is None:
print(name, 'no grad')
continue

print(name, 'grad: ', torch.sum(torch.abs(x.grad)))

# -- Update model_l
c_opt.step()

Expand All @@ -205,8 +236,11 @@ def calculateGBISelection(self, builder, conceptsRelations):
continue

for dn in dns:
v = dn.getAttribute(c[0])

v = dn.getAttribute(c[0]) # Get ILP results
if v is None:
continue

# Create GBI results
vGBI = torch.zeros(v.size())
vArgmaxIndex = torch.argmax(v).item()
vGBI[vArgmaxIndex] = 1
Expand Down
5 changes: 3 additions & 2 deletions domiknows/sensor/pytorch/sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ def __init__(self, *pres, edges=None, label=False, device='auto'):

def __call__(
self,
data_item: Dict[str, Any]
data_item: Dict[str, Any],
force=False
) -> Any:
self.context_helper = data_item
try:
self.update_context(data_item)
self.update_context(data_item, force=force)
except Exception as ex:
print('Error {} during updating data item {} with sensor {}'.format(ex, data_item, self.fullname))
raise
Expand Down
4 changes: 2 additions & 2 deletions domiknows/solver/gurobiILPOntSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ def verifyResultsLC(self, dn, key = "/argmax"):
if verifyListLen:
current_verifyResult['satisfied'] = (verifyListSatisfied / verifyListLen) * 100
else:
current_verifyResult['satisfied'] = float("nan")
current_verifyResult['satisfied'] = 0 # float("nan")

# If this if logical constraints
if type(lc) is ifL or type(lc) is forAllL: # if LC
Expand Down Expand Up @@ -2100,7 +2100,7 @@ def verifyResultsLC(self, dn, key = "/argmax"):
if ifVerifyListLen:
current_verifyResult['ifSatisfied'] = (ifVerifyListSatisfied / ifVerifyListLen) *100
else:
current_verifyResult['ifSatisfied'] = float("nan")
current_verifyResult['ifSatisfied'] = 0 #float("nan")

endLC = process_time_ns() # timer()
elapsedInNsLC = endLC - startLC
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
long_description_content_type='text/markdown',

url='https://github.com/HLR/DomiKnowS',
author='Andrzej Uszok',
author_email='[email protected]',

author='Andrzej Uszok, Parisa Kordjamshidi',
author_email='[email protected], [email protected]',
packages=find_packages(include=['domiknows', 'domiknows.*', 'README.md']),

install_requires=[
Expand Down

0 comments on commit 8bdd83e

Please sign in to comment.