From e567f245e9a7ad9c2a57ee102011de44072b9301 Mon Sep 17 00:00:00 2001 From: Spyros Date: Sun, 22 Sep 2024 19:15:28 +0100 Subject: [PATCH] make inmemory graph a singleton, add comments in cre retrieval tests and ensure all graph relationships are Contains or Equal with the CRE-id and Node-id instead of the db ids --- application/database/db.py | 54 ++--- application/database/inmemory_graph.py | 184 ++++++++++------- application/tests/capec_parser_test.py | 15 +- ...ud_native_security_controls_parser_test.py | 5 +- application/tests/cre_main_test.py | 27 ++- application/tests/cwe_parser_test.py | 20 +- application/tests/db_test.py | 193 +++++++++++++----- application/tests/defs_test.py | 5 +- application/tests/juiceshop_test.py | 5 +- application/tests/oscal_utils_test.py | 7 +- application/tests/spreadsheet_test.py | 4 +- application/tests/web_main_test.py | 37 ++-- application/web/web_main.py | 1 - 13 files changed, 367 insertions(+), 190 deletions(-) diff --git a/application/database/db.py b/application/database/db.py index e143af5b..1f2df3ff 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -329,6 +329,7 @@ class NeoCRE(NeoDocument): # type: ignore auto_linked_to = RelationshipTo( "NeoStandard", "AUTOMATICALLY_LINKED_TO", model=AutoLinkedToRel ) + @classmethod def to_cre_def(self, node, parse_links=True) -> cre_defs.CRE: return cre_defs.CRE( @@ -678,7 +679,12 @@ def __init__(self) -> None: def with_graph(self) -> "Node_collection": logger.info("Loading CRE graph in memory, memory-heavy operation!") - self.graph = inmemory_graph.CRE_Graph.instance(documents=self.__get_all_nodes_and_cres()) + self.graph = inmemory_graph.CRE_Graph() + graph_singleton = inmemory_graph.Singleton_Graph_Storage.instance() + self.graph.with_graph( + graph=graph_singleton, + graph_data=self.__get_all_nodes_and_cres(), + ) return self def __get_external_links(self) -> List[Tuple[CRE, Node, str]]: @@ -1345,17 +1351,9 @@ def get_cre_hierarchy(self, cre: cre_defs.CRE) -> int: if not self.graph: self.with_graph() roots = self.get_root_cres() - root_cre_db_ids = [] - for r in roots: - dbid = self.session.query(CRE.id).filter(CRE.external_id == r.id).first()[0] - root_cre_db_ids.append(dbid) - - credbid = self.session.query(CRE.id).filter(CRE.external_id == cre.id).first() - if not credbid: - raise ValueError(f"CRE {cre.id} does not exist in the database") - credbid = credbid[0] - return self.graph.get_hierarchy(rootIDs=root_cre_db_ids, creID=credbid) + root_cre_ids = [r.id for r in roots] + return self.graph.get_hierarchy(rootIDs=root_cre_ids, creID=cre.id) # def all_nodes_with_pagination( # self, page: int = 1, per_page: int = 10 @@ -1414,10 +1412,8 @@ def delete_gapanalysis_results_for(self, node_name): return res def add_cre(self, cre: cre_defs.CRE) -> CRE: - entry: CRE - query: sqla.Query = self.session.query(CRE).filter( - func.lower(CRE.name) == cre.name.lower() - ) + entry: CRE = None + query = self.session.query(CRE).filter(func.lower(CRE.name) == cre.name.lower()) if cre.id: entry = query.filter(CRE.external_id == cre.id).first() else: @@ -1450,7 +1446,7 @@ def add_cre(self, cre: cre_defs.CRE) -> CRE: self.session.add(entry) self.session.commit() if self.graph: - self.graph = self.graph.add_cre(dbcre=entry, graph=self.graph) + self.graph.add_cre(cre=cre) return entry def add_node( @@ -1481,14 +1477,14 @@ def add_node( self.session.add(dbnode) self.session.commit() if self.graph: - self.graph.add_dbnode(dbnode=dbnode) + self.graph.add_dbnode(dbnode=node) return dbnode def add_internal_link( self, higher: CRE, lower: CRE, - ltype, + ltype: cre_defs.LinkTypes, ) -> cre_defs.Link: """ adds a link between two CREs in the database, @@ -1501,6 +1497,11 @@ def add_internal_link( if ltype == None: raise ValueError("Every link should have a link type") + if ltype == cre_defs.LinkTypes.PartOf: + raise ValueError( + "internal_link does not support 'PartOf' relationships, call it with higher/lower being opposite and the ltype being 'Contains' " + ) + if lower.id is None: if lower.external_id is None: lower = ( @@ -1573,7 +1574,7 @@ def add_internal_link( # ) # entry_exists.ltype = ltype.value # self.session.commit() - return cre_defs.Link(document=CREfromDB(lower),ltype=ltype) + return cre_defs.Link(document=CREfromDB(lower), ltype=ltype) logger.info( "did not know of internal link" @@ -1587,11 +1588,8 @@ def add_internal_link( ) higher_cre = CREfromDB(higher) - lower_cre = CREfromDB(higher) + lower_cre = CREfromDB(lower) link_to = cre_defs.Link(document=lower_cre, ltype=ltype) - - if type(self.graph) != inmemory_graph.CRE_Graph: - raise ValueError("wtf?") cycle = self.graph.introduces_cycle(doc_from=higher_cre, link_to=link_to) if not cycle: @@ -1600,7 +1598,7 @@ def add_internal_link( ) self.session.commit() if self.graph: - self.graph.add_link(doc_from=higher_cre,link_to=link_to) + self.graph.add_link(doc_from=higher_cre, link_to=link_to) else: for item in cycle: from_id = item[0].replace("CRE: ", "") @@ -1617,7 +1615,7 @@ def add_internal_link( f"would introduce cycle {cycle}, skipping" ) return None - return cre_defs.Link(document=lower,ltype=ltype) + return cre_defs.Link(document=lower, ltype=ltype) def add_link( self, @@ -1657,8 +1655,10 @@ def add_link( ) self.session.add(Links(type=ltype.value, cre=cre.id, node=node.id)) if self.graph: - self.graph.add_link(doc_from=CREfromDB(cre),link_to=cre_defs.Link(document=nodeFromDB(node),ltype=ltype.value)) - + self.graph.add_link( + doc_from=CREfromDB(cre), + link_to=cre_defs.Link(document=nodeFromDB(node), ltype=ltype.value), + ) self.session.commit() diff --git a/application/database/inmemory_graph.py b/application/database/inmemory_graph.py index f54d7267..2f49abcf 100644 --- a/application/database/inmemory_graph.py +++ b/application/database/inmemory_graph.py @@ -1,60 +1,87 @@ import sys +import logging import networkx as nx from typing import List, Tuple from application.defs import cre_defs as defs -class CRE_Graph: - graph: nx.Graph = None - __parent_child_subgraph = None - __instance: "CRE_Graph" = None +logging.basicConfig() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class Singleton_Graph_Storage(nx.DiGraph): + __instance: "Singleton_Graph_Storage" = None @classmethod - def instance(cls, documents: List[defs.Document] = None) -> "CRE_Graph": + def instance(cls) -> "Singleton_Graph_Storage": if cls.__instance is None: - cls.__instance = cls.__new__(cls) - cls.graph = nx.DiGraph() - cls.graph = cls.__load_cre_graph(documents=documents) + cls.__instance = nx.DiGraph() return cls.__instance - def __init__(sel): + def __init__(): raise ValueError("CRE_Graph is a singleton, please call instance() instead") - def add_node(self, *args, **kwargs): - return self.graph.add_node(*args, **kwargs) + +class CRE_Graph: + __graph: nx.Graph = None + __parent_child_subgraph = None + + def get_raw_graph(self): + return self.__graph + + def with_graph(cls, graph: nx.Graph, graph_data: List[defs.Document]): + cls.__graph = graph + if not len(graph.edges): + cls.__load_cre_graph(graph_data) def introduces_cycle(self, doc_from: defs.Document, link_to: defs.Link): try: - existing_cycle = nx.find_cycle(self.graph) - if existing_cycle: + ex = self.has_cycle() + if ex: raise ValueError( "Existing graph contains cycle," "this not a recoverable error," - f" manual database actions are required {existing_cycle}" + f" manual database actions are required {ex}" ) except nx.exception.NetworkXNoCycle: pass # happy path, we don't want cycles - new_graph = self.graph.copy() + + # TODO: when this becomes too slow (e.g. when we are importing 1000s of CREs at once) + # we can instead add the edge find the cycle and then remove the edge + new_graph = self.__graph.copy() # this needs our special add_edge but with the copied graph - new_graph = self.add_graph_edge( + new_graph = self.__add_graph_edge( doc_from=doc_from, link_to=link_to, graph=new_graph ) try: return nx.find_cycle(new_graph) - except nx.NetworkXNoCycle: - return False + except nx.exception.NetworkXNoCycle: + return None + + def has_cycle(self): + try: + ex = nx.find_cycle(self.__graph) + return ex + except nx.exception.NetworkXNoCycle: + return None def get_hierarchy(self, rootIDs: List[str], creID: str): + + if len(self.__graph.edges) == 0: + logger.error("graph is empty") + return -1 + if creID in rootIDs: return 0 if self.__parent_child_subgraph == None: - if len(self.graph.edges) == 0: + if len(self.__graph.edges) == 0: raise ValueError("Graph has no edges") include_cres = [] - for el in self.graph.edges: - edge_data = self.graph.get_edge_data(*el) + for el in self.__graph.edges: + edge_data = self.__graph.get_edge_data(*el) if ( el[0].startswith("CRE") and el[1].startswith("CRE") @@ -71,7 +98,7 @@ def get_hierarchy(self, rootIDs: List[str], creID: str): el not in include_cres ): # If the root is not in the parent/children graph, add it to prevent an error and continue, there is not path to our CRE anyway include_cres.append(f"CRE: {el}") - self.__parent_child_subgraph = self.graph.subgraph(set(include_cres)) + self.__parent_child_subgraph = self.__graph.subgraph(set(include_cres)) shortest_path = sys.maxsize for root in rootIDs: @@ -87,49 +114,61 @@ def get_hierarchy(self, rootIDs: List[str], creID: str): ) - 1, ) - except ( - nx.NodeNotFound - ) as nnf: # If the CRE is not in the parent/children graph it means that it's a lone CRE, so it's a root and we return 0 + except nx.exception.NodeNotFound: + # If the CRE is not in the parent/children graph it means that it's a lone CRE, so it's a root and we return 0 + if f"CRE: {root}" not in self.__graph.nodes(): + raise ValueError(f"CRE: {root} isn't in the graph") + if f"CRE: {creID}" not in self.__graph.nodes(): + raise ValueError(f"CRE: {creID} isn't in the graph") return 0 - except ( - nx.NetworkXNoPath - ) as nxnp: # If there is no path to the CRE, continue + except nx.exception.NetworkXNoPath: + # If there is no path to the CRE, continue continue return shortest_path def get_path(self, start: str, end: str) -> List[Tuple[str, str]]: try: - return nx.shortest_path(self.graph, start, end) + return nx.shortest_path(self.__graph, start, end) except nx.NetworkXNoPath: return [] - @classmethod - def add_cre(cls, dbcre: defs.CRE, graph: nx.DiGraph) -> nx.DiGraph: - if dbcre: - cls.graph.add_node(f"CRE: {dbcre.id}", internal_id=dbcre.id) - else: - logger.error("Called with dbcre being none") - return graph - - @classmethod - def add_dbnode(cls, dbnode: defs.Node, graph: nx.DiGraph) -> nx.DiGraph: - if dbnode: - cls.graph.add_node( - "Node: " + str(dbnode.id), - internal_id=dbnode.id, + def add_cre(cls, cre: defs.CRE): + if not isinstance(cre, defs.CRE): + raise ValueError( + f"inmemory graph add_cre takes only cre objects, instead got {type(cre)}" ) + graph_cre = f"{defs.Credoctypes.CRE.value}: {cre.id}" + if cre and graph_cre not in cls.__graph.nodes(): + cls.__graph.add_node(graph_cre, internal_id=cre.id) + + def add_dbnode(cls, dbnode: defs.Node): + graph_node = "Node: " + str(dbnode.id) + + if dbnode and graph_node not in cls.__graph.nodes(): + cls.__graph.add_node(graph_node, internal_id=dbnode.id) else: logger.error("Called with dbnode being none") - return graph - @classmethod - def add_graph_edge( + def add_link(self, doc_from: defs.Document, link_to: defs.Link): + self.__graph = self.__add_graph_edge( + doc_from=doc_from, link_to=link_to, graph=self.__graph + ) + + def __add_graph_edge( cls, doc_from: defs.Document, link_to: defs.Link, graph: nx.DiGraph, ) -> nx.digraph: - + """ + Adds a directed edge to the graph provided + called by both graph population and speculative cycle finding methods + hence why it accepts a graph and returns a graph + """ + if doc_from.name == link_to.document.name: + raise ValueError( + f"cannot add an edge from a document to itself, from: {doc_from}, to: {link_to.document}" + ) to_doctype = defs.Credoctypes.CRE.value if link_to.document.doctype != defs.Credoctypes.CRE.value: to_doctype = "Node" @@ -137,56 +176,57 @@ def add_graph_edge( if doc_from.doctype == defs.Credoctypes.CRE: if link_to.ltype == defs.LinkTypes.Contains: graph.add_edge( - f"{doc_from.doctype}: {doc_from.id}", + f"{doc_from.doctype.value}: {doc_from.id}", f"{to_doctype}: {link_to.document.id}", - ltype=link_to.ltype, + ltype=link_to.ltype.value, ) elif link_to.ltype == defs.LinkTypes.PartOf: graph.add_edge( f"{to_doctype}: {link_to.document.id}", - f"{doc_from.doctype}: {doc_from.id}", - ltype=defs.LinkTypes.Contains, + f"{doc_from.doctype.value}: {doc_from.id}", + ltype=defs.LinkTypes.Contains.value, ) elif link_to.ltype == defs.LinkTypes.Related: # do nothing if the opposite already exists in the graph, otherwise we introduce a cycle if graph.has_edge( f"{to_doctype}: {link_to.document.id}", - f"{doc_from.doctype}: {doc_from.id}", + f"{doc_from.doctype.value}: {doc_from.id}", ): return graph graph.add_edge( - f"{doc_from.doctype}: {doc_from.id}", + f"{doc_from.doctype.value}: {doc_from.id}", + f"{to_doctype}: {link_to.document.id}", + ltype=defs.LinkTypes.Related.value, + ) + elif ( + link_to.ltype == defs.LinkTypes.LinkedTo + or link_to.ltype == defs.LinkTypes.AutomaticallyLinkedTo + ): + graph.add_edge( + f"{doc_from.doctype.value}: {doc_from.id}", f"{to_doctype}: {link_to.document.id}", - ltype=defs.LinkTypes.Related, + ltype=link_to.ltype.value, ) + else: + raise ValueError(f"link type {link_to.ltype.value} not recognized") else: graph.add_edge( - f"{doc_from.doctype}: {doc_from.id}", + f"{doc_from.doctype.value}: {doc_from.id}", f"{to_doctype}: {link_to.document.id}", - ltype=link_to.ltype, + ltype=link_to.ltype.value, ) return graph - @classmethod - def __load_cre_graph(cls, documents: List[defs.Document]) -> nx.Graph: - graph = cls.graph - if not graph: - graph = nx.DiGraph() - + def __load_cre_graph(cls, documents: List[defs.Document]): for doc in documents: - from_doctype = None if doc.doctype == defs.Credoctypes.CRE: - graph = cls.add_cre(dbcre=doc, graph=graph) - from_doctype = defs.Credoctypes.CRE + cls.add_cre(cre=doc) else: - graph = cls.add_dbnode(dbnode=doc, graph=graph) - from_doctype = doc.doctype + cls.add_dbnode(dbnode=doc) for link in doc.links: if link.document.doctype == defs.Credoctypes.CRE: - graph = cls.add_cre(dbcre=link.document, graph=graph) + cls.add_cre(cre=link.document) else: - graph = cls.add_dbnode(dbnode=link.document, graph=graph) - graph = cls.add_graph_edge(doc_from=doc, link_to=link, graph=graph) - cls.graph = graph - return graph + cls.add_dbnode(dbnode=link.document) + cls.__add_graph_edge(doc_from=doc, link_to=link, graph=cls.__graph) diff --git a/application/tests/capec_parser_test.py b/application/tests/capec_parser_test.py index 8f0002c3..278d7609 100644 --- a/application/tests/capec_parser_test.py +++ b/application/tests/capec_parser_test.py @@ -45,9 +45,18 @@ class fakeRequest: name="CAPEC", doctype=defs.Credoctypes.Standard, links=[ - defs.Link(document=defs.CRE(name="CRE-276", id="276-276"),ltype=defs.LinkTypes.LinkedTo), - defs.Link(document=defs.CRE(name="CRE-285", id="285-285"),ltype=defs.LinkTypes.LinkedTo), - defs.Link(document=defs.CRE(name="CRE-434", id="434-434"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.CRE(name="CRE-276", id="276-276"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(name="CRE-285", id="285-285"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(name="CRE-434", id="434-434"), + ltype=defs.LinkTypes.LinkedTo, + ), ], hyperlink="https://capec.mitre.org/data/definitions/1.html", sectionID="1", diff --git a/application/tests/cloud_native_security_controls_parser_test.py b/application/tests/cloud_native_security_controls_parser_test.py index 26dc1fe2..e3a3fbf3 100644 --- a/application/tests/cloud_native_security_controls_parser_test.py +++ b/application/tests/cloud_native_security_controls_parser_test.py @@ -81,7 +81,10 @@ class fakeRequest: embeddings_text="Secrets are injected at runtime, such as environment variables or as a file", hyperlink="https://github.com/cloud-native-security-controls/controls-catalog/blob/main/controls/controls_catalog.csv#L2", links=[ - defs.Link(document=defs.CRE(name="CRE-123", id="123-123"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.CRE(name="CRE-123", id="123-123"), + ltype=defs.LinkTypes.LinkedTo, + ), ], name="Cloud Native Security Controls", section="Access", diff --git a/application/tests/cre_main_test.py b/application/tests/cre_main_test.py index 97b20ee2..d4a93cc7 100644 --- a/application/tests/cre_main_test.py +++ b/application/tests/cre_main_test.py @@ -47,20 +47,23 @@ def test_register_node_with_links(self) -> None: doctype=defs.Credoctypes.Standard, name="CWE", sectionID="598", - ),ltype=defs.LinkTypes.LinkedTo + ), + ltype=defs.LinkTypes.LinkedTo, ), defs.Link( document=defs.Code( doctype=defs.Credoctypes.Code, description="print(10)", name="CodemcCodeFace", - ),ltype=defs.LinkTypes.LinkedTo + ), + ltype=defs.LinkTypes.LinkedTo, ), defs.Link( document=defs.Tool( description="awesome hacking tool", name="ToolmcToolFace", - ),ltype=defs.LinkTypes.LinkedTo + ), + ltype=defs.LinkTypes.LinkedTo, ), ], ) @@ -156,20 +159,21 @@ def test_register_standard_with_groupped_cre_links(self) -> None: description="", name="standard_with", links=[ - defs.Link(document=credoc3,ltype=defs.LinkTypes.LinkedTo), + defs.Link(document=credoc3, ltype=defs.LinkTypes.LinkedTo), defs.Link( document=defs.Standard( doctype=defs.Credoctypes.Standard, name="CWE", sectionID="598" ), - ltype=defs.LinkTypes.LinkedTo + ltype=defs.LinkTypes.LinkedTo, ), - defs.Link(document=credoc2,ltype=defs.LinkTypes.LinkedTo), + defs.Link(document=credoc2, ltype=defs.LinkTypes.LinkedTo), defs.Link( document=defs.Standard( doctype=defs.Credoctypes.Standard, name="ASVS", section="SESSION-MGT-TOKEN-DIRECTIVES-DISCRETE-HANDLING", - ),ltype=defs.LinkTypes.LinkedTo + ), + ltype=defs.LinkTypes.LinkedTo, ), ], section="Session Management", @@ -208,7 +212,10 @@ def test_register_cre(self) -> None: id="100-100", description="CREdesc", name="CREname", - links=[defs.Link(document=standard,ltype=defs.LinkTypes.LinkedTo), defs.Link(document=tool,ltype=defs.LinkTypes.LinkedTo)], + links=[ + defs.Link(document=standard, ltype=defs.LinkTypes.LinkedTo), + defs.Link(document=tool, ltype=defs.LinkTypes.LinkedTo), + ], tags=["CREt1", "CREt2"], metadata={"tags": ["CREl1", "CREl2"]}, ) @@ -263,8 +270,8 @@ def test_register_cre(self) -> None: self.assertCountEqual( c.links, [ - defs.Link(document=standard,ltype=defs.LinkTypes.LinkedTo), - defs.Link(document=tool,ltype=defs.LinkTypes.LinkedTo), + defs.Link(document=standard, ltype=defs.LinkTypes.LinkedTo), + defs.Link(document=tool, ltype=defs.LinkTypes.LinkedTo), defs.Link( document=c_lower.shallow_copy(), ltype=defs.LinkTypes.Contains ), diff --git a/application/tests/cwe_parser_test.py b/application/tests/cwe_parser_test.py index 48b7ef34..2c0f7232 100644 --- a/application/tests/cwe_parser_test.py +++ b/application/tests/cwe_parser_test.py @@ -65,8 +65,14 @@ def iter_content(self, chunk_size=None): name="CWE", doctype=defs.Credoctypes.Standard, links=[ - defs.Link(document=defs.CRE(name="CRE-732", id="732-732"),ltype=defs.LinkTypes.LinkedTo), - defs.Link(document=defs.CRE(name="CRE-733", id="733-733"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.CRE(name="CRE-732", id="732-732"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(name="CRE-733", id="733-733"), + ltype=defs.LinkTypes.LinkedTo, + ), ], hyperlink="https://CWE.mitre.org/data/definitions/1004.html", sectionID="1004", @@ -79,8 +85,14 @@ def iter_content(self, chunk_size=None): sectionID="1007", section="Another CWE", links=[ - defs.Link(document=defs.CRE(name="CRE-451", id="451-451"),ltype=defs.LinkTypes.LinkedTo), - defs.Link(document=defs.CRE(name="CRE-632", id="632-632"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.CRE(name="CRE-451", id="451-451"), + ltype=defs.LinkTypes.LinkedTo, + ), + defs.Link( + document=defs.CRE(name="CRE-632", id="632-632"), + ltype=defs.LinkTypes.LinkedTo, + ), ], ), ] diff --git a/application/tests/db_test.py b/application/tests/db_test.py index a71d941f..1db0545a 100644 --- a/application/tests/db_test.py +++ b/application/tests/db_test.py @@ -1,4 +1,4 @@ -from pprint import pprint +import networkx as nx from application.utils.gap_analysis import make_resources_key, make_subresources_key import string import random @@ -11,7 +11,6 @@ from copy import copy, deepcopy from pprint import pprint from typing import Any, Dict, List, Union -import redis from flask import json as flask_json import yaml @@ -34,6 +33,10 @@ def setUp(self) -> None: sqla.create_all() self.collection = db.Node_collection().with_graph() + self.collection.graph.with_graph( + graph=nx.DiGraph(), graph_data=[] + ) # initialize the graph singleton for the tests to be unique + collection = self.collection dbcre = collection.add_cre( @@ -231,7 +234,9 @@ def test_export(self) -> None: ), ltype=defs.LinkTypes.LinkedTo, ), - defs.Link(document=defs.Code(name="co0"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.Code(name="co0"), ltype=defs.LinkTypes.LinkedTo + ), ], ), defs.Standard( @@ -518,12 +523,26 @@ def test_get_CREs(self) -> None: version="gc3.1.2", ) + parent_cre = db.CRE( + external_id="999-999", description="parent cre", name="pcre" + ) + parent_cre2 = db.CRE( + external_id="888-888", description="parent cre2", name="pcre2" + ) + partOf_cre = db.CRE( + external_id="777-777", description="part of cre", name="poc" + ) + collection.session.add(dbc1) collection.session.add(dbc2) collection.session.add(dbc3) collection.session.add(dbs1) collection.session.add(dbs2) collection.session.add(db_id_only) + + collection.session.add(parent_cre) + collection.session.add(parent_cre2) + collection.session.add(partOf_cre) collection.session.commit() collection.session.add( @@ -534,11 +553,65 @@ def test_get_CREs(self) -> None: ) collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs1.id)) + collection.session.add( + db.InternalLinks( + type=defs.LinkTypes.Contains.value, + group=parent_cre.id, + cre=partOf_cre.id, + ) + ) + collection.session.add( + db.InternalLinks( + type=defs.LinkTypes.Contains.value, + group=parent_cre2.id, + cre=partOf_cre.id, + ) + ) collection.session.commit() - cd1 = defs.CRE(id="123-123", description="gcCD1", name="gcC1", links=[]) - cd2 = defs.CRE(description="gcCD2", name="gcC2", id="444-444") - cd3 = defs.CRE(description="gcCD3", name="gcC3", id="555-555") + # we can retrieve children cres + self.assertEqual( + [ + db.CREfromDB(parent_cre).add_link( + defs.Link( + document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains + ) + ) + ], + collection.get_CREs(external_id=parent_cre.external_id), + ) + self.assertEqual( + [ + db.CREfromDB(parent_cre2).add_link( + defs.Link( + document=db.CREfromDB(partOf_cre), ltype=defs.LinkTypes.Contains + ) + ) + ], + collection.get_CREs(external_id=parent_cre2.external_id), + ) + + # we can retrieve children cres with inverted multiple (PartOf) links to their parents + self.assertEqual( + [ + db.CREfromDB(partOf_cre) + .add_link( + defs.Link( + document=db.CREfromDB(parent_cre), ltype=defs.LinkTypes.PartOf + ) + ) + .add_link( + defs.Link( + document=db.CREfromDB(parent_cre2), ltype=defs.LinkTypes.PartOf + ) + ) + ], + collection.get_CREs(external_id=partOf_cre.external_id), + ) + + cd1 = defs.CRE(id="123-123", description="gcCD1", name="gcC1") + cd2 = defs.CRE(id="444-444", description="gcCD2", name="gcC2") + cd3 = defs.CRE(id="555-555", description="gcCD3", name="gcC3") c_id_only = defs.CRE( id="666-666", description="c_get_by_internal_id_only", name="cgbiio" ) @@ -570,34 +643,49 @@ def test_get_CREs(self) -> None: shallow_cd1.links = [] cd2.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) cd3.add_link(defs.Link(ltype=defs.LinkTypes.PartOf, document=shallow_cd1)) + + # empty returns empty self.assertEqual([], collection.get_CREs()) + # getting "group cre 1" by name returns gcC1 res = collection.get_CREs(name="gcC1") self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting "group cre 1" by id returns gcC1 res = collection.get_CREs(external_id="123-123") self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting "group cre 1" by partial id returns gcC1 res = collection.get_CREs(external_id="12%", partial=True) self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting "group cre 1" by partial name returns gcC1, gcC2 and gcC3 res = collection.get_CREs(name="gcC%", partial=True) + self.assertEqual(3, len(res)) + self.assertCountEqual( + [expected[0].todict(), cd2.todict(), cd3.todict()], + [r.todict() for r in res], + ) + # getting "group cre 1" by partial name and partial id returns gcC1 res = collection.get_CREs(external_id="1%", name="gcC%", partial=True) self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting "group cre 1" by description returns gcC1 res = collection.get_CREs(description="gcCD1") self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting "group cre 1" by partial id and partial description returns gcC1 res = collection.get_CREs(external_id="1%", description="gcC%", partial=True) self.assertEqual(len(expected), len(res)) self.assertCountEqual(expected[0].todict(), res[0].todict()) + # getting all the gcC* cres by partial name and partial description returns gcC1, gcC2, gcC3 res = collection.get_CREs(description="gcC%", name="gcC%", partial=True) want = [expected[0], cd2, cd3] for el in res: @@ -611,9 +699,10 @@ def test_get_CREs(self) -> None: self.assertEqual([], collection.get_CREs(external_id="1234")) self.assertEqual([], collection.get_CREs(name="gcC5")) + # add a standard to gcC1 collection.session.add(db.Links(type="Linked To", cre=dbc1.id, node=dbs2.id)) - only_gcS2 = deepcopy(expected) + only_gcS2 = deepcopy(expected) # save a copy of the current expected expected[0].add_link( defs.Link( ltype=defs.LinkTypes.LinkedTo, @@ -626,9 +715,11 @@ def test_get_CREs(self) -> None: ), ) ) + # we can retrieve the cre with the standard res = collection.get_CREs(name="gcC1") self.assertCountEqual(expected[0].todict(), res[0].todict()) + # we can retrieve ONLY the standard res = collection.get_CREs(name="gcC1", include_only=["gcS2"]) self.assertDictEqual(only_gcS2[0].todict(), res[0].todict()) @@ -646,6 +737,8 @@ def test_get_CREs(self) -> None: ) .add_link(defs.Link(ltype=defs.LinkTypes.Contains, document=ccd3)) ] + + # if the standard is not linked, we retrieve as normal res = collection.get_CREs(name="gcC1", include_only=["gcS0"]) self.assertEqual(no_standards, res) @@ -733,7 +826,13 @@ def test_get_nodes_with_pagination(self) -> None: collection.session.commit() for cre, standard in links: - collection.session.add(db.Links(cre=docs[cre].id, node=docs[standard].id,type=defs.LinkTypes.LinkedTo)) + collection.session.add( + db.Links( + cre=docs[cre].id, + node=docs[standard].id, + type=defs.LinkTypes.LinkedTo, + ) + ) collection.session.commit() expected = [ @@ -774,7 +873,9 @@ def test_get_nodes_with_pagination(self) -> None: version="4", links=[ defs.Link( - document=defs.CRE(name="C1", description="CD1", id="123-123"),ltype=defs.LinkTypes.LinkedTo) + document=defs.CRE(name="C1", description="CD1", id="123-123"), + ltype=defs.LinkTypes.LinkedTo, + ) ], ) ] @@ -813,27 +914,6 @@ def test_add_internal_link(self) -> None: higher=cres["dbca"], lower=cres["dbcb"], ltype=defs.LinkTypes.Related ) - # no cycle, free to insert - self.collection.add_internal_link( - higher=cres["dbcb"], lower=cres["dbcc"], ltype=defs.LinkTypes.Related - ) - - # introdcues a cycle, should not be inserted - self.collection.add_internal_link( - higher=cres["dbcc"], lower=cres["dbca"], ltype=defs.LinkTypes.Related - ) - - # cycles are not inserted branch - none_res = ( - self.collection.session.query(db.InternalLinks) - .filter( - db.InternalLinks.group == cres["dbcc"].id, - db.InternalLinks.cre == cres["dbca"].id, - ) - .one_or_none() - ) - self.assertIsNone(none_res) - # "happy path, internal link exists" res = ( self.collection.session.query(db.InternalLinks) @@ -845,6 +925,10 @@ def test_add_internal_link(self) -> None: ) self.assertEqual((res.group, res.cre), (cres["dbca"].id, cres["dbcb"].id)) + # no cycle, free to insert + self.collection.add_internal_link( + higher=cres["dbcb"], lower=cres["dbcc"], ltype=defs.LinkTypes.Related + ) res = ( self.collection.session.query(db.InternalLinks) .filter( @@ -855,6 +939,22 @@ def test_add_internal_link(self) -> None: ) self.assertEqual((res.group, res.cre), (cres["dbcb"].id, cres["dbcc"].id)) + # introdcues a cycle, should not be inserted + self.collection.add_internal_link( + higher=cres["dbcc"], lower=cres["dbca"], ltype=defs.LinkTypes.Related + ) + + # cycles are not inserted branch + none_res = ( + self.collection.session.query(db.InternalLinks) + .filter( + db.InternalLinks.group == cres["dbcc"].id, + db.InternalLinks.cre == cres["dbca"].id, + ) + .one_or_none() + ) + self.assertIsNone(none_res) + def test_text_search(self) -> None: """Given: a cre(id="111-111"23-456,name=foo,description='lorem ipsum foo+bar') @@ -942,17 +1042,7 @@ def test_text_search(self) -> None: self.maxDiff = None for k, val in expected.items(): res = self.collection.text_search(k) - try: - self.assertCountEqual(res, val) - except Exception as e: - pprint(k) - pprint("|" * 99) - pprint(res) - pprint("|" * 99) - pprint(val) - pprint("|" * 99) - input() - raise e + self.assertCountEqual(res, val) def test_dbNodeFromNode(self) -> None: data = { @@ -1125,9 +1215,6 @@ def test_get_root_cres(self): cres[0].add_link( defs.Link(document=cres[2].shallow_copy(), ltype=defs.LinkTypes.Contains) ) - cres[0].add_link( - defs.Link(document=cres[5].shallow_copy(), ltype=defs.LinkTypes.Related) - ) cres[1].add_link( defs.Link(document=cres[3].shallow_copy(), ltype=defs.LinkTypes.Contains) ) @@ -1138,7 +1225,6 @@ def test_get_root_cres(self): cres[3].add_link( defs.Link(document=cres[5].shallow_copy(), ltype=defs.LinkTypes.Contains) ) - cres[6].add_link( defs.Link(document=cres[7].shallow_copy(), ltype=defs.LinkTypes.PartOf) ) @@ -1151,18 +1237,12 @@ def test_get_root_cres(self): collection.add_internal_link( higher=dbcres[2], lower=dbcres[4], ltype=defs.LinkTypes.Contains ) - collection.add_internal_link( - higher=dbcres[5], lower=dbcres[0], ltype=defs.LinkTypes.Related - ) collection.add_internal_link( higher=dbcres[3], lower=dbcres[5], ltype=defs.LinkTypes.Contains ) collection.add_internal_link( - lower=dbcres[6], higher=dbcres[7], ltype=defs.LinkTypes.Contains + higher=dbcres[7], lower=dbcres[6], ltype=defs.LinkTypes.Contains ) - - collection.session.commit() - cres[7].add_link( defs.Link(document=cres[6].shallow_copy(), ltype=defs.LinkTypes.Contains) ) @@ -2152,7 +2232,13 @@ def test_all_cres_with_pagination(self): self.assertEqual(total_pages, 4) def test_get_cre_hierarchy(self) -> None: - collection = db.Node_collection().with_graph() + # this needs a clean database and a clean graph so reinit everything + # sqla.session.remove() + # sqla.drop_all() + # sqla.create_all() + collection = self.collection # db.Node_collection().with_graph() + # collection.graph.with_graph(graph=nx.DiGraph(), graph_data=[]) + _, inputDocs = export_format_data() importItems = [] for name, items in inputDocs.items(): @@ -2183,7 +2269,6 @@ def test_get_cre_hierarchy(self) -> None: collection.add_internal_link( cre=linked_item, node=dbitem, ltype=link.ltype ) - cres = inputDocs[defs.Credoctypes.CRE] c0 = [c for c in cres if c.name == "C0"][0] self.assertEqual(collection.get_cre_hierarchy(c0), 0) diff --git a/application/tests/defs_test.py b/application/tests/defs_test.py index fe16e6db..f99c08c0 100644 --- a/application/tests/defs_test.py +++ b/application/tests/defs_test.py @@ -63,7 +63,10 @@ def test_document_todict(self) -> None: id="500-500", description="desc", name="name", - links=[defs.Link(document=cre,ltype=defs.LinkTypes.Related), defs.Link(document=standard2,ltype=defs.LinkTypes.LinkedTo)], + links=[ + defs.Link(document=cre, ltype=defs.LinkTypes.Related), + defs.Link(document=standard2, ltype=defs.LinkTypes.LinkedTo), + ], tags=["tag1", "t2"], ) group_output = { diff --git a/application/tests/juiceshop_test.py b/application/tests/juiceshop_test.py index f8a111a2..474f7b2f 100644 --- a/application/tests/juiceshop_test.py +++ b/application/tests/juiceshop_test.py @@ -78,7 +78,10 @@ class fakeRequest: embeddings_text="Sensitive Data Exposure", hyperlink="https://demo.owasp-juice.shop//#/score-board?searchQuery=Access%20Log", links=[ - defs.Link(document=defs.CRE(name="CRE-123", id="123-123"),ltype=defs.LinkTypes.LinkedTo), + defs.Link( + document=defs.CRE(name="CRE-123", id="123-123"), + ltype=defs.LinkTypes.LinkedTo, + ), ], name="OWASP Juice Shop", section="Access Log", diff --git a/application/tests/oscal_utils_test.py b/application/tests/oscal_utils_test.py index 69b861be..5b51dbed 100644 --- a/application/tests/oscal_utils_test.py +++ b/application/tests/oscal_utils_test.py @@ -43,8 +43,8 @@ def test_cre_document_to_oscal(self) -> None: name=f"standard-{i}", section=f"{i}", hyperlink=f"https://example.com/{i}", - - ),ltype=defs.LinkTypes.LinkedTo, + ), + ltype=defs.LinkTypes.LinkedTo, ) ) else: @@ -54,7 +54,8 @@ def test_cre_document_to_oscal(self) -> None: name=f"tool-{i}", sectionID=f"{i}", hyperlink=f"https://example.com/{i}", - ),ltype=defs.LinkTypes.LinkedTo, + ), + ltype=defs.LinkTypes.LinkedTo, ) ) diff --git a/application/tests/spreadsheet_test.py b/application/tests/spreadsheet_test.py index 1a3d204c..996291ff 100644 --- a/application/tests/spreadsheet_test.py +++ b/application/tests/spreadsheet_test.py @@ -49,7 +49,9 @@ def test_prepare_spreadsheet_one_cre(self) -> None: ) collection.add_internal_link( - collection.add_cre(cc), collection.add_cre(cd), ltype=defs.LinkTypes.Contains + collection.add_cre(cc), + collection.add_cre(cd), + ltype=defs.LinkTypes.Contains, ) result = ExportSheet().prepare_spreadsheet(storage=collection, docs=[cc, cd]) diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index ab8d0728..b1142bb3 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -12,6 +12,7 @@ import redis import rq import os +import networkx as nx from application import create_app, sqla # type: ignore from application.tests.utils import data_gen @@ -44,6 +45,11 @@ def setUp(self) -> None: os.environ["INSECURE_REQUESTS"] = "True" sqla.create_all() + self.collection = db.Node_collection().with_graph() + self.collection.graph.with_graph( + graph=nx.DiGraph(), graph_data=[] + ) # initialize the graph singleton for the tests to be unique + def test_extend_cre_with_tag_links(self) -> None: """ Given: @@ -112,11 +118,11 @@ def test_extend_cre_with_tag_links(self) -> None: self.assertEqual(res, v) def test_find_by_id(self) -> None: - collection = db.Node_collection().with_graph() + collection = self.collection cres = { "ca": defs.CRE(id="111-111", description="CA", name="CA", tags=["ta"]), - "cd": defs.CRE(id="222-222", description="CD", name="CD", tags=["td"]), + "cd": defs.CRE(id="222-223", description="CD", name="CD", tags=["td"]), "cb": defs.CRE(id="333-333", description="CB", name="CB", tags=["tb"]), } cres["ca"].add_link( @@ -137,9 +143,11 @@ def test_find_by_id(self) -> None: ) self.maxDiff = None with self.app.test_client() as client: + # id does not exist case response = client.get(f"/rest/v1/id/9999999999") self.assertEqual(404, response.status_code) + # single id case expected = {"data": cres["ca"].todict()} response = client.get( f"/rest/v1/id/{cres['ca'].id}", @@ -148,7 +156,8 @@ def test_find_by_id(self) -> None: self.assertEqual(json.loads(response.data.decode()), expected) self.assertEqual(200, response.status_code) - md_expected = "
CRE---[222-222CD](https://www.opencre.org/cre/222-222),[111-111CA](https://www.opencre.org/cre/111-111),[333-333CB](https://www.opencre.org/cre/333-333)
" + # change the format to markdown + md_expected = "
CRE---[222-223CD](https://www.opencre.org/cre/222-223),[111-111CA](https://www.opencre.org/cre/111-111),[333-333CB](https://www.opencre.org/cre/333-333)
" md_response = client.get( f"/rest/v1/id/{cres['cd'].id}?format=md", headers={"Content-Type": "application/json"}, @@ -159,7 +168,7 @@ def test_find_by_name(self) -> None: collection = db.Node_collection().with_graph() cres = { "ca": defs.CRE(id="111-111", description="CA", name="CA", tags=["ta"]), - "cd": defs.CRE(id="222-222", description="CD", name="CD", tags=["td"]), + "cd": defs.CRE(id="222-224", description="CD", name="CD", tags=["td"]), "cb": defs.CRE(id="333-333", description="CB", name="CB", tags=["tb"]), "cc": defs.CRE(id="444-444", description="CC", name="CC", tags=["tc"]), } @@ -199,7 +208,7 @@ def test_find_by_name(self) -> None: self.assertEqual(200, response.status_code) self.assertEqual(json.loads(response.data.decode()), expected) - md_expected = "
CRE---[222-222CD](https://www.opencre.org/cre/222-222),[111-111CA](https://www.opencre.org/cre/111-111),[333-333CB](https://www.opencre.org/cre/333-333),[444-444CC](https://www.opencre.org/cre/444-444)
" + md_expected = "
CRE---[222-224CD](https://www.opencre.org/cre/222-224),[111-111CA](https://www.opencre.org/cre/111-111),[333-333CB](https://www.opencre.org/cre/333-333),[444-444CC](https://www.opencre.org/cre/444-444)
" md_response = client.get( f"/rest/v1/name/{cres['cd'].name}?format=md", headers={"Content-Type": "application/json"}, @@ -444,9 +453,9 @@ def test_find_root_cres(self) -> None: self.assertEqual(404, response.status_code) cres = { - "ca": defs.CRE(id="111-111", description="CA", name="CA", tags=["ta"]), - "cd": defs.CRE(id="222-222", description="CD", name="CD", tags=["td"]), - "cb": defs.CRE(id="333-333", description="CB", name="CB", tags=["tb"]), + "ca": defs.CRE(id="111-115", description="CA", name="CA", tags=["ta"]), + "cd": defs.CRE(id="222-225", description="CD", name="CD", tags=["td"]), + "cb": defs.CRE(id="333-335", description="CB", name="CB", tags=["tb"]), } cres["ca"].add_link( defs.Link( @@ -711,8 +720,12 @@ def test_deeplink(self) -> None: ltype=defs.LinkTypes.Contains, document=cres["cd"].shallow_copy() ) ) - cres["cd"].add_link(defs.Link(document=standards["cwe0"],ltype=defs.LinkTypes.LinkedTo)) - cres["cb"].add_link(defs.Link(document=standards["ASVS"],ltype=defs.LinkTypes.LinkedTo)) + cres["cd"].add_link( + defs.Link(document=standards["cwe0"], ltype=defs.LinkTypes.LinkedTo) + ) + cres["cb"].add_link( + defs.Link(document=standards["ASVS"], ltype=defs.LinkTypes.LinkedTo) + ) dca = collection.add_cre(cres["ca"]) dcb = collection.add_cre(cres["cb"]) @@ -726,8 +739,8 @@ def test_deeplink(self) -> None: higher=dcb, lower=dcd, ltype=defs.LinkTypes.Contains ) - collection.add_link(dcb, dasvs,ltype=defs.LinkTypes.LinkedTo) - collection.add_link(dcd, dcwe,ltype=defs.LinkTypes.LinkedTo) + collection.add_link(dcb, dasvs, ltype=defs.LinkTypes.LinkedTo) + collection.add_link(dcd, dcwe, ltype=defs.LinkTypes.LinkedTo) response = client.get("/rest/v1/deeplink/CWE?sectionid=456") self.assertEqual(404, response.status_code) diff --git a/application/web/web_main.py b/application/web/web_main.py index 2ac44d05..beab7f0e 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -102,7 +102,6 @@ def find_cre(creid: str = None, crename: str = None) -> Any: # refer # opt_osib = request.args.get("osib") opt_format = request.args.get("format") cres = database.get_CREs(external_id=creid, name=crename, include_only=include_only) - if cres: if len(cres) > 1: logger.error("get by id returned more than one results? This looks buggy")