diff --git a/entrypoint/entrypoint/data_model.py b/entrypoint/entrypoint/data_model.py new file mode 100644 index 0000000..f7028ad --- /dev/null +++ b/entrypoint/entrypoint/data_model.py @@ -0,0 +1,162 @@ +from enum import Enum + + +def parse_comma_list(value): + if not value or value == "''": + return None + result = [] + parts = value.split(',') + for part in parts: + clean_part = part.strip() + if clean_part: + result.append(clean_part) + return result + + +class ArtifactType(Enum): + REPOSITORY = "repository" + CONTAINER = "container" + BINARY = "binary" + ARCHIVE = "archive" + + +class ScanConfig: + def __init__(self, artifact_type=None, artifact_path=None, sbomgen_version=None, timeout=None, platform=None, scanners=None, skip_scanners=None, skip_files=None): + self.artifact_type = artifact_type + self.artifact_path = artifact_path + self.sbomgen_version = sbomgen_version + self.timeout = timeout + self.platform = platform + self.scanners = scanners + self.skip_scanners = skip_scanners + self.skip_files = skip_files + + @classmethod + def from_args(ScanConfig, args): + return ScanConfig( + artifact_type=ArtifactType(args.artifact_type), + artifact_path=args.artifact_path, + sbomgen_version=args.sbomgen_version, + timeout=int(args.timeout), + platform=args.platform, + scanners=parse_comma_list(args.scanners), + skip_scanners=parse_comma_list(args.skip_scanners), + skip_files=parse_comma_list(args.skip_files) + ) + + +class OutputConfig: + def __init__(self, + display_vulnerability_findings="disabled", + show_only_fixable_vulns=False, + output_sbom_path="sbom.json", + output_inspector_scan_path="inspector-scan.json", + output_inspector_scan_path_csv="inspector-scan.csv", + output_inspector_scan_path_markdown="inspector-scan.md", + output_dockerfile_scan_csv="inspector-dockerfile-scan.csv", + output_dockerfile_scan_markdown="inspector-dockerfile-scan.md", + thresholds=False, + critical_threshold=0, + high_threshold=0, + medium_threshold=0, + low_threshold=0, + other_threshold=0, + threshold_fixable_only=False): + # Convert string to boolean for type safety + if display_vulnerability_findings == "enabled": + self.display_vulnerability_findings = True + else: + self.display_vulnerability_findings = False + + self.show_only_fixable_vulns = show_only_fixable_vulns + self.output_sbom_path = output_sbom_path + self.output_inspector_scan_path = output_inspector_scan_path + self.output_inspector_scan_path_csv = output_inspector_scan_path_csv + self.output_inspector_scan_path_markdown = output_inspector_scan_path_markdown + self.output_dockerfile_scan_csv = output_dockerfile_scan_csv + self.output_dockerfile_scan_markdown = output_dockerfile_scan_markdown + self.thresholds = thresholds + self.critical_threshold = critical_threshold + self.high_threshold = high_threshold + self.medium_threshold = medium_threshold + self.low_threshold = low_threshold + self.other_threshold = other_threshold + self.threshold_fixable_only = threshold_fixable_only + + @classmethod + def from_args(OutputConfig, args): + return OutputConfig( + display_vulnerability_findings=args.display_vuln_findings, + show_only_fixable_vulns=args.show_only_fixable_vulns, + output_sbom_path=args.out_sbom, + output_inspector_scan_path=args.out_scan, + output_inspector_scan_path_csv=args.out_scan_csv, + output_inspector_scan_path_markdown=args.out_scan_markdown, + output_dockerfile_scan_csv=args.out_dockerfile_scan_csv, + output_dockerfile_scan_markdown=args.out_dockerfile_scan_md, + thresholds=args.thresholds, + critical_threshold=args.critical, + high_threshold=args.high, + medium_threshold=args.medium, + low_threshold=args.low, + other_threshold=args.other, + threshold_fixable_only=args.threshold_fixable_only + ) + + +class SBOMOutput: + def __init__(self, + file_path=None, + generation_success=False, + return_code=None, + generation_time=None, + file_size=None, + error_message=None): + self.file_path = file_path + self.generation_success = generation_success + self.return_code = return_code + self.generation_time = generation_time + self.file_size = file_size + self.error_message = error_message + + +class VulnScanOutput: + def __init__(self, + # Core scan results + scan_success=False, + return_code=None, + scan_results_file_path=None, + + # Performance/timing data + scan_time=None, + results_file_size=None, + + # Vulnerability counts + total_vulnerabilities=None, + critical_count=None, + high_count=None, + medium_count=None, + low_count=None, + other_count=None, + + # Error handling + error_message=None): + # Core scan results + self.scan_success = scan_success + self.return_code = return_code + self.scan_results_file_path = scan_results_file_path + + # Performance/timing data + self.scan_time = scan_time + self.results_file_size = results_file_size + + # Vulnerability counts + self.total_vulnerabilities = total_vulnerabilities + self.critical_count = critical_count + self.high_count = high_count + self.medium_count = medium_count + self.low_count = low_count + self.other_count = other_count + + # Error handling + self.error_message = error_message diff --git a/entrypoint/entrypoint/orchestrator.py b/entrypoint/entrypoint/orchestrator.py index aeeeab4..7b8ac43 100644 --- a/entrypoint/entrypoint/orchestrator.py +++ b/entrypoint/entrypoint/orchestrator.py @@ -13,52 +13,60 @@ def execute(args) -> int: - logging.info(f"downloading and installing inspector-sbomgen version {args.sbomgen_version}") - ret = install_sbomgen(args) + # NEW: Create structured config from legacy args (strangler fig pattern) + from entrypoint.data_model import ScanConfig, ArtifactType, OutputConfig + config = ScanConfig.from_args(args) + output_config = OutputConfig.from_args(args) + logging.info(f"Created config: artifact_type={config.artifact_type.value}, artifact_path={config.artifact_path}, sbomgen_version={config.sbomgen_version}, timeout={config.timeout}s") + logging.info(f"Created output_config: display_findings={output_config.display_vulnerability_findings}, sbom_path={output_config.output_sbom_path}") + + # Use structured configs for type-safe, maintainable code + logging.info(f"downloading and installing inspector-sbomgen version {config.sbomgen_version}") + ret = install_sbomgen(config.sbomgen_version) require_true((ret == 0), "unable to download and install inspector-sbomgen") logging.info("generating SBOM from artifact") - ret = invoke_sbomgen(args) + ret = invoke_sbomgen(args, config, output_config) require_true(ret == 0, "unable to generate SBOM with inspector-sbomgen") logging.info("scanning SBOM contents with Amazon Inspector") - ret = invoke_inspector_scan(args.out_sbom, args.out_scan) + ret = invoke_inspector_scan(output_config.output_sbom_path, output_config.output_inspector_scan_path) require_true(ret == 0, "unable to scan SBOM contents with Amazon Inspector") - set_github_actions_output('inspector_scan_results', args.out_scan) + set_github_actions_output('inspector_scan_results', output_config.output_inspector_scan_path) logging.info("tallying vulnerabilities") - succeeded, scan_result, fixed_vuln_counts = get_scan_result(args) + succeeded, scan_result, fixed_vuln_counts = get_scan_result(args, config, output_config) require_true(succeeded, "unable to tally vulnerabilities") print_vuln_count_summary(scan_result) - vuln_counts = fixed_vuln_counts if args.threshold_fixable_only else scan_result - set_env_var_if_vuln_threshold_exceeded(args, vuln_counts) + vuln_counts = fixed_vuln_counts if output_config.threshold_fixable_only else scan_result + set_env_var_if_vuln_threshold_exceeded(output_config, vuln_counts) - write_pkg_vuln_report_csv(args.out_scan_csv, scan_result) - set_github_actions_output('inspector_scan_results_csv', args.out_scan_csv) + write_pkg_vuln_report_csv(output_config.output_inspector_scan_path_csv, scan_result) + set_github_actions_output('inspector_scan_results_csv', output_config.output_inspector_scan_path_csv) - pkg_vuln_markdown = write_pkg_vuln_report_markdown(args.out_scan_markdown, scan_result) - post_pkg_vuln_github_actions_step_summary(args, pkg_vuln_markdown) - set_github_actions_output('inspector_scan_results_markdown', args.out_scan_markdown) + pkg_vuln_markdown = write_pkg_vuln_report_markdown(output_config.output_inspector_scan_path_markdown, scan_result) + post_pkg_vuln_github_actions_step_summary(output_config, pkg_vuln_markdown) + set_github_actions_output('inspector_scan_results_markdown', output_config.output_inspector_scan_path_markdown) - dockerfile.write_dockerfile_report_csv(args.out_scan, args.out_dockerfile_scan_csv) - set_github_actions_output('inspector_dockerile_scan_results_csv', args.out_dockerfile_scan_csv) + dockerfile.write_dockerfile_report_csv(output_config.output_inspector_scan_path, output_config.output_dockerfile_scan_csv) + set_github_actions_output('inspector_dockerile_scan_results_csv', output_config.output_dockerfile_scan_csv) - dockerfile.write_dockerfile_report_md(args.out_scan, args.out_dockerfile_scan_md) - set_github_actions_output('inspector_dockerile_scan_results_markdown', args.out_dockerfile_scan_md) - post_dockerfile_step_summary(args, scan_result.total_vulns()) + dockerfile.write_dockerfile_report_md(output_config.output_inspector_scan_path, output_config.output_dockerfile_scan_markdown) + set_github_actions_output('inspector_dockerile_scan_results_markdown', output_config.output_dockerfile_scan_markdown) + post_dockerfile_step_summary(output_config, scan_result.total_vulns()) return 0 -def post_dockerfile_step_summary(args, total_vulns): - if args.display_vuln_findings == "enabled" and total_vulns > 0: +def post_dockerfile_step_summary(output_config, total_vulns): + if output_config.display_vulnerability_findings and total_vulns > 0: logging.info("posting Inspector Dockerfile scan findings to GitHub Actions step summary page") dockerfile_markdown = "" try: - with open(args.out_dockerfile_scan_md, "r") as f: + with open(output_config.output_dockerfile_scan_markdown, "r") as f: dockerfile_markdown = f.read() except Exception as e: logging.debug(e) # can be spammy, so set as debug log @@ -146,7 +154,7 @@ def get_sbomgen_arch(host_cpu): return None -def invoke_sbomgen(args) -> int: +def invoke_sbomgen(args, config, output_config) -> int: sbomgen = installer.get_sbomgen_install_path() if sbomgen == "": logging.error("expected path to inspector-sbomgen but received empty string") @@ -154,57 +162,61 @@ def invoke_sbomgen(args) -> int: # marshall arguments between action.yml and cli.py path_arg = "" - if args.artifact_type.lower() == "repository": - args.artifact_type = "directory" + sbom_artifact_type = "" + if config.artifact_type == ArtifactType.REPOSITORY: + sbom_artifact_type = "directory" path_arg = "--path" - elif "container" in args.artifact_type.lower(): - args.artifact_type = "container" + elif config.artifact_type == ArtifactType.CONTAINER: + sbom_artifact_type = "container" path_arg = "--image" - elif "binary" in args.artifact_type.lower(): - args.artifact_type = "binary" + elif config.artifact_type == ArtifactType.BINARY: + sbom_artifact_type = "binary" path_arg = "--path" - elif "archive" in args.artifact_type.lower(): - args.artifact_type = "archive" + elif config.artifact_type == ArtifactType.ARCHIVE: + sbom_artifact_type = "archive" path_arg = "--path" else: logging.error( - f"expected artifact type to be 'repository', 'container', 'binary' or 'archive' but received {args.artifact_type}") + f"expected artifact type to be 'repository', 'container', 'binary' or 'archive' but received {config.artifact_type.value}") return 1 # invoke sbomgen with arguments - sbomgen_args = [args.artifact_type, - path_arg, args.artifact_path, - "--outfile", args.out_sbom, + sbomgen_args = [sbom_artifact_type, + path_arg, config.artifact_path, + "--outfile", output_config.output_sbom_path, "--disable-progress-bar", - "--timeout", args.timeout, + "--timeout", str(config.timeout), ] - if args.scanners != "''": - logging.info(f"setting --scanners: {args.scanners}") + if config.scanners: + scanners_str = ",".join(config.scanners) + logging.info(f"setting --scanners: {scanners_str}") sbomgen_args.append("--scanners") - sbomgen_args.append(args.scanners) - elif args.skip_scanners != "''": - logging.info(f"setting --skip-scanners: {args.skip_scanners}") + sbomgen_args.append(scanners_str) + elif config.skip_scanners: + skip_scanners_str = ",".join(config.skip_scanners) + logging.info(f"setting --skip-scanners: {skip_scanners_str}") sbomgen_args.append("--skip-scanners") - sbomgen_args.append(args.skip_scanners) + sbomgen_args.append(skip_scanners_str) else: pass - if args.skip_files != "''": - logging.info(f"setting --skip-files: {args.skip_files}") + if config.skip_files: + skip_files_str = ",".join(config.skip_files) + logging.info(f"setting --skip-files: {skip_files_str}") sbomgen_args.append("--skip-files") - sbomgen_args.append(args.skip_files) + sbomgen_args.append(skip_files_str) - if args.artifact_type == "container": + if config.artifact_type == ArtifactType.CONTAINER: - if args.platform: - platform_arg = args.platform.lower() + if config.platform: + platform_arg = config.platform.lower() if not is_valid_container_platform(platform_arg): logging.fatal( - f"received invalid container image platform: '{args.platform}'. Platform should be of the form 'os/cpu/variant' such as 'linux/amd64' or 'linux/arm64/v8'") + f"received invalid container image platform: '{config.platform}'. Platform should be of the form 'os/cpu/variant' such as 'linux/amd64' or 'linux/arm64/v8'") sbomgen_args.append("--platform") sbomgen_args.append(platform_arg) @@ -214,9 +226,9 @@ def invoke_sbomgen(args) -> int: # make scan results readable by any user so # github actions can upload the file as a job artifact - os.system(f"chmod 444 {args.out_sbom}") + os.system(f"chmod 444 {output_config.output_sbom_path}") - set_github_actions_output('artifact_sbom', args.out_sbom) + set_github_actions_output('artifact_sbom', output_config.output_sbom_path) return ret @@ -235,21 +247,21 @@ def invoke_inspector_scan(src_sbom, dst_scan): return ret -def get_scan_result(args) -> tuple[bool, exporter.InspectorScanResult, fixed_vulns.FixedVulns]: +def get_scan_result(args, config, output_config) -> tuple[bool, exporter.InspectorScanResult, fixed_vulns.FixedVulns]: scan_result = exporter.InspectorScanResult(vulnerabilities=[pkg_vuln.Vulnerability()]) fixed_vulns_counts = fixed_vulns.FixedVulns(criticals=0, highs=0, mediums=0, lows=0, others=0) succeeded, fixed_vulns_counts = get_fixed_vuln_counts( - args.out_scan) + output_config.output_inspector_scan_path) if succeeded is False: return False, scan_result, fixed_vulns_counts - succeeded, criticals, highs, mediums, lows, others = get_vuln_counts(args.out_scan) + succeeded, criticals, highs, mediums, lows, others = get_vuln_counts(output_config.output_inspector_scan_path) if succeeded is False: return False, scan_result, fixed_vulns_counts try: - with open(args.out_scan, "r") as f: + with open(output_config.output_inspector_scan_path, "r") as f: inspector_scan = json.load(f) vulns = pkg_vuln.parse_inspector_scan_result(inspector_scan) @@ -257,15 +269,15 @@ def get_scan_result(args) -> tuple[bool, exporter.InspectorScanResult, fixed_vul logging.error(e) return False, scan_result, fixed_vulns_counts - if args.show_only_fixable_vulns: + if output_config.show_only_fixable_vulns: for vuln in vulns: if vuln.fixed_ver == "null": vulns.remove(vuln) scan_result = exporter.InspectorScanResult( vulnerabilities=vulns, - artifact_name=args.artifact_path, - artifact_type=args.artifact_type, + artifact_name=config.artifact_path, + artifact_type=config.artifact_type.value, criticals=str(criticals), highs=str(highs), mediums=str(mediums), @@ -430,10 +442,10 @@ def get_fixed_vuln_counts(inspector_scan_path: str) -> tuple[bool, fixed_vulns.F return True, fixed_vulns_counts -def install_sbomgen(args): +def install_sbomgen(sbomgen_version): os_name = platform.system() if "Linux" in os_name: - ret = download_install_sbomgen(args.sbomgen_version, "/usr/local/bin/inspector-sbomgen") + ret = download_install_sbomgen(sbomgen_version, "/usr/local/bin/inspector-sbomgen") if not ret: return 1 @@ -469,16 +481,16 @@ def write_pkg_vuln_report_markdown(out_scan_markdown, scan_result: exporter.Insp return markdown -def set_env_var_if_vuln_threshold_exceeded(args, +def set_env_var_if_vuln_threshold_exceeded(output_config, vuln_counts: typing.Union[ exporter.InspectorScanResult, fixed_vulns.FixedVulns]): - is_exceeded = exceeds_threshold(vuln_counts.criticals, args.critical, - vuln_counts.highs, args.high, - vuln_counts.mediums, args.medium, - vuln_counts.lows, args.low, - vuln_counts.others, args.other) + is_exceeded = exceeds_threshold(vuln_counts.criticals, output_config.critical_threshold, + vuln_counts.highs, output_config.high_threshold, + vuln_counts.mediums, output_config.medium_threshold, + vuln_counts.lows, output_config.low_threshold, + vuln_counts.others, output_config.other_threshold) - if is_exceeded and args.thresholds: + if is_exceeded and output_config.thresholds: set_github_actions_output('vulnerability_threshold_exceeded', 1) else: set_github_actions_output('vulnerability_threshold_exceeded', 0) @@ -541,8 +553,8 @@ def get_summarized_findings(scan_result: exporter.InspectorScanResult): return results -def post_pkg_vuln_github_actions_step_summary(args, markdown): - if args.display_vuln_findings == "enabled": +def post_pkg_vuln_github_actions_step_summary(output_config, markdown): + if output_config.display_vulnerability_findings: logging.info("posting Inspector scan findings to GitHub Actions step summary page") exporter.post_github_step_summary(markdown) diff --git a/entrypoint/tests/test_data_model.py b/entrypoint/tests/test_data_model.py new file mode 100644 index 0000000..fecbd56 --- /dev/null +++ b/entrypoint/tests/test_data_model.py @@ -0,0 +1,488 @@ +import unittest +from entrypoint.data_model import ArtifactType, ScanConfig, OutputConfig, SBOMOutput, VulnScanOutput, parse_comma_list + + +class MockArgs: + artifact_type = 'repository' + artifact_path = './test' + sbomgen_version = 'latest' + timeout = '600' + platform = 'linux/amd64' + scanners = 'dpkg,npm,python-requirements' + skip_scanners = 'binaries,alpine-apk' + skip_files = './media,/tmp/foo' + + +class MockOutputArgs: + display_vuln_findings = 'enabled' + show_only_fixable_vulns = True + out_sbom = 'test_sbom.json' + out_scan = 'test_scan.json' + out_scan_csv = 'test_scan.csv' + out_scan_markdown = 'test_scan.md' + out_dockerfile_scan_csv = 'test_dockerfile.csv' + out_dockerfile_scan_md = 'test_dockerfile.md' + thresholds = True + critical = 5 + high = 10 + medium = 15 + low = 20 + other = 25 + threshold_fixable_only = True + + +class TestDataModel(unittest.TestCase): + + def test_artifact_type_repository_exists(self): + self.assertEqual(ArtifactType.REPOSITORY.value, "repository") + + def test_artifact_type_container_exists(self): + self.assertEqual(ArtifactType.CONTAINER.value, "container") + + def test_artifact_type_binary_exists(self): + self.assertEqual(ArtifactType.BINARY.value, "binary") + + def test_artifact_type_archive_exists(self): + self.assertEqual(ArtifactType.ARCHIVE.value, "archive") + + def test_artifact_type_from_string(self): + self.assertEqual(ArtifactType("repository"), ArtifactType.REPOSITORY) + + def test_artifact_type_invalid_string_raises_exception(self): + with self.assertRaises(ValueError): + ArtifactType("invalid") + + def test_artifact_type_empty_string_raises_exception(self): + with self.assertRaises(ValueError): + ArtifactType("") + + def test_scan_config_can_be_created(self): + config = ScanConfig() + self.assertIsNotNone(config) + + def test_scan_config_has_artifact_type(self): + config = ScanConfig(artifact_type=ArtifactType.REPOSITORY) + self.assertEqual(config.artifact_type, ArtifactType.REPOSITORY) + + def test_scan_config_has_artifact_path(self): + config = ScanConfig(artifact_path="./test") + self.assertEqual(config.artifact_path, "./test") + + def test_scan_config_from_args_exists(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertIsNotNone(config) + + def test_scan_config_from_args_converts_artifact_type(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.artifact_type, ArtifactType.REPOSITORY) + + def test_scan_config_from_args_converts_artifact_path(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.artifact_path, './test') + + def test_scan_config_repository_type_comparison(self): + config = ScanConfig(artifact_type=ArtifactType.REPOSITORY) + self.assertEqual(config.artifact_type, ArtifactType.REPOSITORY) + + def test_scan_config_container_type_comparison(self): + config = ScanConfig(artifact_type=ArtifactType.CONTAINER) + self.assertEqual(config.artifact_type, ArtifactType.CONTAINER) + + def test_scan_config_binary_type_comparison(self): + config = ScanConfig(artifact_type=ArtifactType.BINARY) + self.assertEqual(config.artifact_type, ArtifactType.BINARY) + + def test_scan_config_archive_type_comparison(self): + config = ScanConfig(artifact_type=ArtifactType.ARCHIVE) + self.assertEqual(config.artifact_type, ArtifactType.ARCHIVE) + + def test_scan_config_container_platform_check(self): + config = ScanConfig(artifact_type=ArtifactType.CONTAINER) + self.assertEqual(config.artifact_type, ArtifactType.CONTAINER) + + def test_scan_config_artifact_type_value(self): + config = ScanConfig(artifact_type=ArtifactType.REPOSITORY) + self.assertEqual(config.artifact_type.value, "repository") + + def test_scan_config_display_mapping(self): + config = ScanConfig(artifact_type=ArtifactType.REPOSITORY) + display_type = "repository" if config.artifact_type == ArtifactType.REPOSITORY else config.artifact_type.value + self.assertEqual(display_type, "repository") + + def test_scan_config_has_sbomgen_version(self): + config = ScanConfig(sbomgen_version="1.8.0") + self.assertEqual(config.sbomgen_version, "1.8.0") + + def test_scan_config_from_args_converts_sbomgen_version(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.sbomgen_version, "latest") + + def test_scan_config_has_timeout(self): + config = ScanConfig(timeout=300) + self.assertEqual(config.timeout, 300) + + def test_scan_config_from_args_converts_timeout(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.timeout, 600) + + def test_scan_config_has_platform(self): + config = ScanConfig(platform="linux/amd64") + self.assertEqual(config.platform, "linux/amd64") + + def test_scan_config_from_args_converts_platform(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.platform, "linux/amd64") + + def test_scan_config_has_scanners(self): + config = ScanConfig(scanners=["dpkg", "npm"]) + self.assertEqual(config.scanners, ["dpkg", "npm"]) + + def test_scan_config_from_args_converts_scanners(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.scanners, ["dpkg", "npm", "python-requirements"]) + + def test_parse_comma_list_with_valid_string(self): + result = parse_comma_list("dpkg,npm,python-requirements") + self.assertEqual(result, ["dpkg", "npm", "python-requirements"]) + + def test_parse_comma_list_with_whitespace(self): + result = parse_comma_list("dpkg, npm , python-requirements ") + self.assertEqual(result, ["dpkg", "npm", "python-requirements"]) + + def test_parse_comma_list_with_empty_string(self): + result = parse_comma_list("") + self.assertIsNone(result) + + def test_parse_comma_list_with_quoted_empty_string(self): + result = parse_comma_list("''") + self.assertIsNone(result) + + def test_parse_comma_list_with_none(self): + result = parse_comma_list(None) + self.assertIsNone(result) + + def test_scan_config_has_skip_scanners(self): + config = ScanConfig(skip_scanners=["binaries", "alpine-apk"]) + self.assertEqual(config.skip_scanners, ["binaries", "alpine-apk"]) + + def test_scan_config_from_args_converts_skip_scanners(self): + mock_args = MockArgs() + config = ScanConfig.from_args(mock_args) + self.assertEqual(config.skip_scanners, ["binaries", "alpine-apk"]) + + def test_scan_config_sbomgen_version_access(self): + config = ScanConfig(sbomgen_version="1.8.0") + self.assertEqual(config.sbomgen_version, "1.8.0") + + def test_scan_config_timeout_access(self): + config = ScanConfig(timeout=300) + self.assertEqual(config.timeout, 300) + + def test_scan_config_scanners_list_access(self): + config = ScanConfig(scanners=["dpkg", "npm"]) + self.assertEqual(config.scanners, ["dpkg", "npm"]) + + def test_scan_config_skip_scanners_list_access(self): + config = ScanConfig(skip_scanners=["binaries", "alpine-apk"]) + self.assertEqual(config.skip_scanners, ["binaries", "alpine-apk"]) + + def test_scan_config_skip_files_list_access(self): + config = ScanConfig(skip_files=["./media", "/tmp/foo"]) + self.assertEqual(config.skip_files, ["./media", "/tmp/foo"]) + + def test_scan_config_platform_access(self): + config = ScanConfig(platform="linux/amd64") + self.assertEqual(config.platform, "linux/amd64") + + def test_scan_config_comprehensive_access(self): + config = ScanConfig( + artifact_type=ArtifactType.CONTAINER, + artifact_path="./test", + sbomgen_version="1.8.0", + timeout=300, + platform="linux/amd64" + ) + self.assertEqual(config.artifact_type, ArtifactType.CONTAINER) + self.assertEqual(config.artifact_path, "./test") + self.assertEqual(config.sbomgen_version, "1.8.0") + self.assertEqual(config.timeout, 300) + self.assertEqual(config.platform, "linux/amd64") + + def test_output_config_can_be_created(self): + config = OutputConfig() + self.assertIsNotNone(config) + + def test_output_config_has_display_vulnerability_findings(self): + config = OutputConfig(display_vulnerability_findings="enabled") + self.assertEqual(config.display_vulnerability_findings, True) + + def test_output_config_has_show_only_fixable_vulns(self): + config = OutputConfig(show_only_fixable_vulns=True) + self.assertEqual(config.show_only_fixable_vulns, True) + + def test_output_config_has_output_sbom_path(self): + config = OutputConfig(output_sbom_path="./sbom_123.json") + self.assertEqual(config.output_sbom_path, "./sbom_123.json") + + def test_output_config_has_output_inspector_scan_path(self): + config = OutputConfig(output_inspector_scan_path="inspector_scan_123.json") + self.assertEqual(config.output_inspector_scan_path, "inspector_scan_123.json") + + def test_output_config_uses_action_yml_defaults(self): + config = OutputConfig() + self.assertEqual(config.display_vulnerability_findings, False) + self.assertEqual(config.show_only_fixable_vulns, False) + self.assertEqual(config.output_sbom_path, "sbom.json") + self.assertEqual(config.output_inspector_scan_path, "inspector-scan.json") + self.assertEqual(config.output_inspector_scan_path_csv, "inspector-scan.csv") + self.assertEqual(config.output_inspector_scan_path_markdown, "inspector-scan.md") + self.assertEqual(config.output_dockerfile_scan_csv, "inspector-dockerfile-scan.csv") + self.assertEqual(config.output_dockerfile_scan_markdown, "inspector-dockerfile-scan.md") + self.assertEqual(config.thresholds, False) + self.assertEqual(config.critical_threshold, 0) + self.assertEqual(config.high_threshold, 0) + self.assertEqual(config.medium_threshold, 0) + self.assertEqual(config.low_threshold, 0) + self.assertEqual(config.other_threshold, 0) + self.assertEqual(config.threshold_fixable_only, False) + + def test_output_config_converts_enabled_to_true(self): + config = OutputConfig(display_vulnerability_findings="enabled") + self.assertEqual(config.display_vulnerability_findings, True) + + def test_output_config_converts_disabled_to_false(self): + config = OutputConfig(display_vulnerability_findings="disabled") + self.assertEqual(config.display_vulnerability_findings, False) + + def test_output_config_converts_invalid_string_to_false(self): + config = OutputConfig(display_vulnerability_findings="invalid") + self.assertEqual(config.display_vulnerability_findings, False) + + def test_output_config_has_output_inspector_scan_path_csv(self): + config = OutputConfig(output_inspector_scan_path_csv="inspector_scan_123.csv") + self.assertEqual(config.output_inspector_scan_path_csv, "inspector_scan_123.csv") + + def test_output_config_has_output_inspector_scan_path_markdown(self): + config = OutputConfig(output_inspector_scan_path_markdown="inspector_scan_123.md") + self.assertEqual(config.output_inspector_scan_path_markdown, "inspector_scan_123.md") + + def test_output_config_has_output_dockerfile_scan_csv(self): + config = OutputConfig(output_dockerfile_scan_csv="dockerfile_scan_123.csv") + self.assertEqual(config.output_dockerfile_scan_csv, "dockerfile_scan_123.csv") + + def test_output_config_has_output_dockerfile_scan_markdown(self): + config = OutputConfig(output_dockerfile_scan_markdown="dockerfile_scan_123.md") + self.assertEqual(config.output_dockerfile_scan_markdown, "dockerfile_scan_123.md") + + def test_output_config_has_thresholds(self): + config = OutputConfig(thresholds=True) + self.assertEqual(config.thresholds, True) + + def test_output_config_has_critical_threshold(self): + config = OutputConfig(critical_threshold=5) + self.assertEqual(config.critical_threshold, 5) + + def test_output_config_has_threshold_fixable_only(self): + config = OutputConfig(threshold_fixable_only=True) + self.assertEqual(config.threshold_fixable_only, True) + + def test_output_config_from_args_exists(self): + mock_args = MockOutputArgs() + config = OutputConfig.from_args(mock_args) + self.assertIsNotNone(config) + + def test_output_config_from_args_converts_display_vuln_findings(self): + mock_args = MockOutputArgs() + config = OutputConfig.from_args(mock_args) + self.assertEqual(config.display_vulnerability_findings, True) + + def test_output_config_from_args_converts_output_paths(self): + mock_args = MockOutputArgs() + config = OutputConfig.from_args(mock_args) + self.assertEqual(config.output_sbom_path, 'test_sbom.json') + self.assertEqual(config.output_inspector_scan_path, 'test_scan.json') + self.assertEqual(config.output_inspector_scan_path_csv, 'test_scan.csv') + + def test_output_config_from_args_converts_thresholds(self): + mock_args = MockOutputArgs() + config = OutputConfig.from_args(mock_args) + self.assertEqual(config.thresholds, True) + self.assertEqual(config.critical_threshold, 5) + self.assertEqual(config.high_threshold, 10) + self.assertEqual(config.threshold_fixable_only, True) + + def test_output_config_from_args_converts_boolean_fields(self): + mock_args = MockOutputArgs() + config = OutputConfig.from_args(mock_args) + self.assertEqual(config.show_only_fixable_vulns, True) + self.assertEqual(config.threshold_fixable_only, True) + + def test_sbom_output_can_be_created(self): + output = SBOMOutput() + self.assertIsNotNone(output) + + def test_sbom_output_has_file_path(self): + output = SBOMOutput(file_path="/tmp/test.json") + self.assertEqual(output.file_path, "/tmp/test.json") + + def test_sbom_output_has_generation_success(self): + output = SBOMOutput(generation_success=True) + self.assertEqual(output.generation_success, True) + + def test_sbom_output_has_return_code(self): + output = SBOMOutput(return_code=0) + self.assertEqual(output.return_code, 0) + + def test_sbom_output_has_generation_time(self): + output = SBOMOutput(generation_time=5.2) + self.assertEqual(output.generation_time, 5.2) + + def test_sbom_output_has_file_size(self): + output = SBOMOutput(file_size=1024) + self.assertEqual(output.file_size, 1024) + + def test_sbom_output_has_error_message(self): + output = SBOMOutput(error_message="Generation failed") + self.assertEqual(output.error_message, "Generation failed") + + def test_sbom_output_defaults(self): + output = SBOMOutput() + self.assertIsNone(output.file_path) + self.assertEqual(output.generation_success, False) + self.assertIsNone(output.return_code) + self.assertIsNone(output.generation_time) + self.assertIsNone(output.file_size) + self.assertIsNone(output.error_message) + + def test_sbom_output_success_scenario(self): + output = SBOMOutput( + file_path="/tmp/sbom.json", + generation_success=True, + return_code=0, + generation_time=3.5, + file_size=2048 + ) + self.assertEqual(output.file_path, "/tmp/sbom.json") + self.assertEqual(output.generation_success, True) + self.assertEqual(output.return_code, 0) + self.assertEqual(output.generation_time, 3.5) + self.assertEqual(output.file_size, 2048) + self.assertIsNone(output.error_message) + + def test_sbom_output_failure_scenario(self): + output = SBOMOutput( + generation_success=False, + return_code=1, + error_message="Timeout exceeded" + ) + self.assertEqual(output.generation_success, False) + self.assertEqual(output.return_code, 1) + self.assertEqual(output.error_message, "Timeout exceeded") + self.assertIsNone(output.file_path) + self.assertIsNone(output.generation_time) + + def test_vuln_scan_output_can_be_created(self): + output = VulnScanOutput() + self.assertIsNotNone(output) + + def test_vuln_scan_output_has_scan_success(self): + output = VulnScanOutput(scan_success=True) + self.assertEqual(output.scan_success, True) + + def test_vuln_scan_output_has_return_code(self): + output = VulnScanOutput(return_code=0) + self.assertEqual(output.return_code, 0) + + def test_vuln_scan_output_has_scan_results_file_path(self): + output = VulnScanOutput(scan_results_file_path="/tmp/scan.json") + self.assertEqual(output.scan_results_file_path, "/tmp/scan.json") + + def test_vuln_scan_output_has_scan_time(self): + output = VulnScanOutput(scan_time=12.5) + self.assertEqual(output.scan_time, 12.5) + + def test_vuln_scan_output_has_results_file_size(self): + output = VulnScanOutput(results_file_size=4096) + self.assertEqual(output.results_file_size, 4096) + + def test_vuln_scan_output_has_vulnerability_counts(self): + output = VulnScanOutput( + total_vulnerabilities=25, + critical_count=5, + high_count=8, + medium_count=7, + low_count=3, + other_count=2 + ) + self.assertEqual(output.total_vulnerabilities, 25) + self.assertEqual(output.critical_count, 5) + self.assertEqual(output.high_count, 8) + self.assertEqual(output.medium_count, 7) + self.assertEqual(output.low_count, 3) + self.assertEqual(output.other_count, 2) + + def test_vuln_scan_output_has_error_message(self): + output = VulnScanOutput(error_message="API timeout") + self.assertEqual(output.error_message, "API timeout") + + def test_vuln_scan_output_defaults(self): + output = VulnScanOutput() + self.assertEqual(output.scan_success, False) + self.assertIsNone(output.return_code) + self.assertIsNone(output.scan_results_file_path) + self.assertIsNone(output.scan_time) + self.assertIsNone(output.results_file_size) + self.assertIsNone(output.total_vulnerabilities) + self.assertIsNone(output.critical_count) + self.assertIsNone(output.high_count) + self.assertIsNone(output.medium_count) + self.assertIsNone(output.low_count) + self.assertIsNone(output.other_count) + self.assertIsNone(output.error_message) + + def test_vuln_scan_output_success_scenario(self): + output = VulnScanOutput( + scan_success=True, + return_code=0, + scan_results_file_path="/tmp/results.json", + scan_time=8.3, + results_file_size=2048, + total_vulnerabilities=15, + critical_count=2, + high_count=5, + medium_count=6, + low_count=2, + other_count=0 + ) + self.assertEqual(output.scan_success, True) + self.assertEqual(output.return_code, 0) + self.assertEqual(output.scan_results_file_path, "/tmp/results.json") + self.assertEqual(output.scan_time, 8.3) + self.assertEqual(output.results_file_size, 2048) + self.assertEqual(output.total_vulnerabilities, 15) + self.assertEqual(output.critical_count, 2) + self.assertEqual(output.high_count, 5) + self.assertEqual(output.medium_count, 6) + self.assertEqual(output.low_count, 2) + self.assertEqual(output.other_count, 0) + self.assertIsNone(output.error_message) + + def test_vuln_scan_output_failure_scenario(self): + output = VulnScanOutput( + scan_success=False, + return_code=1, + error_message="Inspector API unavailable" + ) + self.assertEqual(output.scan_success, False) + self.assertEqual(output.return_code, 1) + self.assertEqual(output.error_message, "Inspector API unavailable") + self.assertIsNone(output.scan_results_file_path) + self.assertIsNone(output.scan_time) + self.assertIsNone(output.total_vulnerabilities) diff --git a/entrypoint/tests/test_orchestrator.py b/entrypoint/tests/test_orchestrator.py index c822649..c1955fb 100644 --- a/entrypoint/tests/test_orchestrator.py +++ b/entrypoint/tests/test_orchestrator.py @@ -119,13 +119,30 @@ def test_system_against_dockerfile_findings(self): "args", [ "out_scan", - "artifact_path", + "artifact_path", "artifact_type", "out_scan_csv", "out_scan_markdown", "out_dockerfile_scan_csv", "out_dockerfile_scan_md", "show_only_fixable_vulns", + # ScanConfig required fields + "sbomgen_version", + "timeout", + "platform", + "scanners", + "skip_scanners", + "skip_files", + # OutputConfig required fields + "display_vuln_findings", + "out_sbom", + "thresholds", + "critical", + "high", + "medium", + "low", + "other", + "threshold_fixable_only", ], ) args = ArgMock( @@ -137,9 +154,31 @@ def test_system_against_dockerfile_findings(self): out_dockerfile_scan_csv="/tmp/out_dockerfile_scan.csv", out_dockerfile_scan_md="/tmp/out_dockerfile_scan.md", show_only_fixable_vulns=False, + # ScanConfig defaults + sbomgen_version="latest", + timeout="600", + platform=None, + scanners="''", + skip_scanners="''", + skip_files="''", + # OutputConfig defaults + display_vuln_findings="disabled", + out_sbom="/tmp/sbom.json", + thresholds=False, + critical=0, + high=0, + medium=0, + low=0, + other=0, + threshold_fixable_only=False, ) - succeeded, scan_result, fixed_vuln_counts = orchestrator.get_scan_result(args) + # Create config objects for new function signature + from entrypoint.data_model import ScanConfig, OutputConfig + config = ScanConfig.from_args(args) + output_config = OutputConfig.from_args(args) + + succeeded, scan_result, fixed_vuln_counts = orchestrator.get_scan_result(args, config, output_config) self.assertTrue(succeeded) orchestrator.write_pkg_vuln_report_csv(args.out_scan_csv, scan_result) @@ -262,7 +301,17 @@ def test_threshold_exceeded_on_fixable_vulns(self): # Given a scan containing fixable and unfixable vulns, # threshold should be exceeded vulns_with_fixes = fixed_vulns.FixedVulns(criticals=10, highs=0, mediums=0, lows=0, others=0) - orchestrator.set_env_var_if_vuln_threshold_exceeded(threshold_args, vulns_with_fixes) + from entrypoint.data_model import OutputConfig + threshold_output_config = OutputConfig( + critical_threshold=threshold_args.critical, + high_threshold=threshold_args.high, + medium_threshold=threshold_args.medium, + low_threshold=threshold_args.low, + other_threshold=threshold_args.other, + thresholds=threshold_args.thresholds, + threshold_fixable_only=threshold_args.threshold_fixable_only + ) + orchestrator.set_env_var_if_vuln_threshold_exceeded(threshold_output_config, vulns_with_fixes) want = "1" got = os.environ.get("vulnerability_threshold_exceeded") self.assertEqual(want, got) @@ -270,7 +319,7 @@ def test_threshold_exceeded_on_fixable_vulns(self): # Given a scan containing NO fixable vulns, # threshold exceeded should NOT be set no_vulns_with_fix = fixed_vulns.FixedVulns(criticals=0, highs=0, mediums=0, lows=0, others=0) - orchestrator.set_env_var_if_vuln_threshold_exceeded(threshold_args, no_vulns_with_fix) + orchestrator.set_env_var_if_vuln_threshold_exceeded(threshold_output_config, no_vulns_with_fix) want = "0" got = os.environ.get("vulnerability_threshold_exceeded") self.assertEqual(want, got) @@ -287,7 +336,16 @@ def test_threshold_exceeded_on_fixable_vulns(self): threshold_fixable_only=True ) vulns_with_fixes = fixed_vulns.FixedVulns(criticals=10, highs=10, mediums=10, lows=10, others=10) - orchestrator.set_env_var_if_vuln_threshold_exceeded(disable_threshold_args, vulns_with_fixes) + disable_threshold_output_config = OutputConfig( + critical_threshold=disable_threshold_args.critical, + high_threshold=disable_threshold_args.high, + medium_threshold=disable_threshold_args.medium, + low_threshold=disable_threshold_args.low, + other_threshold=disable_threshold_args.other, + thresholds=disable_threshold_args.thresholds, + threshold_fixable_only=disable_threshold_args.threshold_fixable_only + ) + orchestrator.set_env_var_if_vuln_threshold_exceeded(disable_threshold_output_config, vulns_with_fixes) want = "0" got = os.environ.get("vulnerability_threshold_exceeded") self.assertEqual(want, got)