Skip to content

Commit

Permalink
Imlemented dummy datanode and test for it
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrzej Uszok committed Sep 25, 2023
1 parent dba31f1 commit ecbe94b
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,4 @@ domiknows/program/model/gbi copy.py
# venv
venv
.externalToolBuilders/org.eclipse.jdt.core.javabuilder.launch
.settings/org.eclipse.ltk.core.refactoring.prefs
2 changes: 1 addition & 1 deletion Tasks
3 changes: 2 additions & 1 deletion domiknows/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
from .candidates import CandidateSelection, combinationC
from .property import Property
from .trial import Trial
from .dataNode import DataNode, DataNodeBuilder
from .dataNode import DataNode, DataNodeBuilder
from .dataNodeDummy import createDummyDataNode
11 changes: 7 additions & 4 deletions domiknows/graph/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def intersection_of_lists(lists):
return ordered_common_elements

def findDatanodesForRootConcept(dn, rootConcept):
if "DataNodesConcepts" in dn.myBuilder:
if dn.myBuilder != None and "DataNodesConcepts" in dn.myBuilder:
if rootConcept.name in dn.myBuilder["DataNodesConcepts"]:
return dn.myBuilder["DataNodesConcepts"][rootConcept.name]

dns = dn.findDatanodes(select = rootConcept)
dn.myBuilder["DataNodesConcepts"][rootConcept.name] = dns

if dn.myBuilder != None:
dn.myBuilder["DataNodesConcepts"][rootConcept.name] = dns

return dns

def getCandidates(dn, e, variable, lcVariablesDns, lc, logger, integrate = False):
Expand All @@ -83,7 +86,7 @@ def getCandidates(dn, e, variable, lcVariablesDns, lc, logger, integrate = False

# Check if we already found this variable
if variable.name in lcVariablesDns:
dnsList = lcVariablesDns[variable.name]
dnsList = lcVariablesDns[variable.name]
else:
rootConcept = dn.findRootConceptOrRelation(conceptName)
rootDns = findDatanodesForRootConcept(dn, rootConcept)
Expand Down
25 changes: 18 additions & 7 deletions domiknows/graph/dataNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,18 @@

# Class representing single data instance with relation links to other data nodes
class DataNode:
def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, ontologyNode = None, relationLinks = {}, attributes = {}):
def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, ontologyNode = None, graph = None, relationLinks = {}, attributes = {}):

self.myBuilder = myBuilder # DatanodeBuilder used to construct this datanode
self.instanceID = instanceID # The data instance id (e.g. paragraph number, sentence number, phrase number, image number, etc.)
self.instanceValue = instanceValue # Optional value of the instance (e.g. paragraph text, sentence text, phrase text, image bitmap, etc.)
self.ontologyNode = ontologyNode # Reference to the node in the ontology graph (e.g. Concept) which is the type of this instance (e.g. paragraph, sentence, phrase, etc.)

self.graph = self.ontologyNode.sup
if ontologyNode is not None:
self.graph = self.ontologyNode.sup
if graph is not None:
self.graph = graph

