Skip to content

Commit

Permalink
make the memory graph self-initialize, move functionality of get_stan…
Browse files Browse the repository at this point in the history
…dard_by_db_id to get_nodes
  • Loading branch information
northdpole committed Aug 11, 2024
1 parent 8814dea commit f8406c5
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 46 deletions.
64 changes: 27 additions & 37 deletions application/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,13 +742,11 @@ def __get_all_nodes_and_cres(self) -> List[cre_defs.Document]:
cres = []
node_ids = self.session.query(Node.id).all()
for nid in node_ids:
nodes.append(self.get_node_by_db_id(nid[0]))
[result.append(node) for node in nodes]
result.extend(self.get_nodes(db_id=nid[0]))

cre_ids = self.session.query(CRE.id).all()
for cid in cre_ids:
cres.append(self.get_cre_by_db_id(cid[0]))
[result.append(cre) for cre in cres]
result.append(self.get_cre_by_db_id(cid[0]))
return result

def __introduces_cycle(self, node_from: str, node_to: str) -> Any:
Expand Down Expand Up @@ -983,19 +981,24 @@ def get_nodes(
description: Optional[str] = None,
ntype: str = cre_defs.Standard.__name__,
sectionID: Optional[str] = None,
db_id: Optional[str] = None,
) -> Optional[List[cre_defs.Node]]:
nodes = []
nodes_query = self.__get_nodes_query__(
name=name,
section=section,
subsection=subsection,
link=link,
version=version,
partial=partial,
ntype=ntype,
description=description,
sectionID=sectionID,
)
nodes_query = None
if db_id:
nodes_query = self.session.query(Node).filter(Node.id == db_id)
else:
nodes_query = self.__get_nodes_query__(
name=name,
section=section,
subsection=subsection,
link=link,
version=version,
partial=partial,
ntype=ntype,
description=description,
sectionID=sectionID,
)
dbnodes = nodes_query.all()
if dbnodes:
for dbnode in dbnodes:
Expand Down Expand Up @@ -1032,22 +1035,6 @@ def get_nodes(

return []

def get_node_by_db_id(self, id: str) -> cre_defs.Node:
node = self.session.query(Node).filter(Node.id == id).first()
if not node:
logger.error(f"Node {id} does not exist in the db")
return None

cs = linked_cres = Links.query.filter(Links.node == id).all()
nodes = self.get_nodes(
name=node.name,
section=node.section,
subsection=node.subsection,
ntype=node.ntype,
sectionID=node.section_id,
)[0]
return nodes

def get_cre_by_db_id(self, id: str) -> cre_defs.CRE:
"""internal method, returns a shallow cre (no links) by its database id
Expand Down Expand Up @@ -1194,12 +1181,13 @@ def get_CREs(
for ls in linked_nodes:
nd = self.session.query(Node).filter(Node.id == ls.node).first()
if not include_only or (include_only and nd.name in include_only):
cre.add_link(
cre_defs.Link(
document=nodeFromDB(nd),
ltype=cre_defs.LinkTypes.from_str(ls.type),
n = nodeFromDB(nd)
if not cre.link_exists(n):
cre.add_link(
cre_defs.Link(
document=n, ltype=cre_defs.LinkTypes.from_str(ls.type)
)
)
)
# TODO figure the query to merge the following two
internal_links = (
self.session.query(InternalLinks)
Expand Down Expand Up @@ -1231,7 +1219,9 @@ def get_CREs(
elif il.group == dbcre.id:
res = q.filter(CRE.id == il.cre).first()
ltype = cre_defs.LinkTypes.from_str(il.type)
cre.add_link(cre_defs.Link(document=CREfromDB(res), ltype=ltype))
c = CREfromDB(res)
if not cre.link_exists(c):
cre.add_link(cre_defs.Link(document=c, ltype=ltype))
cres.append(cre)
return cres

Expand Down
13 changes: 9 additions & 4 deletions application/database/inmemory_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class CRE_Graph:
def instance(cls, documents: List[defs.Document] = None) -> "CRE_Graph":
if cls.__instance is None:
cls.__instance = cls.__new__(cls)
cls.graph = nx.DiGraph()
cls.graph = cls.__load_cre_graph(documents=documents)
return cls.__instance

Expand Down Expand Up @@ -87,15 +88,15 @@ def get_path(self, start: str, end: str) -> List[Tuple[str, str]]:
@classmethod
def add_cre(cls, dbcre: defs.CRE, graph: nx.DiGraph) -> nx.DiGraph:
if dbcre:
graph.add_node(f"CRE: {dbcre.id}", internal_id=dbcre.id)
cls.graph.add_node(f"CRE: {dbcre.id}", internal_id=dbcre.id)
else:
logger.error("Called with dbcre being none")
return graph

@classmethod
def add_dbnode(cls, dbnode: defs.Node, graph: nx.DiGraph) -> nx.DiGraph:
if dbnode:
graph.add_node(
cls.graph.add_node(
"Node: " + str(dbnode.id),
internal_id=dbnode.id,
)
Expand All @@ -105,7 +106,10 @@ def add_dbnode(cls, dbnode: defs.Node, graph: nx.DiGraph) -> nx.DiGraph:

@classmethod
def __load_cre_graph(cls, documents: List[defs.Document]) -> nx.Graph:
graph = nx.DiGraph()
graph = cls.graph
if not graph:
graph = nx.DiGraph()

for doc in documents:
from_doctype = None
if doc.doctype == defs.Credoctypes.CRE:
Expand All @@ -122,9 +126,10 @@ def __load_cre_graph(cls, documents: List[defs.Document]) -> nx.Graph:
else:
graph = cls.add_dbnode(dbnode=link.document, graph=graph)
to_doctype = "Node"
graph = graph.add_edge(
graph.add_edge(
f"{from_doctype}: {doc.id}",
f"{to_doctype}: {link.document.id}",
ltype=link.ltype,
)
cls.graph = graph
return graph
4 changes: 2 additions & 2 deletions application/prompt_client/prompt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def generate_embeddings(
logger.info(f"generating {len(missing_embeddings)} embeddings")
for id in missing_embeddings:
cre = database.get_cre_by_db_id(id)
node = database.get_node_by_db_id(id)
node = database.get_nodes(db_id=id)
content = ""
if node:
if is_valid_url(node.hyperlink):
Expand Down Expand Up @@ -464,7 +464,7 @@ def generate_text(self, prompt: str) -> Dict[str, str]:
)
closest_object = None
if closest_id:
closest_object = self.database.get_node_by_db_id(closest_id)
closest_object = self.database.get_nodes(db_id=closest_id)

logger.info(
f"The prompt {prompt}, was most similar to object \n{closest_object}\n, with similarity:{similarity}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def parse(self, cache: db.Node_collection, ph: prompt_client.PromptHandler):
)
standard_id = ph.get_id_of_most_similar_node(cnsc_embeddings)
if standard_id:
dbstandard = cache.get_node_by_db_id(standard_id)
dbstandard = cache.get_nodes(db_id=standard_id)
logger.info(
f"found an appropriate standard for Cloud Native Security Control {cnsc.section}:{cnsc.subsection}, it is: {dbstandard.name}:{dbstandard.section}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def parse(
f"could not find an appropriate CRE for Juiceshop challenge {chal.section}, findings similarities with standards instead"
)
standard_id = ph.get_id_of_most_similar_node(challenge_embeddings)
dbstandard = cache.get_node_by_db_id(standard_id)
dbstandard = cache.get_nodes(db_id=standard_id)
logger.info(
f"found an appropriate standard for Juiceshop challenge {chal.section}, it is: {dbstandard.section}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __parse(
f"could not find an appropriate CRE for pci {pci_control.section}, findings similarities with standards instead"
)
standard_id = prompt.get_id_of_most_similar_node(control_embeddings)
dbstandard = cache.get_node_by_db_id(standard_id)
dbstandard = cache.get_nodes(db_id=standard_id)
logger.info(
f"found an appropriate standard for pci {pci_control.section}, it is: {dbstandard.section}"
)
Expand Down

0 comments on commit f8406c5

Please sign in to comment.