diff --git a/CveXplore/VERSION b/CveXplore/VERSION index 698c77ba8..071e17fe5 100644 --- a/CveXplore/VERSION +++ b/CveXplore/VERSION @@ -1 +1 @@ -0.3.20.dev16 \ No newline at end of file +0.3.20.dev18 \ No newline at end of file diff --git a/CveXplore/common/data_source_connection.py b/CveXplore/common/data_source_connection.py index 2dffa303d..999e60ecd 100644 --- a/CveXplore/common/data_source_connection.py +++ b/CveXplore/common/data_source_connection.py @@ -15,18 +15,39 @@ class DatasourceConnection(CveXploreObject): objects and generic database functions """ - # hack for documentation building - if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] != "YES": - __DATA_SOURCE_CONNECTION = DatabaseConnection( - database_type="dummy", - database_init_parameters={}, - ).database_connection + def __init__(self, collection: str): + """ + Create a DatasourceConnection object + """ + super().__init__() + self._collection = collection + + @property + def datasource_connection(self): + # hack for documentation building + if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] == "YES": + return DatabaseConnection( + database_type="dummy", + database_init_parameters={}, + ).database_connection + else: + return DatabaseConnection( + database_type=self.config.DATASOURCE_TYPE, + database_init_parameters=self.config.DATASOURCE_CONNECTION_DETAILS, + ).database_connection + + @property + def datasource_collection_connection(self): + return getattr(self.datasource_connection, f"store_{self.collection}") + + @property + def collection(self): + return self._collection def to_dict(self, *print_keys: str) -> dict: """ Method to convert the entire object to a dictionary """ - if len(print_keys) != 0: full_dict = { k: v @@ -40,30 +61,8 @@ def to_dict(self, *print_keys: str) -> dict: return full_dict - def __init__(self, collection: str): - """ - Create a DatasourceConnection object - """ - super().__init__() - self.__collection = collection - def __eq__(self, other): return self.__dict__ == other.__dict__ def __ne__(self, other): return self.__dict__ != other.__dict__ - - @property - def _datasource_connection(self): - return DatasourceConnection.__DATA_SOURCE_CONNECTION - - @property - def _datasource_collection_connection(self): - return getattr( - DatasourceConnection.__DATA_SOURCE_CONNECTION, - f"store_{self.__collection}", - ) - - @property - def _collection(self): - return self.__collection diff --git a/CveXplore/database/connection/dummy/dummy.py b/CveXplore/database/connection/dummy/dummy.py index 1e671cdab..6cbe632fb 100644 --- a/CveXplore/database/connection/dummy/dummy.py +++ b/CveXplore/database/connection/dummy/dummy.py @@ -10,3 +10,6 @@ def __init__(self, **kwargs): @property def dbclient(self): return self._dbclient + + def set_handlers_for_collections(self): + pass diff --git a/CveXplore/database/helpers/generic_db.py b/CveXplore/database/helpers/generic_db.py index f0a171551..4aadd4d12 100644 --- a/CveXplore/database/helpers/generic_db.py +++ b/CveXplore/database/helpers/generic_db.py @@ -56,15 +56,13 @@ def __init__(self, collection: str): } total_fields_list = ( - self.__default_fields + self.__fields_mapping[self._collection] + self.__default_fields + self.__fields_mapping[self.collection] ) for field in total_fields_list: setattr( self, field, - GenericDatabaseFieldsFunctions( - field=field, collection=self._collection - ), + GenericDatabaseFieldsFunctions(field=field, collection=self.collection), ) def get_by_id(self, doc_id: str): @@ -78,7 +76,7 @@ def get_by_id(self, doc_id: str): except ValueError: return "Provided value is not a string nor can it be cast to one" - return self._datasource_collection_connection.find_one({"id": doc_id}) + return self.datasource_collection_connection.find_one({"id": doc_id}) def mget_by_id(self, *doc_ids: str) -> Union[Iterable[CveXploreObject], Iterable]: """ @@ -106,7 +104,7 @@ def _field_list(self, doc_id: str) -> list: map( lambda d: d.to_dict(), [ - self._datasource_collection_connection.find_one( + self.datasource_collection_connection.find_one( {"id": doc_id} ) ], @@ -139,7 +137,7 @@ def mapped_fields(self, collection: str) -> list: def __repr__(self): """String representation of object""" - return f"<< {self.__class__.__name__}:{self._collection} >>" + return f"<< {self.__class__.__name__}:{self.collection} >>" class GenericDatabaseFieldsFunctions(DatasourceConnection): @@ -164,7 +162,7 @@ def search(self, value: str): query = {self.__field: {"$regex": regex}} - return self._datasource_collection_connection.find(query) + return self.datasource_collection_connection.find(query) def find(self, value: str | dict = None): """ @@ -176,8 +174,8 @@ def find(self, value: str | dict = None): else: query = None - return self._datasource_collection_connection.find(query) + return self.datasource_collection_connection.find(query) def __repr__(self): """String representation of object""" - return f"<< GenericDatabaseFieldsFunctions:{self._collection} >>" + return f"<< GenericDatabaseFieldsFunctions:{self.collection} >>" diff --git a/CveXplore/database/helpers/specific_db.py b/CveXplore/database/helpers/specific_db.py index 93eca0d8e..594da4498 100644 --- a/CveXplore/database/helpers/specific_db.py +++ b/CveXplore/database/helpers/specific_db.py @@ -30,7 +30,7 @@ def get_cves_for_vendor( """ the_result = list( - self._datasource_collection_connection.find({"vendors": vendor}) + self.datasource_collection_connection.find({"vendors": vendor}) .limit(limit) .sort("cvss", DESCENDING) ) @@ -69,7 +69,7 @@ def get_by_id(self, doc_id: str): except ValueError: return "Provided value is not a string nor can it be cast to one" - return self._datasource_collection_connection.find_one({"id": doc_id}) + return self.datasource_collection_connection.find_one({"id": doc_id}) def _field_list(self, doc_id: str) -> list: """ @@ -84,7 +84,7 @@ def _field_list(self, doc_id: str) -> list: map( lambda d: d.to_dict(), [ - self._datasource_collection_connection.find_one( + self.datasource_collection_connection.find_one( {"id": doc_id} ) ], @@ -116,7 +116,7 @@ def search_active_cpes( query = {"$and": [{field: {"$regex": regex}}, {"deprecated": False}]} the_result = list( - self._datasource_collection_connection.find(query) + self.datasource_collection_connection.find(query) .limit(limit) .sort(field, sorting) ) @@ -136,7 +136,7 @@ def find_active_cpes( query = {"$and": [{field: value}, {"deprecated": False}]} the_result = list( - self._datasource_collection_connection.find(query) + self.datasource_collection_connection.find(query) .limit(limit) .sort(field, sorting) ) diff --git a/CveXplore/objects/capec.py b/CveXplore/objects/capec.py index 9c024565d..9d083f698 100644 --- a/CveXplore/objects/capec.py +++ b/CveXplore/objects/capec.py @@ -26,7 +26,7 @@ def iter_related_weaknessess(self): if hasattr(self, "related_weakness"): if len(self.related_weakness) != 0: for each in self.related_weakness: - cwe_doc = self._datasource_connection.store_cwe.find_one( + cwe_doc = self.datasource_connection.store_cwe.find_one( {"id": each} ) @@ -42,7 +42,7 @@ def iter_related_capecs(self): if hasattr(self, "related_capecs"): if len(self.related_capecs) != 0: for each in self.related_capecs: - capec_doc = self._datasource_connection.store_capec.find_one( + capec_doc = self.datasource_connection.store_capec.find_one( {"id": each} ) diff --git a/CveXplore/objects/cpe.py b/CveXplore/objects/cpe.py index c201988c8..9bac28e62 100644 --- a/CveXplore/objects/cpe.py +++ b/CveXplore/objects/cpe.py @@ -32,7 +32,7 @@ def iter_cves_matching_cpe(self, vuln_prod_search: bool = False): cpe_regex_string = create_cpe_regex_string(self.cpeName) - results = self._datasource_connection.store_cves.find( + results = self.datasource_connection.store_cves.find( {cpe_searchField: {"$regex": cpe_regex_string}} ).sort("cvss", DESCENDING) diff --git a/CveXplore/objects/cves.py b/CveXplore/objects/cves.py index 7e49bab9d..360fa6bbe 100644 --- a/CveXplore/objects/cves.py +++ b/CveXplore/objects/cves.py @@ -26,20 +26,20 @@ def __init__(self, **kwargs): try: if int(cwe_id): results = getattr( - self._datasource_connection, "store_cwe" + self.datasource_connection, "store_cwe" ).find_one({"id": cwe_id}) if results is not None: self.cwe = results except ValueError: pass - capecs = self._datasource_connection.store_capec.find( + capecs = self.datasource_connection.store_capec.find( {"related_weakness": {"$in": [cwe_id]}} ) setattr(self, "capec", list(capecs)) - via4s = self._datasource_connection.store_via4.find_one({"id": self.id}) + via4s = self.datasource_connection.store_via4.find_one({"id": self.id}) if via4s is not None: setattr(self, "via4_references", via4s) diff --git a/CveXplore/objects/cvexplore_object.py b/CveXplore/objects/cvexplore_object.py index 7dbe7d796..00b209fb0 100644 --- a/CveXplore/objects/cvexplore_object.py +++ b/CveXplore/objects/cvexplore_object.py @@ -2,6 +2,7 @@ CveXploreObject =============== """ +from CveXplore.common.config import Configuration class CveXploreObject(object): @@ -10,7 +11,7 @@ class CveXploreObject(object): """ def __init__(self): - pass + self.config = Configuration def __repr__(self) -> str: return f"<< {self.__class__.__name__} >>" diff --git a/CveXplore/objects/cwe.py b/CveXplore/objects/cwe.py index e57a46517..34e473848 100644 --- a/CveXplore/objects/cwe.py +++ b/CveXplore/objects/cwe.py @@ -26,7 +26,7 @@ def iter_related_weaknessess(self): if hasattr(self, "related_weaknesses"): if len(self.related_weaknesses) != 0: for each in self.related_weaknesses: - cwe_doc = self._datasource_connection.store_cwe.find_one( + cwe_doc = self.datasource_connection.store_cwe.find_one( {"id": each} ) @@ -40,7 +40,7 @@ def iter_related_capecs(self): :rtype: Capec """ - related_capecs = self._datasource_connection.store_capec.find( + related_capecs = self.datasource_connection.store_capec.find( {"related_weakness": self.id} )