Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor refactor #254

Merged
merged 2 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.20.dev16
0.3.20.dev18
57 changes: 28 additions & 29 deletions CveXplore/common/data_source_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,39 @@ class DatasourceConnection(CveXploreObject):
objects and generic database functions
"""

# hack for documentation building
if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] != "YES":
__DATA_SOURCE_CONNECTION = DatabaseConnection(
database_type="dummy",
database_init_parameters={},
).database_connection
def __init__(self, collection: str):
"""
Create a DatasourceConnection object
"""
super().__init__()
self._collection = collection

@property
def datasource_connection(self):
# hack for documentation building
if json.loads(os.getenv("DOC_BUILD"))["DOC_BUILD"] == "YES":
return DatabaseConnection(
database_type="dummy",
database_init_parameters={},
).database_connection
else:
return DatabaseConnection(
database_type=self.config.DATASOURCE_TYPE,
database_init_parameters=self.config.DATASOURCE_CONNECTION_DETAILS,
).database_connection

@property
def datasource_collection_connection(self):
return getattr(self.datasource_connection, f"store_{self.collection}")

@property
def collection(self):
return self._collection

def to_dict(self, *print_keys: str) -> dict:
"""
Method to convert the entire object to a dictionary
"""

if len(print_keys) != 0:
full_dict = {
k: v
Expand All @@ -40,30 +61,8 @@ def to_dict(self, *print_keys: str) -> dict:

return full_dict

def __init__(self, collection: str):
"""
Create a DatasourceConnection object
"""
super().__init__()
self.__collection = collection

def __eq__(self, other):
return self.__dict__ == other.__dict__

def __ne__(self, other):
return self.__dict__ != other.__dict__

@property
def _datasource_connection(self):
return DatasourceConnection.__DATA_SOURCE_CONNECTION

@property
def _datasource_collection_connection(self):
return getattr(
DatasourceConnection.__DATA_SOURCE_CONNECTION,
f"store_{self.__collection}",
)

@property
def _collection(self):
return self.__collection
3 changes: 3 additions & 0 deletions CveXplore/database/connection/dummy/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ def __init__(self, **kwargs):
@property
def dbclient(self):
return self._dbclient

def set_handlers_for_collections(self):
pass
18 changes: 8 additions & 10 deletions CveXplore/database/helpers/generic_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ def __init__(self, collection: str):
}

total_fields_list = (
self.__default_fields + self.__fields_mapping[self._collection]
self.__default_fields + self.__fields_mapping[self.collection]
)
for field in total_fields_list:
setattr(
self,
field,
GenericDatabaseFieldsFunctions(
field=field, collection=self._collection
),
GenericDatabaseFieldsFunctions(field=field, collection=self.collection),
)

def get_by_id(self, doc_id: str):
Expand All @@ -78,7 +76,7 @@ def get_by_id(self, doc_id: str):
except ValueError:
return "Provided value is not a string nor can it be cast to one"

return self._datasource_collection_connection.find_one({"id": doc_id})
return self.datasource_collection_connection.find_one({"id": doc_id})

def mget_by_id(self, *doc_ids: str) -> Union[Iterable[CveXploreObject], Iterable]:
"""
Expand Down Expand Up @@ -106,7 +104,7 @@ def _field_list(self, doc_id: str) -> list:
map(
lambda d: d.to_dict(),
[
self._datasource_collection_connection.find_one(
self.datasource_collection_connection.find_one(
{"id": doc_id}
)
],
Expand Down Expand Up @@ -139,7 +137,7 @@ def mapped_fields(self, collection: str) -> list:

def __repr__(self):
"""String representation of object"""
return f"<< {self.__class__.__name__}:{self._collection} >>"
return f"<< {self.__class__.__name__}:{self.collection} >>"


class GenericDatabaseFieldsFunctions(DatasourceConnection):
Expand All @@ -164,7 +162,7 @@ def search(self, value: str):

query = {self.__field: {"$regex": regex}}

return self._datasource_collection_connection.find(query)
return self.datasource_collection_connection.find(query)

def find(self, value: str | dict = None):
"""
Expand All @@ -176,8 +174,8 @@ def find(self, value: str | dict = None):
else:
query = None

return self._datasource_collection_connection.find(query)
return self.datasource_collection_connection.find(query)

def __repr__(self):
"""String representation of object"""
return f"<< GenericDatabaseFieldsFunctions:{self._collection} >>"
return f"<< GenericDatabaseFieldsFunctions:{self.collection} >>"
10 changes: 5 additions & 5 deletions CveXplore/database/helpers/specific_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_cves_for_vendor(
"""

the_result = list(
self._datasource_collection_connection.find({"vendors": vendor})
self.datasource_collection_connection.find({"vendors": vendor})
.limit(limit)
.sort("cvss", DESCENDING)
)
Expand Down Expand Up @@ -69,7 +69,7 @@ def get_by_id(self, doc_id: str):
except ValueError:
return "Provided value is not a string nor can it be cast to one"

return self._datasource_collection_connection.find_one({"id": doc_id})
return self.datasource_collection_connection.find_one({"id": doc_id})

def _field_list(self, doc_id: str) -> list:
"""
Expand All @@ -84,7 +84,7 @@ def _field_list(self, doc_id: str) -> list:
map(
lambda d: d.to_dict(),
[
self._datasource_collection_connection.find_one(
self.datasource_collection_connection.find_one(
{"id": doc_id}
)
],
Expand Down Expand Up @@ -116,7 +116,7 @@ def search_active_cpes(
query = {"$and": [{field: {"$regex": regex}}, {"deprecated": False}]}

the_result = list(
self._datasource_collection_connection.find(query)
self.datasource_collection_connection.find(query)
.limit(limit)
.sort(field, sorting)
)
Expand All @@ -136,7 +136,7 @@ def find_active_cpes(
query = {"$and": [{field: value}, {"deprecated": False}]}

the_result = list(
self._datasource_collection_connection.find(query)
self.datasource_collection_connection.find(query)
.limit(limit)
.sort(field, sorting)
)
Expand Down
4 changes: 2 additions & 2 deletions CveXplore/objects/capec.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def iter_related_weaknessess(self):
if hasattr(self, "related_weakness"):
if len(self.related_weakness) != 0:
for each in self.related_weakness:
cwe_doc = self._datasource_connection.store_cwe.find_one(
cwe_doc = self.datasource_connection.store_cwe.find_one(
{"id": each}
)

Expand All @@ -42,7 +42,7 @@ def iter_related_capecs(self):
if hasattr(self, "related_capecs"):
if len(self.related_capecs) != 0:
for each in self.related_capecs:
capec_doc = self._datasource_connection.store_capec.find_one(
capec_doc = self.datasource_connection.store_capec.find_one(
{"id": each}
)

Expand Down
2 changes: 1 addition & 1 deletion CveXplore/objects/cpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def iter_cves_matching_cpe(self, vuln_prod_search: bool = False):

cpe_regex_string = create_cpe_regex_string(self.cpeName)

results = self._datasource_connection.store_cves.find(
results = self.datasource_connection.store_cves.find(
{cpe_searchField: {"$regex": cpe_regex_string}}
).sort("cvss", DESCENDING)

Expand Down
6 changes: 3 additions & 3 deletions CveXplore/objects/cves.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ def __init__(self, **kwargs):
try:
if int(cwe_id):
results = getattr(
self._datasource_connection, "store_cwe"
self.datasource_connection, "store_cwe"
).find_one({"id": cwe_id})
if results is not None:
self.cwe = results
except ValueError:
pass

capecs = self._datasource_connection.store_capec.find(
capecs = self.datasource_connection.store_capec.find(
{"related_weakness": {"$in": [cwe_id]}}
)

setattr(self, "capec", list(capecs))

via4s = self._datasource_connection.store_via4.find_one({"id": self.id})
via4s = self.datasource_connection.store_via4.find_one({"id": self.id})

if via4s is not None:
setattr(self, "via4_references", via4s)
Expand Down
3 changes: 2 additions & 1 deletion CveXplore/objects/cvexplore_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CveXploreObject
===============
"""
from CveXplore.common.config import Configuration


class CveXploreObject(object):
Expand All @@ -10,7 +11,7 @@ class CveXploreObject(object):
"""

def __init__(self):
pass
self.config = Configuration

def __repr__(self) -> str:
return f"<< {self.__class__.__name__} >>"
4 changes: 2 additions & 2 deletions CveXplore/objects/cwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def iter_related_weaknessess(self):
if hasattr(self, "related_weaknesses"):
if len(self.related_weaknesses) != 0:
for each in self.related_weaknesses:
cwe_doc = self._datasource_connection.store_cwe.find_one(
cwe_doc = self.datasource_connection.store_cwe.find_one(
{"id": each}
)

Expand All @@ -40,7 +40,7 @@ def iter_related_capecs(self):
:rtype: Capec
"""

related_capecs = self._datasource_connection.store_capec.find(
related_capecs = self.datasource_connection.store_capec.find(
{"related_weakness": self.id}
)

Expand Down
Loading