if relationLinks:
self.relationLinks = relationLinks # Dictionary mapping relation name to RelationLinks
else:
Expand All @@ -92,7 +96,13 @@ def __init__(self, myBuilder = None, instanceID = None, instanceValue = None, on
else:
self.attributes = {}

self.current_device = 'auto'
self.current_device = 'cpu'
if torch.cuda.is_available():
if torch.cuda.device_count() > 1:
self.current_device = 'cuda:1'
else:
self.current_device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.gurobiModel = None

self.myLoggerTime = getRegrTimer_logger()
Expand Down Expand Up @@ -232,12 +242,12 @@ def hasAttribute(self, key):
keyInVariableSet = self.ontologyNode.name + "/" + key
if keyInVariableSet in rootDataNode.attributes["variableSet"]:
return True
elif keyInVariableSet in rootDataNode.attributes["propertySet"]:
elif "propertySet" in rootDataNode.attributes and keyInVariableSet in rootDataNode.attributes["propertySet"]:
return True
elif "variableSet" in self.attributes:
if key in self.attributes["variableSet"]:
return True
elif key in self.attributes["propertySet"]:
elif "propertySet" in self and key in self.attributes["propertySet"]:
return True
else:
return False
Expand Down Expand Up @@ -1397,10 +1407,11 @@ def inferILPResults(self, *_conceptsRelations, key = ("local" , "softmax"), fun=
_DataNode__Logger.info('Called with - %s - list of concepts and relations for inference'%([x.name if isinstance(x, Concept) else x for x in _conceptsRelations]))

# Check if full data node is created and create it if not -it is needed for ILP inference
self.myBuilder.createFullDataNode(self)
if self.myBuilder:
self.myBuilder.createFullDataNode(self)

# Check if concepts and/or relations have been provided for inference, if provide translate then to tuple concept info form
_conceptsRelations = self.collectConceptsAndRelations(_conceptsRelations) # Collect all concepts and relations from graph as default set
_conceptsRelations = self.collectConceptsAndRelations(_conceptsRelations) # Collect all concepts and relations from data graph as default set
if len(_conceptsRelations) == 0:
_DataNode__Logger.error('Not found any concepts or relations for inference in provided DataNode %s'%(self))
raise DataNode.DataNodeError('Not found any concepts or relations for inference in provided DataNode %s'%(self))
Expand Down
134 changes: 134 additions & 0 deletions domiknows/graph/dataNodeDummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from . import DataNode, Graph
import torch

dataSizeInit = 5

def findConcept(conceptName, usedGraph):
subGraph_keys = [key for key in usedGraph._objs]
for subGraphKey in subGraph_keys:
subGraph = usedGraph._objs[subGraphKey]

for conceptNameItem in subGraph.concepts:
if conceptName == conceptNameItem:
concept = subGraph.concepts[conceptNameItem]

return concept

return None
def findConceptInfo(usedGraph, concept):
conceptInfo = {
'concept': concept,
'relation': bool(concept.has_a()),
'relationAttrs': {rel.name: findConcept(rel.dst.name, usedGraph) for _, rel in enumerate(concept.has_a())},
'contains': [contain.dst for contain in concept._out.get('contains', [])],
'containedIn': [contain.src for contain in concept._in.get('contains', [])],
'is_a': [contain.dst for contain in concept._out.get('is_a', [])]
}

if not conceptInfo['containedIn'] and not conceptInfo['is_a'] and not conceptInfo['relation']:
conceptInfo['root'] = True
else:
conceptInfo['root'] = False

return conceptInfo

def addDatanodes(concept, conceptInfos, datanodes, allDns, level=1):

currentConceptInfo = conceptInfos[concept.name]
instanceID = currentConceptInfo.get('count', 0)

for dn in datanodes:
dns = []
for i in range(dataSizeInit * level):
newDN = DataNode(instanceID = instanceID, ontologyNode = currentConceptInfo['concept'])
dn.addChildDataNode(newDN)
dns.append(newDN)
instanceID+=1

for contain in currentConceptInfo['contains']:
addDatanodes(contain, conceptInfos, dns, allDns, level = level+1)

currentConceptInfo['count'] = currentConceptInfo.get('count', 0) + dataSizeInit * level
currentConceptInfo.setdefault('dns', {}).setdefault(dn.ontologyNode.name, []).extend(dns)
currentConceptInfo['dns'].setdefault('all', []).extend(dns)

allDns.extend(dns)

def createDummyDataNode(graph):
rootDataNode = None
rootConcept = None

conceptInfos = {}
allDns = []

for currentConceptKey, currentConcept in graph.concepts.items():
conceptInfo = findConceptInfo(graph, currentConcept)
conceptInfos[currentConceptKey] = conceptInfo
if conceptInfo['root']:
rootConcept = currentConceptKey

for subGraphKey, subGraph in graph.subgraphs.items():
for currentConceptKey, currentConcept in subGraph.concepts.items():
conceptInfo = findConceptInfo(subGraph, currentConcept)
conceptInfos[currentConceptKey] = conceptInfo
if conceptInfo['root']:
rootConcept = currentConceptKey

if rootConcept:
# Add root datanode
rootConceptInfo = conceptInfos[rootConcept]
rootDataNode = DataNode(instanceID = 1, ontologyNode = rootConceptInfo['concept'])
rootDataNode.attributes["variableSet"] = {}

# Add children datanodes
for contain in rootConceptInfo['contains']:
addDatanodes(contain, conceptInfos, [rootDataNode], allDns)

# Add relation
for currentConceptKey in conceptInfos:
relationConceptInfo = conceptInfos[currentConceptKey]
relationDns = []

if relationConceptInfo['relation'] and not relationConceptInfo['is_a']:
for d, attr in enumerate(relationConceptInfo['relationAttrs']):
attrConceptInfo = conceptInfos[relationConceptInfo['relationAttrs'][attr].name]

instanceID = relationConceptInfo.get('count', 0)

for i in range(attrConceptInfo['count']):
if d == 0:
newDN = DataNode(instanceID = instanceID, ontologyNode = attrConceptInfo['concept'])
relationDns.append(newDN)
instanceID+=1
else:
newDN = relationDns[1]

newDN.addRelationLink(attr, attrConceptInfo["dns"]["all"][i])

relationConceptInfo['count'] = relationConceptInfo.get('count', 0) + instanceID

allDns.extend(relationDns)
relationConceptInfo.setdefault('dns', {})['all'] = relationDns

# Add probabilities
for currentConceptKey in conceptInfos:
conceptInfo = conceptInfos[currentConceptKey]
if conceptInfo['is_a']:
conceptRootConceptInfo = conceptInfos[conceptInfo['is_a'][0].name]

if 'count' not in conceptRootConceptInfo:
continue

m = conceptRootConceptInfo['count']
random_tensor = torch.rand(m, 1)
final_tensor = torch.cat((1 - random_tensor, random_tensor), dim=1)
rootDataNode.attributes["variableSet"][conceptRootConceptInfo['concept'].name +'/<' + conceptInfo['concept'].name + '>'] = final_tensor
continue

# Iterate over the data nodes in "allDns" and add the "rootDataNode" attribute to them
for dn in allDns:
if dn == rootDataNode:
continue
dn.attributes["rootDataNode"] = rootDataNode

return rootDataNode
Expand Down
51 changes: 51 additions & 0 deletions test_regr/dummy_datanode/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from domiknows.graph import Graph, Concept, Relation
from domiknows.graph.logicalConstrain import ifL, andL, nandL, atMostL, existsL

Graph.clear()
Concept.clear()
Relation.clear()

with Graph('global') as graph:
with Graph('linguistic') as ling_graph:
char = Concept(name='char')
word = Concept(name='word')
phrase = Concept(name='phrase')
sentence = Concept(name='sentence')
(rel_sentence_contains_word,) = sentence.contains(word)
(rel_sentence_contains_phrase,) = sentence.contains(phrase)
(rel_phrase_contains_word,) = phrase.contains(word)
(rel_word_contains_char,) = word.contains(char)

pair = Concept(name='pair')
(rel_pair_word1, rel_pair_word2, ) = pair.has_a(arg1=word, arg2=word)

with Graph('application', auto_constraint=True) as app_graph:
people = word(name='people')
organization = word(name='organization')
location = word(name='location')
other = word(name='other')
o = word(name='O')

#disjoint(people, organization, location, other, o)

# LC0
nandL(people, organization, active = True)

work_for = pair(name='work_for')
located_in = pair(name='located_in')
live_in = pair(name='live_in')
orgbase_on = pair(name='orgbase_on')
kill = pair(name='kill')

work_for.has_a(people, organization)
located_in.has_a(location, location)
live_in.has_a(people, location)
orgbase_on.has_a(organization, location)
kill.has_a(people, people)

# LC2
ifL(existsL(work_for('x')), andL(people(path=('x', rel_pair_word1.name)), organization(path=('x', rel_pair_word2.name))), active = True)

# LC3
ifL(word, atMostL(people, organization, location, other, o), active = True)

18 changes: 18 additions & 0 deletions test_regr/dummy_datanode/test_dummy_datanode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from domiknows.graph import createDummyDataNode
from graph import graph

def test_dummy_data_node_inference():
testDummyDn = createDummyDataNode(graph)

# Checking if inferILPResults doesn't raise any exception
try:
testDummyDn.inferILPResults()
except Exception:
pytest.fail("inferILPResults raised an exception")

# Checking if infer doesn't raise any exception
try:
testDummyDn.infer()
except Exception:
pytest.fail("infer raised an exception")

0 comments on commit ecbe94b

Please sign in to comment.