From 66b26d549e0810384040cdc64a99df49b4dee091 Mon Sep 17 00:00:00 2001 From: Dylan Pulver Date: Fri, 5 Jul 2024 00:21:21 -0400 Subject: [PATCH] safety/scan --- safety/scan/command.py | 160 +++++++++++--------- safety/scan/decorators.py | 57 +++---- safety/scan/main.py | 115 ++++++++++---- safety/scan/models.py | 78 ++++++++-- safety/scan/render.py | 309 ++++++++++++++++++++++++++++++-------- safety/scan/util.py | 129 ++++++++++++++-- safety/scan/validators.py | 105 ++++++++++--- 7 files changed, 718 insertions(+), 235 deletions(-) diff --git a/safety/scan/command.py b/safety/scan/command.py index 433e8742..f272868b 100644 --- a/safety/scan/command.py +++ b/safety/scan/command.py @@ -1,5 +1,3 @@ - -from datetime import datetime from enum import Enum import itertools import logging @@ -45,16 +43,32 @@ class ScannableEcosystems(Enum): + """Enum representing scannable ecosystems.""" PYTHON = Ecosystem.PYTHON.value -def process_report(obj: Any, console: Console, report: ReportModel, output: str, - save_as: Optional[Tuple[str, Path]], **kwargs): - +def process_report( + obj: Any, console: Console, report: ReportModel, output: str, + save_as: Optional[Tuple[str, Path]], **kwargs +) -> Optional[str]: + """ + Processes and outputs the report based on the given parameters. + + Args: + obj (Any): The context object. + console (Console): The console object. + report (ReportModel): The report model. + output (str): The output format. + save_as (Optional[Tuple[str, Path]]): The save-as format and path. + kwargs: Additional keyword arguments. + + Returns: + Optional[str]: The URL of the report if uploaded, otherwise None. + """ wait_msg = "Processing report" with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status: json_format = report.as_v30().json() - + export_type, export_path = None, None if save_as: @@ -74,12 +88,12 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, spdx_version = None if export_type: spdx_version = export_type.version if export_type.version and ScanExport.is_format(export_type, ScanExport.SPDX) else None - + if not spdx_version and output: spdx_version = output.version if output.version and ScanOutput.is_format(output, ScanOutput.SPDX) else None spdx_format = render_scan_spdx(report, obj, spdx_version=spdx_version) - + if export_type is ScanExport.HTML or output is ScanOutput.HTML: html_format = render_scan_html(report, obj) @@ -89,7 +103,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, ScanExport.SPDX: spdx_format, ScanExport.SPDX_2_3: spdx_format, ScanExport.SPDX_2_2: spdx_format, - } + } output_format_mapping = { ScanOutput.JSON: json_format, @@ -106,7 +120,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, msg = f"Saving {export_type} report at: {export_path}" status.update(msg) LOG.debug(msg) - save_report_as(report.metadata.scan_type, export_type, Path(export_path), + save_report_as(report.metadata.scan_type, export_type, Path(export_path), report_to_export) report_url = None @@ -131,7 +145,7 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, f"[link]{project_url}[/link]") elif report.metadata.scan_type is ScanType.system_scan: lines.append(f"System scan report: [link]{report_url}[/link]") - + for line in lines: console.print(line, emoji=True) @@ -142,25 +156,30 @@ def process_report(obj: Any, console: Console, report: ReportModel, output: str, if output is ScanOutput.JSON: kwargs = {"json": report_to_output} else: - kwargs = {"data": report_to_output} + kwargs = {"data": report_to_output} console.print_json(**kwargs) else: console.print(report_to_output) console.quiet = True - + return report_url -def generate_updates_arguments() -> list: - """Generates a list of file types and update limits for apply fixes.""" +def generate_updates_arguments() -> List: + """ + Generates a list of file types and update limits for apply fixes. + + Returns: + List: A list of file types and update limits. + """ fixes = [] limit_type = SecurityUpdates.UpdateLevel.PATCH - DEFAULT_FILE_TYPES = [FileType.REQUIREMENTS_TXT, FileType.PIPENV_LOCK, + DEFAULT_FILE_TYPES = [FileType.REQUIREMENTS_TXT, FileType.PIPENV_LOCK, FileType.POETRY_LOCK, FileType.VIRTUAL_ENVIRONMENT] fixes.extend([(default_file_type, limit_type) for default_file_type in DEFAULT_FILE_TYPES]) - + return fixes @@ -197,7 +216,7 @@ def scan(ctx: typer.Context, ] = ScanOutput.SCREEN, detailed_output: Annotated[bool, typer.Option("--detailed-output", - help=SCAN_DETAILED_OUTPUT, + help=SCAN_DETAILED_OUTPUT, show_default=False) ] = False, save_as: Annotated[Optional[Tuple[ScanExport, Path]], @@ -221,7 +240,7 @@ def scan(ctx: typer.Context, )] = None, apply_updates: Annotated[bool, typer.Option("--apply-fixes", - help=SCAN_APPLY_FIXES, + help=SCAN_APPLY_FIXES, show_default=False) ] = False ): @@ -229,10 +248,12 @@ def scan(ctx: typer.Context, Scans a project (defaulted to the current directory) for supply-chain security and configuration issues """ + # Generate update arguments if apply updates option is enabled fixes_target = [] if apply_updates: fixes_target = generate_updates_arguments() + # Ensure save_as params are correctly set if not all(save_as): ctx.params["save_as"] = None @@ -240,19 +261,21 @@ def scan(ctx: typer.Context, ecosystems = [Ecosystem(member.value) for member in list(ScannableEcosystems)] to_include = {file_type: paths for file_type, paths in ctx.obj.config.scan.include_files.items() if file_type.ecosystem in ecosystems} - file_finder = FileFinder(target=target, ecosystems=ecosystems, + # Initialize file finder + file_finder = FileFinder(target=target, ecosystems=ecosystems, max_level=ctx.obj.config.scan.max_depth, - exclude=ctx.obj.config.scan.ignore, + exclude=ctx.obj.config.scan.ignore, include_files=to_include, console=console) + # Download necessary assets for each handler for handler in file_finder.handlers: if handler.ecosystem: wait_msg = "Fetching Safety's vulnerability database..." with console.status(wait_msg, spinner=DEFAULT_SPINNER): handler.download_required_assets(ctx.obj.auth.client) - + # Start scanning the project directory wait_msg = "Scanning project directory" path = None @@ -260,7 +283,7 @@ def scan(ctx: typer.Context, with console.status(wait_msg, spinner=DEFAULT_SPINNER): path, file_paths = file_finder.search() - print_detected_ecosystems_section(console, file_paths, + print_detected_ecosystems_section(console, file_paths, include_safety_prjs=True) target_ecosystems = ", ".join([member.value for member in ecosystems]) @@ -274,7 +297,7 @@ def scan(ctx: typer.Context, count = 0 ignored = set() - + affected_count = 0 dependency_vuln_detected = False @@ -287,18 +310,21 @@ def scan(ctx: typer.Context, requirements_txt_found = False display_apply_fix_suggestion = False + # Process each file for dependencies and vulnerabilities with console.status(wait_msg, spinner=DEFAULT_SPINNER) as status: - for path, analyzed_file in process_files(paths=file_paths, + for path, analyzed_file in process_files(paths=file_paths, config=config): count += len(analyzed_file.dependency_results.dependencies) + # Update exit code if vulnerabilities are found if exit_code == 0 and analyzed_file.dependency_results.failed: exit_code = EXIT_CODE_VULNERABILITIES_FOUND + # Handle ignored vulnerabilities for detailed output if detailed_output: vulns_ignored = analyzed_file.dependency_results.ignored_vulns_data \ .values() - ignored_vulns_data = itertools.chain(vulns_ignored, + ignored_vulns_data = itertools.chain(vulns_ignored, ignored_vulns_data) ignored.update(analyzed_file.dependency_results.ignored_vulns.keys()) @@ -309,7 +335,7 @@ def scan(ctx: typer.Context, def sort_vulns_by_score(vuln: Vulnerability) -> int: if vuln.severity and vuln.severity.cvssv3: return vuln.severity.cvssv3.get("base_score", 0) - + return 0 to_fix_spec = [] @@ -327,10 +353,10 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: for spec in affected_specifications: if file_matched_for_fix: to_fix_spec.append(spec) - + console.print() vulns_to_report = sorted( - [vuln for vuln in spec.vulnerabilities if not vuln.ignored], + [vuln for vuln in spec.vulnerabilities if not vuln.ignored], key=sort_vulns_by_score, reverse=True) @@ -346,14 +372,14 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: console.print(Padding(f"{msg}]", (0, 0, 0, 1)), emoji=True, overflow="crop") - + if detailed_output or vulns_found < 3: for vuln in vulns_to_report: - render_to_console(vuln, console, - rich_kwargs={"emoji": True, + render_to_console(vuln, console, + rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=detailed_output) - + lines = [] # Put remediation here @@ -381,16 +407,16 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: console.print(Padding(line, (0, 0, 0, 1)), emoji=True) console.print( - Padding(f"Learn more: [link]{spec.remediation.more_info_url}[/link]", - (0, 0, 0, 1)), emoji=True) + Padding(f"Learn more: [link]{spec.remediation.more_info_url}[/link]", + (0, 0, 0, 1)), emoji=True) else: console.print() console.print(f":white_check_mark: [file_title]{path.relative_to(target)}: No issues found.[/file_title]", emoji=True) if(ctx.obj.auth.stage == Stage.development - and analyzed_file.ecosystem == Ecosystem.PYTHON - and analyzed_file.file_type == FileType.REQUIREMENTS_TXT + and analyzed_file.ecosystem == Ecosystem.PYTHON + and analyzed_file.file_type == FileType.REQUIREMENTS_TXT and any(affected_specifications) and not apply_updates): display_apply_fix_suggestion = True @@ -405,12 +431,12 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: if file_matched_for_fix: to_fix_files.append((file, to_fix_spec)) - files.append(file) + files.append(file) if display_apply_fix_suggestion: console.print() print_fixes_section(console, requirements_txt_found, detailed_output) - + console.print() print_brief(console, ctx.obj.project, count, affected_count, fixes_count) @@ -418,18 +444,18 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: is_detailed_output=detailed_output, ignored_vulns_data=ignored_vulns_data) - + version = ctx.obj.schema metadata = ctx.obj.metadata telemetry = ctx.obj.telemetry ctx.obj.project.files = files report = ReportModel(version=version, - metadata=metadata, + metadata=metadata, telemetry=telemetry, files=[], projects=[ctx.obj.project]) - + report_url = process_report(ctx.obj, console, report, **{**ctx.params}) project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}" @@ -440,7 +466,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: no_output = output is not ScanOutput.SCREEN prompt = output is ScanOutput.SCREEN - + # TODO: rename that 'no_output' confusing name if not no_output: console.print() @@ -462,11 +488,11 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: if any(policy_limits): update_limits = [policy_limit.value for policy_limit in policy_limits] - - fixes = process_fixes_scan(file_to_fix, + + fixes = process_fixes_scan(file_to_fix, specs_to_fix, update_limits, output, no_output=no_output, prompt=prompt) - + if not no_output: console.print("-" * console.size.width) @@ -484,7 +510,7 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int: @scan_system_app.command( cls=SafetyCLICommand, help=CLI_SYSTEM_SCAN_COMMAND_HELP, - options_metavar="[COMMAND-OPTIONS]", + options_metavar="[COMMAND-OPTIONS]", name=CMD_SYSTEM_NAME, epilog=DEFAULT_EPILOG) @handle_cmd_exception @inject_metadata @@ -521,7 +547,7 @@ def system_scan(ctx: typer.Context, typer.Option( help=SYSTEM_SCAN_OUTPUT_HELP, show_default=False) - ] = SystemScanOutput.SCREEN, + ] = SystemScanOutput.SCREEN, save_as: Annotated[Optional[Tuple[SystemScanExport, Path]], typer.Option( help=SYSTEM_SCAN_SAVE_AS_HELP, @@ -575,9 +601,9 @@ def system_scan(ctx: typer.Context, for file_type, paths in target_paths.items(): current = file_paths.get(file_type, set()) current.update(paths) - file_paths[file_type] = current + file_paths[file_type] = current - scan_project_command = get_command_for(name=CMD_PROJECT_NAME, + scan_project_command = get_command_for(name=CMD_PROJECT_NAME, typer_instance=scan_project_app) projects_dirs = set() @@ -587,12 +613,12 @@ def system_scan(ctx: typer.Context, with console.status(":mag:", spinner=DEFAULT_SPINNER) as status: # Handle projects first if FileType.SAFETY_PROJECT.value in file_paths.keys(): - projects_file_paths = file_paths[FileType.SAFETY_PROJECT.value] + projects_file_paths = file_paths[FileType.SAFETY_PROJECT.value] basic_params = ctx.params.copy() basic_params.pop("targets", None) prjs_console = Console(quiet=True) - + for project_path in projects_file_paths: projects_dirs.add(project_path.parent) project_dir = str(project_path.parent) @@ -607,7 +633,7 @@ def system_scan(ctx: typer.Context, if not project or not project.id: LOG.warn(f"{project_path} parsed but project id is not defined or valid.") continue - + if not ctx.obj.platform_enabled: msg = f"project found and skipped, navigate to `{project.project_path}` and scan this project with ‘safety scan’" console.print(f"{project.id}: {msg}") @@ -615,8 +641,8 @@ def system_scan(ctx: typer.Context, msg = f"Existing project found at {project_dir}" console.print(f"{project.id}: {msg}") - project_data[project.id] = {"path": project_dir, - "report_url": None, + project_data[project.id] = {"path": project_dir, + "report_url": None, "project_url": None, "failed_exception": None} @@ -642,7 +668,7 @@ def system_scan(ctx: typer.Context, "save_as": (None, None), "upload_request_id": upload_request_id, "local_policy": local_policy_file, "console": prjs_console} try: - # TODO: Refactor to avoid calling invoke, also, launch + # TODO: Refactor to avoid calling invoke, also, launch # this on background. console.print( Padding(f"Running safety scan for {project.id} project", @@ -660,7 +686,7 @@ def system_scan(ctx: typer.Context, (0, 0, 0, 1)), emoji=True) LOG.exception(f"Failed to run scan on project {project.id}, " \ f"Upload request ID: {upload_request_id}. Reason {e}") - + console.print() file_paths.pop(FileType.SAFETY_PROJECT.value, None) @@ -670,18 +696,18 @@ def system_scan(ctx: typer.Context, status.update(":mag: Finishing projects processing.") for k, f_paths in file_paths.items(): - file_paths[k] = {fp for fp in f_paths - if not should_exclude(excludes=projects_dirs, + file_paths[k] = {fp for fp in f_paths + if not should_exclude(excludes=projects_dirs, to_analyze=fp)} - + pkgs_count = 0 file_count = 0 venv_count = 0 for path, analyzed_file in process_files(paths=file_paths, config=config): status.update(f":mag: {path}") - files.append(FileModel(location=path, - file_type=analyzed_file.file_type, + files.append(FileModel(location=path, + file_type=analyzed_file.file_type, results=analyzed_file.dependency_results)) file_pkg_count = len(analyzed_file.dependency_results.dependencies) @@ -718,7 +744,7 @@ def system_scan(ctx: typer.Context, pkgs_count += file_pkg_count console.print(f":package: {file_pkg_count} {msg} in {path}", emoji=True) - + if affected_pkgs_count <= 0: msg = "No vulnerabilities found" else: @@ -738,7 +764,7 @@ def system_scan(ctx: typer.Context, telemetry=telemetry, files=files, projects=projects) - + console.print() total_count = sum([finder.file_count for finder in file_finders], 0) console.print(f"Searched {total_count:,} files for dependency security issues") @@ -749,16 +775,16 @@ def system_scan(ctx: typer.Context, console.print() proccessed = dict(filter( - lambda item: item[1]["report_url"] and item[1]["project_url"], + lambda item: item[1]["report_url"] and item[1]["project_url"], project_data.items())) - + if proccessed: run_word = "runs" if len(proccessed) == 1 else "run" console.print(f"Project {pluralize('scan', len(proccessed))} {run_word} on {len(proccessed)} existing {pluralize('project', len(proccessed))}:") for prj, data in proccessed.items(): console.print(f"[bold]{prj}[/bold] at {data['path']}") - for detail in [f"{prj} dashboard: {data['project_url']}"]: + for detail in [f"{prj} dashboard: {data['project_url']}"]: console.print(Padding(detail, (0, 0, 0, 1)), emoji=True, overflow="crop") process_report(ctx.obj, console, report, **{**ctx.params}) diff --git a/safety/scan/decorators.py b/safety/scan/decorators.py index 2f41a7c8..29c9e7c9 100644 --- a/safety/scan/decorators.py +++ b/safety/scan/decorators.py @@ -4,7 +4,7 @@ from pathlib import Path from random import randint import sys -from typing import List, Optional +from typing import Any, List, Optional from rich.padding import Padding from safety_schemas.models import ConfigModel, ProjectModel @@ -29,7 +29,10 @@ LOG = logging.getLogger(__name__) -def initialize_scan(ctx, console): +def initialize_scan(ctx: Any, console: Console) -> None: + """ + Initializes the scan by setting platform_enabled based on the response from the server. + """ data = None try: @@ -48,7 +51,7 @@ def initialize_scan(ctx, console): def scan_project_command_init(func): """ - Make general verifications before each scan command. + Decorator to make general verifications before each project scan command. """ @wraps(func) def inner(ctx, policy_file_path: Optional[Path], target: Path, @@ -62,7 +65,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, console.quiet = True if not ctx.obj.auth.is_valid(): - process_auth_status_not_ready(console=console, + process_auth_status_not_ready(console=console, auth=ctx.obj.auth, ctx=ctx) upload_request_id = kwargs.pop("upload_request_id", None) @@ -109,12 +112,12 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, cloud_policy = None if ctx.obj.platform_enabled: - cloud_policy = print_wait_policy_download(console, (download_policy, - {"session": session, + cloud_policy = print_wait_policy_download(console, (download_policy, + {"session": session, "project_id": ctx.obj.project.id, "stage": stage, "branch": branch})) - + ctx.obj.project.policy = resolve_policy(local_policy, cloud_policy) config = ctx.obj.project.policy.config \ if ctx.obj.project.policy and ctx.obj.project.policy.config \ @@ -145,10 +148,10 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"} else: details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"} - + if ctx.obj.project.id: details["Project"] = ctx.obj.project.id - + if ctx.obj.project.git: details[" Git branch"] = ctx.obj.project.git.branch @@ -156,7 +159,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, msg = "None, using Safety CLI default policies" - if ctx.obj.project.policy: + if ctx.obj.project.policy: if ctx.obj.project.policy.source is PolicySource.cloud: msg = f"fetched from Safety Platform, " \ "ignoring any local Safety CLI policy files" @@ -170,7 +173,7 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, for k,v in details.items(): console.print(f"[scan_meta_title]{k}[/scan_meta_title]: {v}") - + print_announcements(console=console, ctx=ctx) console.print() @@ -185,10 +188,10 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path, def scan_system_command_init(func): """ - Make general verifications before each system scan command. + Decorator to make general verifications before each system scan command. """ @wraps(func) - def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], + def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], output: SystemScanOutput, console: Console = main_console, *args, **kwargs): ctx.obj.console = console @@ -198,8 +201,8 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], console.quiet = True if not ctx.obj.auth.is_valid(): - process_auth_status_not_ready(console=console, - auth=ctx.obj.auth, ctx=ctx) + process_auth_status_not_ready(console=console, + auth=ctx.obj.auth, ctx=ctx) initialize_scan(ctx, console) @@ -229,12 +232,12 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], ctx.obj.config = config - if not any(targets): + if not any(targets): if any(config.scan.system_targets): targets = [Path(t).expanduser().absolute() for t in config.scan.system_targets] else: targets = [Path("/")] - + ctx.obj.metadata.scan_locations = targets console.print() @@ -244,8 +247,8 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], details = {"Account": f"{ctx.obj.auth.name}, {ctx.obj.auth.email}", "Scan stage": ctx.obj.auth.stage} - - if ctx.obj.system_scan_policy: + + if ctx.obj.system_scan_policy: if ctx.obj.system_scan_policy.source is PolicySource.cloud: policy_type = "remote" else: @@ -259,9 +262,9 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], for k,v in details.items(): console.print(f"[bold]{k}[/bold]: {v}") - + if ctx.obj.system_scan_policy: - + dirs = [ign for ign in ctx.obj.config.scan.ignore if Path(ign).is_dir()] policy_details = [ @@ -273,17 +276,17 @@ def inner(ctx, policy_file_path: Optional[Path], targets: List[Path], console.print( Padding(policy_detail, (0, 0, 0, 1)), emoji=True) - + print_announcements(console=console, ctx=ctx) console.print() - + kwargs.update({"targets": targets}) result = func(ctx, *args, **kwargs) return result - return inner - + return inner + def inject_metadata(func): """ @@ -304,7 +307,7 @@ def inner(ctx, *args, **kwargs): if not scan_type: raise SafetyException("Missing scan_type.") - + if scan_type is ScanType.scan: if not target: raise SafetyException("Missing target.") @@ -319,7 +322,7 @@ def inner(ctx, *args, **kwargs): telemetry=telemetry, schema_version=ReportSchemaVersion.v3_0 ) - + ctx.obj.schema = ReportSchemaVersion.v3_0 ctx.obj.metadata = metadata ctx.obj.telemetry = telemetry diff --git a/safety/scan/main.py b/safety/scan/main.py index 52508276..3b6d2e71 100644 --- a/safety/scan/main.py +++ b/safety/scan/main.py @@ -25,11 +25,20 @@ PROJECT_CONFIG_NAME = "name" -def download_policy(session: SafetyAuthSession, - project_id: str, - stage: Stage, - branch: Optional[str]) -> Optional[PolicyFileModel]: - result = session.download_policy(project_id=project_id, stage=stage, +def download_policy(session: SafetyAuthSession, project_id: str, stage: Stage, branch: Optional[str]) -> Optional[PolicyFileModel]: + """ + Downloads the policy file from the cloud for the given project and stage. + + Args: + session (SafetyAuthSession): SafetyAuthSession object for authentication. + project_id (str): The ID of the project. + stage (Stage): The stage of the project. + branch (Optional[str]): The branch of the project (optional). + + Returns: + Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None. + """ + result = session.download_policy(project_id=project_id, stage=stage, branch=branch) if result and "uuid" in result and result["uuid"]: @@ -62,28 +71,44 @@ def download_policy(session: SafetyAuthSession, source=PolicySource.cloud, location=None, config=config) - + return None def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel: + """ + Loads an unverified project from the configuration file located at the project root. + + Args: + project_root (Path): The root directory of the project. + + Returns: + UnverifiedProjectModel: An instance of UnverifiedProjectModel. + """ config = configparser.ConfigParser() project_path = project_root / PROJECT_CONFIG config.read(project_path) id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None) url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None) - name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) + name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None) created = True if id: created = False - - return UnverifiedProjectModel(id=id, url_path=url, - name=name, project_path=project_path, + + return UnverifiedProjectModel(id=id, url_path=url, + name=name, project_path=project_path, created=created) -def save_project_info(project: ProjectModel, project_path: Path): +def save_project_info(project: ProjectModel, project_path: Path) -> None: + """ + Saves the project information to the configuration file. + + Args: + project (ProjectModel): The ProjectModel object containing project information. + project_path (Path): The path to the configuration file. + """ config = configparser.ConfigParser() config.read(project_path) @@ -95,12 +120,21 @@ def save_project_info(project: ProjectModel, project_path: Path): config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path if project.name: config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name - + with open(project_path, 'w') as configfile: - config.write(configfile) + config.write(configfile) def load_policy_file(path: Path) -> Optional[PolicyFileModel]: + """ + Loads a policy file from the specified path. + + Args: + path (Path): The path to the policy file. + + Returns: + Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None. + """ config = None if not path or not path.exists(): @@ -118,13 +152,21 @@ def load_policy_file(path: Path) -> Optional[PolicyFileModel]: LOG.error(f"Wrong YML file for policy file {path}.", exc_info=True) raise SafetyError(f"{err}, details: {e}") - return PolicyFileModel(id=str(path), source=PolicySource.local, + return PolicyFileModel(id=str(path), source=PolicySource.local, location=path, config=config) -def resolve_policy(local_policy: Optional[PolicyFileModel], - cloud_policy: Optional[PolicyFileModel]) \ - -> Optional[PolicyFileModel]: +def resolve_policy(local_policy: Optional[PolicyFileModel], cloud_policy: Optional[PolicyFileModel]) -> Optional[PolicyFileModel]: + """ + Resolves the policy to be used, preferring cloud policy over local policy. + + Args: + local_policy (Optional[PolicyFileModel]): The local policy file model (optional). + cloud_policy (Optional[PolicyFileModel]): The cloud policy file model (optional). + + Returns: + Optional[PolicyFileModel]: The resolved PolicyFileModel object. + """ policy = None if cloud_policy: @@ -135,20 +177,37 @@ def resolve_policy(local_policy: Optional[PolicyFileModel], return policy -def save_report_as(scan_type: ScanType, export_type: ScanExport, at: Path, report: Any): - tag = int(time.time()) +def save_report_as(scan_type: ScanType, export_type: ScanExport, at: Path, report: Any) -> None: + """ + Saves the scan report to the specified location. + + Args: + scan_type (ScanType): The type of scan. + export_type (ScanExport): The type of export. + at (Path): The path to save the report. + report (Any): The report content. + """ + tag = int(time.time()) + + if at.is_dir(): + at = at / Path( + f"{scan_type.value}-{export_type.get_default_file_name(tag=tag)}") + + with open(at, 'w+') as report_file: + report_file.write(report) - if at.is_dir(): - at = at / Path( - f"{scan_type.value}-{export_type.get_default_file_name(tag=tag)}") - with open(at, 'w+') as report_file: - report_file.write(report) +def process_files(paths: Dict[str, Set[Path]], config: Optional[ConfigModel] = None) -> Generator[Tuple[Path, InspectableFile], None, None]: + """ + Processes the files and yields each file path along with its inspectable file. + Args: + paths (Dict[str, Set[Path]]): A dictionary of file paths by file type. + config (Optional[ConfigModel]): The configuration model (optional). -def process_files(paths: Dict[str, Set[Path]], - config: Optional[ConfigModel] = None) -> \ - Generator[Tuple[Path, InspectableFile], None, None]: + Yields: + Tuple[Path, InspectableFile]: A tuple of file path and inspectable file. + """ if not config: config = ConfigModel() @@ -158,7 +217,7 @@ def process_files(paths: Dict[str, Set[Path]], continue for f_path in f_paths: with InspectableFileContext(f_path, file_type=file_type) as inspectable_file: - if inspectable_file and inspectable_file.file_type: + if inspectable_file and inspectable_file.file_type: inspectable_file.inspect(config=config) inspectable_file.remediate() yield f_path, inspectable_file diff --git a/safety/scan/models.py b/safety/scan/models.py index 54a32895..86ffc21a 100644 --- a/safety/scan/models.py +++ b/safety/scan/models.py @@ -5,10 +5,22 @@ from pydantic.dataclasses import dataclass class FormatMixin: + """ + Mixin class providing format-related utilities for Enum classes. + """ @classmethod - def is_format(cls, format_sub: Optional[Enum], format_instance: Enum): - """ Check if the value is a variant of the specified format. """ + def is_format(cls, format_sub: Optional[Enum], format_instance: Enum) -> bool: + """ + Check if the value is a variant of the specified format. + + Args: + format_sub (Optional[Enum]): The format to check. + format_instance (Enum): The instance of the format to compare against. + + Returns: + bool: True if the format matches, otherwise False. + """ if not format_sub: return False @@ -17,19 +29,27 @@ def is_format(cls, format_sub: Optional[Enum], format_instance: Enum): prefix = format_sub.value.split('@')[0] return prefix == format_instance.value - + @property - def version(self): - """ Return the version of the format. """ + def version(self) -> Optional[str]: + """ + Return the version of the format. + + Returns: + Optional[str]: The version of the format if available, otherwise None. + """ result = self.value.split('@') if len(result) == 2: return result[1] - + return None class ScanOutput(FormatMixin, str, Enum): + """ + Enum representing different scan output formats. + """ JSON = "json" SPDX = "spdx" SPDX_2_3 = "spdx@2.3" @@ -39,19 +59,36 @@ class ScanOutput(FormatMixin, str, Enum): SCREEN = "screen" NONE = "none" - def is_silent(self): + def is_silent(self) -> bool: + """ + Check if the output format is silent. + + Returns: + bool: True if the output format is silent, otherwise False. + """ return self in (ScanOutput.JSON, ScanOutput.SPDX, ScanOutput.SPDX_2_3, ScanOutput.SPDX_2_2, ScanOutput.HTML) class ScanExport(FormatMixin, str, Enum): + """ + Enum representing different scan export formats. + """ JSON = "json" SPDX = "spdx" SPDX_2_3 = "spdx@2.3" SPDX_2_2 = "spdx@2.2" - HTML = "html" + HTML = "html" + + def get_default_file_name(self, tag: int) -> str: + """ + Get the default file name for the export format. - def get_default_file_name(self, tag: int): - + Args: + tag (int): A unique tag to include in the file name. + + Returns: + str: The default file name. + """ if self is ScanExport.JSON: return f"safety-report-{tag}.json" elif self in [ScanExport.SPDX, ScanExport.SPDX_2_3, ScanExport.SPDX_2_2]: @@ -63,19 +100,34 @@ def get_default_file_name(self, tag: int): class SystemScanOutput(str, Enum): + """ + Enum representing different system scan output formats. + """ JSON = "json" SCREEN = "screen" - def is_silent(self): - return self in (SystemScanOutput.JSON,) + def is_silent(self) -> bool: + """ + Check if the output format is silent. + + Returns: + bool: True if the output format is silent, otherwise False. + """ + return self in (SystemScanOutput.JSON,) class SystemScanExport(str, Enum): + """ + Enum representing different system scan export formats. + """ JSON = "json" @dataclass class UnverifiedProjectModel(): + """ + Data class representing an unverified project model. + """ id: Optional[str] project_path: Path created: bool name: Optional[str] = None - url_path: Optional[str] = None + url_path: Optional[str] = None diff --git a/safety/scan/render.py b/safety/scan/render.py index f5f1da2f..9eef7642 100644 --- a/safety/scan/render.py +++ b/safety/scan/render.py @@ -5,7 +5,7 @@ import logging from pathlib import Path import time -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Tuple from rich.prompt import Prompt from rich.text import Text from rich.console import Console @@ -28,6 +28,16 @@ import datetime def render_header(targets: List[Path], is_system_scan: bool) -> Text: + """ + Render the header text for the scan. + + Args: + targets (List[Path]): List of target paths for the scan. + is_system_scan (bool): Indicates if the scan is a system scan. + + Returns: + Text: Rendered header text. + """ version = get_safety_version() scan_datetime = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z") @@ -38,14 +48,29 @@ def render_header(targets: List[Path], is_system_scan: bool) -> Text: return Text.from_markup( f"[bold]Safety[/bold] {version} {action}\n{scan_datetime}") -def print_header(console, targets: List[Path], is_system_scan: bool = False): +def print_header(console, targets: List[Path], is_system_scan: bool = False) -> None: + """ + Print the header for the scan. + + Args: + console (Console): The console for output. + targets (List[Path]): List of target paths for the scan. + is_system_scan (bool): Indicates if the scan is a system scan. + """ console.print(render_header(targets, is_system_scan), markup=True) -def print_announcements(console, ctx): +def print_announcements(console: Console, ctx: typer.Context): + """ + Print announcements from Safety. + + Args: + console (Console): The console for output. + ctx (typer.Context): The context of the Typer command. + """ colors = {"error": "red", "warning": "yellow", "info": "default"} - announcements = safety.get_announcements(ctx.obj.auth.client, - telemetry=ctx.obj.config.telemetry_enabled, + announcements = safety.get_announcements(ctx.obj.auth.client, + telemetry=ctx.obj.config.telemetry_enabled, with_telemetry=ctx.obj.telemetry) basic_announcements = get_basic_announcements(announcements, False) @@ -53,12 +78,19 @@ def print_announcements(console, ctx): console.print() console.print("[bold]Safety Announcements:[/bold]") console.print() - for announcement in announcements: + for announcement in announcements: color = colors.get(announcement.get('type', "info"), "default") console.print(f"[{color}]* {announcement.get('message')}[/{color}]") -def print_detected_ecosystems_section(console, file_paths: Dict[str, Set[Path]], - include_safety_prjs: bool = True): +def print_detected_ecosystems_section(console: Console, file_paths: Dict[str, Set[Path]], include_safety_prjs: bool = True) -> None: + """ + Print detected ecosystems section. + + Args: + console (Console): The console for output. + file_paths (Dict[str, Set[Path]]): Dictionary of file paths by type. + include_safety_prjs (bool): Whether to include safety projects. + """ detected: Dict[Ecosystem, Dict[FileType, int]] = {} for file_type_key, f_paths in file_paths.items(): @@ -75,24 +107,33 @@ def print_detected_ecosystems_section(console, file_paths: Dict[str, Set[Path]], brief = "Found " file_types = [] - + for f_type, count in f_type_count.items(): file_types.append(f"{count} {f_type.human_name(plural=count>1)}") - + if len(file_types) > 1: brief += ", ".join(file_types[:-1]) + " and " + file_types[-1] else: brief += file_types[0] - + msg = f"{ecosystem.name.replace('_', ' ').title()} detected. {brief}" - + console.print(msg) -def print_brief(console, project: ProjectModel, dependencies_count: int = 0, - affected_count: int = 0, fixes_count: int = 0): +def print_brief(console: Console, project: ProjectModel, dependencies_count: int = 0, affected_count: int = 0, fixes_count: int = 0) -> None: + """ + Print a brief summary of the scan results. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + dependencies_count (int): Number of dependencies tested. + affected_count (int): Number of security issues found. + fixes_count (int): Number of fixes suggested. + """ from ..util import pluralize - if project.policy: + if project.policy: if project.policy.source is PolicySource.cloud: policy_msg = f"policy fetched from Safety Platform" else: @@ -107,8 +148,16 @@ def print_brief(console, project: ProjectModel, dependencies_count: int = 0, f"issues using {policy_msg}") console.print( f"[number]{affected_count}[/number] security {pluralize('issue', affected_count)} found, [number]{fixes_count}[/number] {pluralize('fix', fixes_count)} suggested") - -def print_fixes_section(console, requirements_txt_found: bool = False, is_detailed_output: bool = False): + +def print_fixes_section(console: Console, requirements_txt_found: bool = False, is_detailed_output: bool = False) -> None: + """ + Print the section on applying fixes. + + Args: + console (Console): The console for output. + requirements_txt_found (bool): Indicates if a requirements.txt file was found. + is_detailed_output (bool): Indicates if detailed output is enabled. + """ console.print("-" * console.size.width) console.print("Apply Fixes") console.print("-" * console.size.width) @@ -131,8 +180,17 @@ def print_fixes_section(console, requirements_txt_found: bool = False, is_detail console.print("-" * console.size.width) -def print_ignore_details(console, project: ProjectModel, ignored, - is_detailed_output: bool = False, ignored_vulns_data = None): +def print_ignore_details(console: Console, project: ProjectModel, ignored: Set[str], is_detailed_output: bool = False, ignored_vulns_data: Optional[Dict[str, Vulnerability]] = None) -> None: + """ + Print details about ignored vulnerabilities. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + ignored (Set[str]): Set of ignored vulnerabilities. + is_detailed_output (bool): Indicates if detailed output is enabled. + ignored_vulns_data (Optional[Dict[str, Vulnerability]]): Data of ignored vulnerabilities. + """ from ..util import pluralize if is_detailed_output: @@ -146,7 +204,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, unpinned_ignored = {} unpinned_ignored_pkgs = set() environment_ignored = {} - environment_ignored_pkgs = set() + environment_ignored_pkgs = set() for vuln_data in ignored_vulns_data: code = IgnoreCodes(vuln_data.ignored_code) @@ -160,7 +218,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, unpinned_ignored_pkgs.add(vuln_data.package_name) elif code is IgnoreCodes.environment_dependency: environment_ignored[vuln_data.vulnerability_id] = vuln_data - environment_ignored_pkgs.add(vuln_data.package_name) + environment_ignored_pkgs.add(vuln_data.package_name) if manual_ignored: count = len(manual_ignored) @@ -168,7 +226,7 @@ def print_ignore_details(console, project: ProjectModel, ignored, f"[number]{count}[/number] were manually ignored due to the project policy:") for vuln in manual_ignored.values(): render_to_console(vuln, console, - rich_kwargs={"emoji": True, "overflow": "crop"}, + rich_kwargs={"emoji": True, "overflow": "crop"}, detailed_output=is_detailed_output) if cvss_severity_ignored: count = len(cvss_severity_ignored) @@ -197,7 +255,19 @@ def print_ignore_details(console, project: ProjectModel, ignored, "project policy)") -def print_wait_project_verification(console, project_id, closure, on_error_delay=1): +def print_wait_project_verification(console: Console, project_id: str, closure: Tuple[Any, Dict[str, Any]], on_error_delay: int = 1) -> Any: + """ + Print a waiting message while verifying a project. + + Args: + console (Console): The console for output. + project_id (str): The project ID. + closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call. + on_error_delay (int): Delay in seconds on error. + + Returns: + Any: The status of the project verification. + """ status = None wait_msg = f"Verifying project {project_id} with Safety Platform." @@ -215,10 +285,17 @@ def print_wait_project_verification(console, project_id, closure, on_error_delay if not status: wait_msg = f'Unable to verify "{project_id}". Starting again...' time.sleep(on_error_delay) - + return status -def print_project_info(console, project: ProjectModel): +def print_project_info(console: Console, project: ProjectModel): + """ + Print information about the project. + + Args: + console (Console): The console for output. + project (ProjectModel): The project model. + """ config_msg = "loaded without policies or custom configuration." if project.policy: @@ -229,11 +306,21 @@ def print_project_info(console, project: ProjectModel): else: config_msg = " policies fetched " \ "from Safety Platform." - + msg = f"[bold]{project.id} project found[/bold] - {config_msg}" console.print(msg) -def print_wait_policy_download(console, closure) -> Optional[PolicyFileModel]: +def print_wait_policy_download(console: Console, closure: Tuple[Any, Dict[str, Any]]) -> Optional[PolicyFileModel]: + """ + Print a waiting message while downloading a policy from the cloud. + + Args: + console (Console): The console for output. + closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call. + + Returns: + Optional[PolicyFileModel]: The downloaded policy file model. + """ policy = None wait_msg = "Looking for a policy from cloud..." @@ -253,9 +340,19 @@ def print_wait_policy_download(console, closure) -> Optional[PolicyFileModel]: return policy -def prompt_project_id(console, stage: Stage, - prj_root_name: Optional[str], - do_not_exit=True) -> str: +def prompt_project_id(console: Console, stage: Stage, prj_root_name: Optional[str], do_not_exit: bool = True) -> Optional[str}: + """ + Prompt the user to set a project ID for the scan. + + Args: + console (Console): The console for output. + stage (Stage): The current stage. + prj_root_name (Optional[str]): The root name of the project. + do_not_exit (bool): Indicates if the function should not exit on failure. + + Returns: + Optional[str]: The project ID. + """ from safety.util import clean_project_id default_prj_id = clean_project_id(prj_root_name) if prj_root_name else None @@ -264,10 +361,10 @@ def prompt_project_id(console, stage: Stage, # Fail here console.print("The scan needs to be linked to a project.") raise typer.Exit(code=1) - + hint = "" - if default_prj_id: - hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" + if default_prj_id: + hint = f" If empty Safety will use [bold]{default_prj_id}[/bold]" prompt_text = f"Set a project id for this scan (no spaces).{hint}" def ask(): @@ -290,27 +387,55 @@ def ask(): return project_id -def prompt_link_project(console, prj_name: str, prj_admin_email: str) -> bool: +def prompt_link_project(console: Console, prj_name: str, prj_admin_email: str) -> bool: + """ + Prompt the user to link the scan with an existing project. + + Args: + console (Console): The console for output. + prj_name (str): The project name. + prj_admin_email (str): The project admin email. + + Returns: + bool: True if the user wants to link the scan, False otherwise. + """ console.print("[bold]Safety found an existing project with this name in your organization:[/bold]") - for detail in (f"[bold]Project name:[/bold] {prj_name}", + for detail in (f"[bold]Project name:[/bold] {prj_name}", f"[bold]Project admin:[/bold] {prj_admin_email}"): console.print(Padding(detail, (0, 0, 0, 2)), emoji=True) prompt_question = "Do you want to link this scan with this existing project?" - - answer = Prompt.ask(prompt=prompt_question, choices=["y", "n"], + + answer = Prompt.ask(prompt=prompt_question, choices=["y", "n"], default="y", show_default=True, console=console).lower() - + return answer == "y" -def render_to_console(cls: Vulnerability, console: Console, rich_kwargs, - detailed_output: bool = False): +def render_to_console(cls: Vulnerability, console: Console, rich_kwargs: Dict[str, Any], detailed_output: bool = False) -> None: + """ + Render a vulnerability to the console. + + Args: + cls (Vulnerability): The vulnerability instance. + console (Console): The console for output. + rich_kwargs (Dict[str, Any]): Additional arguments for rendering. + detailed_output (bool): Indicates if detailed output is enabled. + """ cls.__render__(console, detailed_output, rich_kwargs) -def get_render_console(entity_type): +def get_render_console(entity_type: Any) -> Any: + """ + Get the render function for a specific entity type. + + Args: + entity_type (Any): The entity type. + + Returns: + Any: The render function. + """ if entity_type is Vulnerability: def __render__(self, console: Console, detailed_output: bool, rich_kwargs): @@ -330,12 +455,12 @@ def __render__(self, console: Console, detailed_output: bool, rich_kwargs): console.print( Padding( - f"->{pre} Vuln ID [vuln_id]{self.vulnerability_id}[/vuln_id]: {severity_detail if severity_detail else ''}", + f"->{pre} Vuln ID [vuln_id]{self.vulnerability_id}[/vuln_id]: {severity_detail if severity_detail else ''}", (0, 0, 0, 2) ), **rich_kwargs) console.print( Padding( - f"{self.advisory[:advisory_length]}{'...' if len(self.advisory) > advisory_length else ''}", + f"{self.advisory[:advisory_length]}{'...' if len(self.advisory) > advisory_length else ''}", (0, 0, 0, 5) ), **rich_kwargs) @@ -347,7 +472,17 @@ def __render__(self, console: Console, detailed_output: bool, rich_kwargs): return __render__ -def render_scan_html(report: ReportModel, obj) -> str: +def render_scan_html(report: ReportModel, obj: Any) -> str: + """ + Render the scan report to HTML. + + Args: + report (ReportModel): The scan report model. + obj (Any): The object containing additional settings. + + Returns: + str: The rendered HTML report. + """ from safety.scan.command import ScannableEcosystems project = report.projects[0] if any(report.projects) else None @@ -376,30 +511,40 @@ def render_scan_html(report: ReportModel, obj) -> str: ignored_packages += len(file.results.ignored_vulns) # TODO: Get this information for the report model (?) - summary = {"scanned_packages": scanned_packages, - "affected_packages": affected_packages, + summary = {"scanned_packages": scanned_packages, + "affected_packages": affected_packages, "remediations_recommended": remediations_recommended, "ignored_vulnerabilities": ignored_vulnerabilities, "vulnerabilities": vulnerabilities} - + vulnerabilities = [] - - + + # TODO: This should be based on the configs per command ecosystems = [(f"{ecosystem.name.title()}", [file_type.human_name(plural=True) for file_type in ecosystem.file_types]) for ecosystem in [Ecosystem(member.value) for member in list(ScannableEcosystems)]] - + settings ={"audit_and_monitor": True, "platform_url": SAFETY_PLATFORM_URL, "ecosystems": ecosystems} - template_context = {"report": report, "summary": summary, "announcements": [], - "project": project, + template_context = {"report": report, "summary": summary, "announcements": [], + "project": project, "platform_enabled": obj.platform_enabled, "settings": settings, "vulns_per_file": vulns_per_file, "remed_per_file": remed_per_file} - + return parse_html(kwargs=template_context, template="scan/index.html") -def generate_spdx_creation_info(*, spdx_version: str, project_identifier: str) -> Any: +def generate_spdx_creation_info(spdx_version: str, project_identifier: str) -> Any: + """ + Generate SPDX creation information. + + Args: + spdx_version (str): The SPDX version. + project_identifier (str): The project identifier. + + Returns: + Any: The SPDX creation information. + """ from spdx_tools.spdx.model import ( Actor, ActorType, @@ -439,7 +584,17 @@ def generate_spdx_creation_info(*, spdx_version: str, project_identifier: str) - return creation_info -def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]): +def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]) -> Any: + """ + Create an external package reference for SPDX. + + Args: + package (PythonDependency): The package dependency. + version (Optional[str]): The package version. + + Returns: + Any: The external package reference. + """ from spdx_tools.spdx.model import ( ExternalPackageRef, ExternalPackageRefCategory, @@ -455,11 +610,20 @@ def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]): def create_packages(dependencies: List[PythonDependency]) -> List[Any]: + """ + Create a list of SPDX packages. + + Args: + dependencies (List[PythonDependency]): List of Python dependencies. + + Returns: + List[Any]: List of SPDX packages. + """ from spdx_tools.spdx.model.spdx_no_assertion import SpdxNoAssertion from spdx_tools.spdx.model import ( Package, - ) + ) doc_pkgs = [] pkgs_added = set([]) @@ -471,7 +635,7 @@ def create_packages(dependencies: List[PythonDependency]) -> List[Any]: if pkg_id in pkgs_added: continue pkg_ref = create_pkg_ext_ref(package=dependency, version=pkg_version) - + pkg = Package( spdx_id=pkg_id, name=f"pip:{dep_name}", @@ -491,6 +655,16 @@ def create_packages(dependencies: List[PythonDependency]) -> List[Any]: def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[Any]: + """ + Create an SPDX document. + + Args: + report (ReportModel): The scan report model. + spdx_version (str): The SPDX version. + + Returns: + Optional[Any]: The SPDX document. + """ from spdx_tools.spdx.model import ( Document, Relationship, @@ -501,13 +675,13 @@ def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[ if not project: return None - + prj_id = project.id - + if not prj_id: parent_name = project.project_path.parent.name prj_id = parent_name if parent_name else str(int(time.time())) - + creation_info = generate_spdx_creation_info(spdx_version=spdx_version, project_identifier=prj_id) depedencies = iter([]) @@ -534,12 +708,23 @@ def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[ return spdx_doc -def render_scan_spdx(report: ReportModel, obj, spdx_version: Optional[str]) -> Optional[Any]: +def render_scan_spdx(report: ReportModel, obj: Any, spdx_version: Optional[str]) -> Optional[Any]: + """ + Render the scan report to SPDX format. + + Args: + report (ReportModel): The scan report model. + obj (Any): The object containing additional settings. + spdx_version (Optional[str]): The SPDX version. + + Returns: + Optional[Any]: The rendered SPDX document in JSON format. + """ from spdx_tools.spdx.writer.write_utils import ( convert, validate_and_deduplicate ) - + # Set to latest supported if a version is not specified if not spdx_version: spdx_version = "2.3" diff --git a/safety/scan/util.py b/safety/scan/util.py index 388b3f98..3fea1a5c 100644 --- a/safety/scan/util.py +++ b/safety/scan/util.py @@ -13,11 +13,20 @@ LOG = logging.getLogger(__name__) class Language(str, Enum): + """ + Enum representing supported programming languages. + """ python = "python" javascript = "javascript" safety_project = "safety_project" def handler(self) -> FileHandler: + """ + Get the appropriate file handler for the language. + + Returns: + FileHandler: The file handler for the language. + """ if self is Language.python: return PythonFileHandler() if self is Language.safety_project: @@ -26,20 +35,35 @@ def handler(self) -> FileHandler: return PythonFileHandler() class Output(Enum): + """ + Enum representing output formats. + """ json = "json" class AuthenticationType(str, Enum): + """ + Enum representing authentication types. + """ token = "token" api_key = "api_key" none = "unauthenticated" def is_allowed_in(self, stage: Stage = Stage.development) -> bool: + """ + Check if the authentication type is allowed in the given stage. + + Args: + stage (Stage): The current stage. + + Returns: + bool: True if the authentication type is allowed, otherwise False. + """ if self is AuthenticationType.none: - return False - + return False + if stage == Stage.development and self is AuthenticationType.api_key: return False - + if (not stage == Stage.development) and self is AuthenticationType.token: return False @@ -47,64 +71,137 @@ def is_allowed_in(self, stage: Stage = Stage.development) -> bool: class GIT: + """ + Class representing Git operations. + """ ORIGIN_CMD: Tuple[str, ...] = ("remote", "get-url", "origin") BRANCH_CMD: Tuple[str, ...] = ("symbolic-ref", "--short", "-q", "HEAD") TAG_CMD: Tuple[str, ...] = ("describe", "--tags", "--exact-match") - DESCRIBE_CMD: Tuple[str, ...] = ("describe", '--match=""', '--always', + DESCRIBE_CMD: Tuple[str, ...] = ("describe", '--match=""', '--always', '--abbrev=40', '--dirty') GIT_CHECK_CMD: Tuple[str, ...] = ("rev-parse", "--is-inside-work-tree") - + def __init__(self, root: Path = Path(".")) -> None: + """ + Initialize the GIT class with the given root directory. + + Args: + root (Path): The root directory for Git operations. + """ self.git = ("git", "-C", root.resolve()) def __run__(self, cmd: Tuple[str, ...], env_var: Optional[str] = None) -> Optional[str]: + """ + Run a Git command. + + Args: + cmd (Tuple[str, ...]): The Git command to run. + env_var (Optional[str]): An optional environment variable to check for the command result. + + Returns: + Optional[str]: The result of the Git command, or None if an error occurred. + """ if env_var and os.environ.get(env_var): return os.environ.get(env_var) try: - return subprocess.run(self.git + cmd, stdout=subprocess.PIPE, + return subprocess.run(self.git + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL).stdout.decode('utf-8').strip() except Exception as e: LOG.exception(e) - + return None def origin(self) -> Optional[str]: + """ + Get the Git origin URL. + + Returns: + Optional[str]: The Git origin URL, or None if an error occurred. + """ return self.__run__(self.ORIGIN_CMD, env_var="SAFETY_GIT_ORIGIN") - + def branch(self) -> Optional[str]: + """ + Get the current Git branch. + + Returns: + Optional[str]: The current Git branch, or None if an error occurred. + """ return self.__run__(self.BRANCH_CMD, env_var="SAFETY_GIT_BRANCH") def tag(self) -> Optional[str]: + """ + Get the current Git tag. + + Returns: + Optional[str]: The current Git tag, or None if an error occurred. + """ return self.__run__(self.TAG_CMD, env_var="SAFETY_GIT_TAG") - + def describe(self) -> Optional[str]: + """ + Get the Git describe output. + + Returns: + Optional[str]: The Git describe output, or None if an error occurred. + """ return self.__run__(self.DESCRIBE_CMD) - + def dirty(self, raw_describe: str) -> bool: + """ + Check if the working directory is dirty. + + Args: + raw_describe (str): The raw describe output. + + Returns: + bool: True if the working directory is dirty, otherwise False. + """ if os.environ.get("SAFETY_GIT_DIRTY") in ["0", "1"]: return bool(int(os.environ.get("SAFETY_GIT_DIRTY"))) - + return raw_describe.endswith('-dirty') def commit(self, raw_describe: str) -> Optional[str]: + """ + Get the current Git commit hash. + + Args: + raw_describe (str): The raw describe output. + + Returns: + Optional[str]: The current Git commit hash, or None if an error occurred. + """ if os.environ.get("SAFETY_GIT_COMMIT"): return os.environ.get("SAFETY_GIT_COMMIT") - try: + try: return raw_describe.split("-dirty")[0] except Exception: pass def is_git(self) -> bool: + """ + Check if the current directory is a Git repository. + + Returns: + bool: True if the current directory is a Git repository, otherwise False. + """ result = self.__run__(self.GIT_CHECK_CMD) if result == "true": return True - + return False def build_git_data(self): + """ + Build a GITModel object with Git data. + + Returns: + Optional[GITModel]: The GITModel object with Git data, or None if the directory is not a Git repository. + """ from safety_schemas.models import GITModel if self.is_git(): @@ -114,8 +211,8 @@ def build_git_data(self): if raw_describe: commit = self.commit(raw_describe) dirty = self.dirty(raw_describe) - return GITModel(branch=self.branch(), - tag=self.tag(), commit=commit, dirty=dirty, + return GITModel(branch=self.branch(), + tag=self.tag(), commit=commit, dirty=dirty, origin=self.origin()) - + return None diff --git a/safety/scan/validators.py b/safety/scan/validators.py index 12aa777f..0bfb81ea 100644 --- a/safety/scan/validators.py +++ b/safety/scan/validators.py @@ -8,19 +8,32 @@ from safety.scan.render import print_wait_project_verification, prompt_project_id, prompt_link_project from safety_schemas.models import AuthenticationType, ProjectModel, Stage +from safety.auth.utils import SafetyAuthSession MISSING_SPDX_EXTENSION_MSG = "spdx extra is not installed, please install it with: pip install safety[spdx]" -def raise_if_not_spdx_extension_installed(): +def raise_if_not_spdx_extension_installed() -> None: + """ + Raises an error if the spdx extension is not installed. + """ try: import spdx_tools.spdx except Exception as e: - raise typer.BadParameter(MISSING_SPDX_EXTENSION_MSG) + raise typer.BadParameter(MISSING_SPDX_EXTENSION_MSG) -def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]): +def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]) -> Tuple[Optional[str], Optional[Path]]: + """ + Callback function to handle save_as parameter and validate if spdx extension is installed. + + Args: + save_as (Optional[Tuple[ScanExport, Path]]): The export type and path. + + Returns: + Tuple[Optional[str], Optional[Path]]: The validated export type and path. + """ export_type, export_path = save_as if save_as else (None, None) if ScanExport.is_format(export_type, ScanExport.SPDX): @@ -28,18 +41,32 @@ def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]): return (export_type.value, export_path) if export_type and export_path else (export_type, export_path) -def output_callback(output: ScanOutput): +def output_callback(output: ScanOutput) -> str: + """ + Callback function to handle output parameter and validate if spdx extension is installed. + + Args: + output (ScanOutput): The output format. + Returns: + str: The validated output format. + """ if ScanOutput.is_format(output, ScanExport.SPDX): raise_if_not_spdx_extension_installed() - + return output.value def fail_if_not_allowed_stage(ctx: typer.Context): + """ + Fail the command if the authentication type is not allowed in the current stage. + + Args: + ctx (typer.Context): The context of the Typer command. + """ if ctx.resilient_parsing: return - + stage = ctx.obj.auth.stage auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type() @@ -51,7 +78,17 @@ def fail_if_not_allowed_stage(ctx: typer.Context): f"the '{stage}' stage.") -def save_verified_project(ctx, slug, name, project_path, url_path): +def save_verified_project(ctx: typer.Context, slug: str, name: Optional[str], project_path: Path, url_path: Optional[str]): + """ + Save the verified project information to the context and project info file. + + Args: + ctx (typer.Context): The context of the Typer command. + slug (str): The project slug. + name (Optional[str]): The project name. + project_path (Path): The project path. + url_path (Optional[str]): The project URL path. + """ ctx.obj.project = ProjectModel( id=slug, name=name, @@ -59,14 +96,28 @@ def save_verified_project(ctx, slug, name, project_path, url_path): url_path=url_path ) if ctx.obj.auth.stage is Stage.development: - save_project_info(project=ctx.obj.project, + save_project_info(project=ctx.obj.project, project_path=project_path) -def check_project(console, ctx, session, - unverified_project: UnverifiedProjectModel, - stage, - git_origin, ask_project_id=False): +def check_project(console, ctx: typer.Context, session: SafetyAuthSession, + unverified_project: UnverifiedProjectModel, stage: Stage, + git_origin: Optional[str], ask_project_id: bool = False) -> dict: + """ + Check the project against the session and stage, verifying the project if necessary. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + ask_project_id (bool): Whether to prompt for the project ID. + + Returns: + dict: The result of the project check. + """ stage = ctx.obj.auth.stage source = ctx.obj.telemetry.safety_source if ctx.obj.telemetry else None data = {"scan_stage": stage, "safety_source": source} @@ -91,17 +142,27 @@ def check_project(console, ctx, session, data[PRJ_SLUG_KEY] = unverified_project.id data[PRJ_SLUG_SOURCE_KEY] = "user" - status = print_wait_project_verification(console, data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", + status = print_wait_project_verification(console, data[PRJ_SLUG_KEY] if data.get(PRJ_SLUG_KEY, None) else "-", (session.check_project, data), on_error_delay=1) return status -def verify_project(console, ctx, session, - unverified_project: UnverifiedProjectModel, - stage, - git_origin): - +def verify_project(console, ctx: typer.Context, session: SafetyAuthSession, + unverified_project: UnverifiedProjectModel, stage: Stage, + git_origin: Optional[str]): + """ + Verify the project, linking it if necessary and saving the verified project information. + + Args: + console: The console for output. + ctx (typer.Context): The context of the Typer command. + session (SafetyAuthSession): The authentication session. + unverified_project (UnverifiedProjectModel): The unverified project model. + stage (Stage): The current stage. + git_origin (Optional[str]): The Git origin URL. + """ + verified_prj = False link_prj = True @@ -122,17 +183,17 @@ def verify_project(console, ctx, session, link_prj = prompt_link_project(prj_name=prj_name, prj_admin_email=prj_admin_email, console=console) - + if not link_prj: continue verified_prj = print_wait_project_verification( - console, unverified_slug, (session.project, + console, unverified_slug, (session.project, {"project_id": unverified_slug}), on_error_delay=1) - + if verified_prj and isinstance(verified_prj, dict) and verified_prj.get("slug", None): - save_verified_project(ctx, verified_prj["slug"], verified_prj.get("name", None), + save_verified_project(ctx, verified_prj["slug"], verified_prj.get("name", None), unverified_project.project_path, verified_prj.get("url", None)) else: verified_prj = False