diff --git a/lib/artifacts/scw_artifact.py b/lib/artifacts/scw_artifact.py index e1c7a698..96628dd6 100644 --- a/lib/artifacts/scw_artifact.py +++ b/lib/artifacts/scw_artifact.py @@ -1,8 +1,8 @@ import json import requests -import utils.utils -BASE_SCW_URL = 'https://integration-api.securecodewarrior.com/api/v1/trial?id=bugcrowd&mappingList=vrt&mappingKey=' +BASE_SCW_URL = 'https://integration-api.securecodewarrior.com\ +/api/v1/trial?id=bugcrowd&mappingList=vrt&mappingKey=' OUTPUT_FILENAME = 'scw_links.json' @@ -23,7 +23,9 @@ def scw_mapping(vrt_id): def join_vrt_id(parent_id, child_id): - return '.'.join([parent_id, child_id]) if parent_id is not None else child_id + return '.'.join( + [parent_id, child_id] + ) if parent_id is not None else child_id def generate_urls(vrt, content, parent_id=None): diff --git a/lib/tests/test_artifact_format.py b/lib/tests/test_artifact_format.py index aa3bc2e9..68354c7e 100644 --- a/lib/tests/test_artifact_format.py +++ b/lib/tests/test_artifact_format.py @@ -2,20 +2,22 @@ import os import unittest + class TestArtifactFormat(unittest.TestCase): - def setUp(self): - print("\n`---{}---`".format(self._testMethodName)) - self.scw_artifact_path = os.path.join( - utils.THIRD_PARTY_MAPPING_DIR, - utils.SCW_DIR, - utils.SCW_FILENAME - ) + def setUp(self): + print("\n`---{}---`".format(self._testMethodName)) + self.scw_artifact_path = os.path.join( + utils.THIRD_PARTY_MAPPING_DIR, + utils.SCW_DIR, + utils.SCW_FILENAME + ) + + def test_artifact_loads_valid_json(self): + self.assertTrue( + utils.get_json(self.scw_artifact_path), + self.scw_artifact_path + ' is not valid JSON.' + ) - def test_artifact_loads_valid_json(self): - self.assertTrue( - utils.get_json(self.scw_artifact_path), - self.scw_artifact_path + ' is not valid JSON.' - ) if __name__ == "__main__": unittest.main() \ No newline at end of file diff --git a/lib/tests/test_deprecated_mapping.py b/lib/tests/test_deprecated_mapping.py index 1ca25a6a..d501af91 100644 --- a/lib/tests/test_deprecated_mapping.py +++ b/lib/tests/test_deprecated_mapping.py @@ -7,8 +7,15 @@ class TestDeprecatedMapping(unittest.TestCase): def setUp(self): print("\n`---{}---`".format(self._testMethodName)) self.vrt_versions = utils.all_versions(utils.VRT_FILENAME) - self.last_tagged_version = max([Version.coerce(x) for x in self.vrt_versions.keys() if x != 'current']) - self.deprecated_json = utils.get_json(utils.DEPRECATED_MAPPING_FILENAME) + self.last_tagged_version = max( + [ + Version.coerce(x) for x in self.vrt_versions.keys() + if x != 'current' + ] + ) + self.deprecated_json = utils.get_json( + utils.DEPRECATED_MAPPING_FILENAME + ) def test_old_vrt_ids_have_current_node(self): for version, vrt in self.vrt_versions.items(): @@ -17,17 +24,28 @@ def test_old_vrt_ids_have_current_node(self): for id_list in utils.all_id_lists(vrt): vrt_id = '.'.join(id_list) if vrt_id in self.deprecated_json: - max_ver = sorted(self.deprecated_json[vrt_id].keys(), key=lambda s: map(int, s.split('.')))[-1] + max_ver = sorted( + self.deprecated_json[vrt_id].keys(), + key=lambda s: map(int, s.split('.')) + )[-1] vrt_id = self.deprecated_json[vrt_id][max_ver] id_list = vrt_id.split('.') - self.assertTrue(vrt_id == 'other' or self.check_mapping(id_list), - '%s from v%s has no mapping' % (vrt_id, version)) + self.assertTrue( + vrt_id == 'other' or self.check_mapping(id_list), + '%s from v%s has no mapping' % (vrt_id, version) + ) def test_deprecated_nodes_map_valid_node(self): for old_id, mapping in self.deprecated_json.items(): for new_version, new_id in mapping.items(): - self.assertTrue(new_id == 'other' or utils.id_valid(self.vrt_version(new_version), new_id.split('.')), - new_id + ' is not valid') + self.assertTrue( + new_id == 'other' or utils.id_valid( + self.vrt_version( + new_version + ), new_id.split('.') + ), + new_id + ' is not valid' + ) def check_mapping(self, id_list): if utils.id_valid(self.vrt_versions['current'], id_list): @@ -45,5 +63,6 @@ def vrt_version(self, version): else: self.fail('Unknown version: %s' % version) + if __name__ == "__main__": unittest.main() diff --git a/lib/tests/test_vrt.py b/lib/tests/test_vrt.py index 66976d16..97e0240a 100644 --- a/lib/tests/test_vrt.py +++ b/lib/tests/test_vrt.py @@ -5,13 +5,17 @@ import glob import os + class TestVrt(unittest.TestCase): def setUp(self): print("\n`---{}---`".format(self._testMethodName)) self.vrt = utils.get_json(utils.VRT_FILENAME) self.mappings = [ - { 'filename': f, 'name': os.path.splitext(os.path.basename(f))[0] } - for f in glob.glob(utils.MAPPING_DIR + '/**/*.json', recursive=True) if 'schema' not in f + {'filename': f, 'name': os.path.splitext(os.path.basename(f))[0]} + for f in glob.glob( + utils.MAPPING_DIR + '/**/*.json', recursive=True + ) + if 'schema' not in f ] @unittest.skip('need to decide the best way to handle this') @@ -20,7 +24,10 @@ def test_changelog_updated(self): Checks if CHANGELOG.md is being updated with the current commit and prompts the user if it isn't """ - p = subprocess.Popen('git diff HEAD --stat --staged CHANGELOG.md | wc -l', shell=True, stdout=subprocess.PIPE) + p = subprocess.Popen( + 'git diff HEAD --stat --staged CHANGELOG.md | wc -l', + shell=True, stdout=subprocess.PIPE + ) out, _err = p.communicate() self.assertGreater(int(out), 0, 'CHANGELOG.md not updated') @@ -28,7 +35,9 @@ def validate_schema(self, schema_file, data_file): schema = utils.get_json(schema_file) data = utils.get_json(data_file) jsonschema.Draft4Validator.check_schema(schema) - error = jsonschema.exceptions.best_match(jsonschema.Draft4Validator(schema).iter_errors(data)) + error = jsonschema.exceptions.best_match( + jsonschema.Draft4Validator(schema).iter_errors(data) + ) if error: raise error @@ -41,19 +50,30 @@ def test_mapping_schemas(self): f'{utils.MAPPING_DIR}/**/{mapping["name"]}.schema.json', recursive=True )[0] - self.assertTrue(os.path.isfile(schema_file), 'Missing schema file for %s mapping' % mapping['name']) + self.assertTrue( + os.path.isfile(schema_file), + 'Missing schema file for %s mapping' % mapping['name'] + ) self.validate_schema(schema_file, mapping['filename']) def all_vrt_ids_have_mapping(self, mappping_filename, key): mapping = utils.get_json(mappping_filename) keyed_mapping = utils.key_by_id(mapping['content']) - for vrt_id_list in utils.all_id_lists(self.vrt, include_internal=False): + for vrt_id_list in utils.all_id_lists( + self.vrt, include_internal=False + ): result = utils.has_mapping(keyed_mapping, vrt_id_list, key) if key == 'cwe' and not result: - print('WARNING: no ' + key + ' mapping for ' + '.'.join(vrt_id_list)) + print('WARNING: no ' + key + ' mapping for ' + '.'.join( + vrt_id_list + )) else: - self.assertTrue(utils.has_mapping(keyed_mapping, vrt_id_list, key), - 'no ' + key + ' mapping for ' + '.'.join(vrt_id_list)) + self.assertTrue( + utils.has_mapping( + keyed_mapping, vrt_id_list, key + ), + 'no ' + key + ' mapping for ' + '.'.join(vrt_id_list) + ) def test_all_vrt_ids_have_all_mappings(self): for mapping in self.mappings: @@ -63,7 +83,11 @@ def only_map_valid_ids(self, mapping_filename): vrt_ids = utils.all_id_lists(self.vrt) mapping_ids = utils.all_id_lists(utils.get_json(mapping_filename)) for id_list in mapping_ids: - self.assertIn(id_list, vrt_ids, 'invalid id in ' + mapping_filename + ' - ' + '.'.join(id_list)) + self.assertIn( + id_list, + vrt_ids, + 'invalid id in ' + mapping_filename + ' - ' + '.'.join(id_list) + ) def test_only_map_valid_ids(self): for mapping in self.mappings: diff --git a/lib/utils/utils.py b/lib/utils/utils.py index 70497821..b70676ae 100644 --- a/lib/utils/utils.py +++ b/lib/utils/utils.py @@ -10,13 +10,16 @@ SCW_DIR = 'remediation_training' THIRD_PARTY_MAPPING_DIR = 'third-party-mappings' + def get_json(filename): with open(filename) as f: return json.loads(f.read()) + def all_versions(filename): """ - Find, open and parse all tagged versions of a json file, including the current version + Find, open and parse all tagged versions of a json file, + including the current version :param filename: The filename to find :return: a dictionary of all the versions, in the form @@ -41,10 +44,12 @@ def id_valid(vrt, id_list): Check if a vrt id is valid :param vrt: The vrt object - :param id_list: The vrt id, split into components, eg ['category', 'subcategory', 'variant'] + :param id_list: The vrt id, split into components, + eg ['category', 'subcategory', 'variant'] :return: True/False """ - # this is not particularly efficient, but it's more readable than other options so until we need to care... + # this is not particularly efficient, but it's more readable than other + # options so until we need to care... return id_list in all_id_lists(vrt) @@ -53,7 +58,8 @@ def has_mapping(mapping, id_list, key): Check if a vrt id has a mapping :param mapping: The mapping object, keyed by id - :param id_list: The vrt id, split into components, eg ['category', 'subcategory', 'variant'] + :param id_list: The vrt id, split into components, + eg ['category', 'subcategory', 'variant'] :param key: The mapping key to look for, eg 'cvss_v3' :return: True/False """ @@ -72,9 +78,16 @@ def key_by_id(mapping): Converts arrays to hashes keyed by the id attribute for easier lookup. So [{'id': 'one', 'foo': 'bar'}, {'id': 'two', 'foo': 'baz'}] becomes - {'one': {'id': 'one', 'foo': 'bar'}, 'two': {'id': 'two', 'foo': 'baz'}} + { + 'one': {'id': 'one', 'foo': 'bar'}, + 'two': {'id': 'two', 'foo': 'baz'} + } """ - if isinstance(mapping, list) and isinstance(mapping[0], dict) and 'id' in mapping[0]: + if isinstance( + mapping, list + ) and isinstance( + mapping[0], dict + ) and 'id' in mapping[0]: return {x['id']: key_by_id(x) for x in mapping} elif isinstance(mapping, dict): return {k: key_by_id(v) for k, v in mapping.items()} @@ -84,10 +97,12 @@ def key_by_id(mapping): def all_id_lists(vrt, include_internal=True): """ - Get all valid vrt ids for a given vrt object, including internal nodes by default + Get all valid vrt ids for a given vrt object, including internal nodes + by default :param vrt: The vrt object - :param include_internal: Whether to include internal nodes or only leaf nodes + :param include_internal: Whether to include internal nodes or only + leaf nodes :return: ids in the form [ ['category'], @@ -98,7 +113,10 @@ def all_id_lists(vrt, include_internal=True): """ def _all_id_lists(sub_vrt, prefix): if isinstance(sub_vrt, list): - return [vrt_id for entry in sub_vrt for vrt_id in _all_id_lists(entry, prefix)] + return [ + vrt_id for entry in sub_vrt + for vrt_id in _all_id_lists(entry, prefix) + ] elif isinstance(sub_vrt, dict): if 'children' in sub_vrt: new_prefix = prefix + [sub_vrt['id']] diff --git a/lib/validate_artifacts.py b/lib/validate_artifacts.py index 355a5784..8259778c 100644 --- a/lib/validate_artifacts.py +++ b/lib/validate_artifacts.py @@ -5,7 +5,11 @@ from artifacts import scw_artifact artifact_json = utils.get_json(scw_artifact.OUTPUT_FILENAME) -repo_path = os.path.join(utils.THIRD_PARTY_MAPPING_DIR, utils.SCW_DIR, utils.SCW_FILENAME) +repo_path = os.path.join( + utils.THIRD_PARTY_MAPPING_DIR, + utils.SCW_DIR, + utils.SCW_FILENAME +) print(os.path.abspath(repo_path)) repo_json = utils.get_json(repo_path) @@ -16,5 +20,8 @@ print('SCW Document is valid!') sys.exit(0) else: - print('SCW Document is invalid, copy the artifact to the remediation training') + print( + 'SCW Document is invalid, copy the artifact to the remediation\ + training' + ) sys.exit(1)