# This file is part of CycloneDX Python Library
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OWASP Foundation. All Rights Reserved.

from datetime import datetime, timedelta
from decimal import Decimal
from unittest import TestCase

from cyclonedx.model import XsUri
from cyclonedx.model.impact_analysis import (
    ImpactAnalysisAffectedStatus,
    ImpactAnalysisJustification,
    ImpactAnalysisResponse,
    ImpactAnalysisState,
)
from cyclonedx.model.vulnerability import (
    BomTarget,
    BomTargetVersionRange,
    Vulnerability,
    VulnerabilityAdvisory,
    VulnerabilityAnalysis,
    VulnerabilityRating,
    VulnerabilityReference,
    VulnerabilityScoreSource,
    VulnerabilitySeverity,
    VulnerabilitySource,
)
from tests import reorder


class TestModelVulnerabilitySeverity(TestCase):

    def test_v_severity_from_cvss_scores_single_critical(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(9.1),
            VulnerabilitySeverity.CRITICAL
        )

    def test_v_severity_from_cvss_scores_multiple_critical(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores((9.1, 9.5)),
            VulnerabilitySeverity.CRITICAL
        )

    def test_v_severity_from_cvss_scores_single_high(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(8.9),
            VulnerabilitySeverity.HIGH
        )

    def test_v_severity_from_cvss_scores_single_medium(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(4.2),
            VulnerabilitySeverity.MEDIUM
        )

    def test_v_severity_from_cvss_scores_single_low(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(1.1),
            VulnerabilitySeverity.LOW
        )

    def test_v_severity_from_cvss_scores_single_none(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(0.0),
            VulnerabilitySeverity.NONE
        )

    def test_v_severity_from_cvss_scores_multiple_high(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores((1.2, 8.9, 2.2, 5.6)),
            VulnerabilitySeverity.HIGH
        )


class TestModelVulnerabilityScoreSource(TestCase):

    def test_v_source_parse_other(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector('loremIpsum'),
            VulnerabilityScoreSource.OTHER
        )

    def test_v_source_parse_cvss4_0(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector(
                'CVSS:4.0/AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N/E:U'),
            VulnerabilityScoreSource.CVSS_V4
        )

    def test_v_source_parse_cvss3_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector(
                'CVSS:3.1/AV:N/AC:L/PR:N/UI:N/S:U/C:N/I:N/A:H'),
            VulnerabilityScoreSource.CVSS_V3_1
        )

    def test_v_source_parse_cvss3_0(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector(
                'CVSS:3.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            VulnerabilityScoreSource.CVSS_V3
        )

    def test_v_source_parse_cvss2_0(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector(
                'CVSS:2.0/AV:N/AC:L/Au:N/C:N/I:N/A:C'),
            VulnerabilityScoreSource.CVSS_V2
        )

    def test_v_source_parse_owasp_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector(
                'OWASP/K9:M1:O0:Z2/D1:X1:W1:L3/C2:I1:A1:T1/F1:R1:S2:P3/50'),
            VulnerabilityScoreSource.OWASP
        )

    def test_v_source_get_localised_vector_cvss4_slash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V4.get_localised_vector(
                'CVSS:4.0/AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'),
            'AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'
        )

    def test_v_source_get_localised_vector_cvss4_noslash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V4.get_localised_vector(
                'CVSS:4.0AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'),
            'AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'
        )

    def test_v_source_get_localised_vector_cvss4_none(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V4.get_localised_vector(
                'AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'),
            'AV:N/AC:L/AT:P/PR:N/UI:P/VC:H/VI:H/VA:H/SC:N/SI:N/SA:N'
        )

    def test_v_source_get_localised_vector_cvss3_1_slash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(
                'CVSS:3.1/AV:N/AC:H/PR:N/UI:N/S:U/C:H/I:H/A:H'),
            'AV:N/AC:H/PR:N/UI:N/S:U/C:H/I:H/A:H'
        )

    def test_v_source_get_localised_vector_cvss3_1_noslash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3_1.get_localised_vector(
                'CVSS:3.0AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_1_none(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3_1.get_localised_vector(
                'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_slash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(
                'CVSS:3.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_noslash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(
                'CVSS:3.0AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_none(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(
                'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_slash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(
                'CVSS:2.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_noslash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(
                'CVSS:2.0AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_none(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(
                'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_slash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(
                'OWASP/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_noslash(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(
                'OWASPAV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_none(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(
                'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_other(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OTHER.get_localised_vector(
                'SOMETHING_OR_OTHER'),
            'SOMETHING_OR_OTHER'
        )


class TestModelVulnerability(TestCase):

    def test_empty_vulnerability(self) -> None:
        v = Vulnerability()
        self.assertIsNone(v.bom_ref.value)
        self.assertIsNone(v.id)
        self.assertIsNone(v.source)
        self.assertFalse(v.references)
        self.assertFalse(v.ratings)
        self.assertFalse(v.cwes)
        self.assertIsNone(v.description)
        self.assertIsNone(v.detail)
        self.assertIsNone(v.recommendation)
        self.assertIsNone(v.workaround)
        self.assertFalse(v.advisories)
        self.assertIsNone(v.created)
        self.assertIsNone(v.published)
        self.assertIsNone(v.updated)
        self.assertIsNone(v.credits)
        self.assertFalse(v.tools)
        self.assertIsNone(v.analysis)
        self.assertFalse(v.affects)

    def test_sort(self) -> None:
        source1 = VulnerabilitySource(name='a')
        source2 = VulnerabilitySource(name='b')
        datetime1 = datetime.utcnow()
        datetime2 = datetime1 + timedelta(seconds=5)

        # expected sort order: (id, description, detail, source, created, published)
        expected_order = [0, 1, 10, 2, 3, 4, 5, 6, 7, 8, 9, 11]
        vulnerabilities = [
            Vulnerability(bom_ref='0', id='a', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='1', id='a', description='a', detail='a',
                          source=source1, created=datetime1),
            Vulnerability(bom_ref='2', id='a', description='a', detail='a',
                          source=source1),
            Vulnerability(bom_ref='3', id='a', description='a', detail='a'),
            Vulnerability(bom_ref='4', id='a', description='a'),
            Vulnerability(bom_ref='5', id='a'),
            Vulnerability(bom_ref='6', id='a', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime2),
            Vulnerability(bom_ref='7', id='a', description='a', detail='a',
                          source=source1, created=datetime2, published=datetime1),
            Vulnerability(bom_ref='8', id='a', description='a', detail='a',
                          source=source2, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='9', id='a', description='a', detail='b',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='10', id='a', description='b', detail='b',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='11', id='b', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime1),
        ]
        sorted_vulnerabilities = sorted(vulnerabilities)
        expected_vulnerabilities = reorder(vulnerabilities, expected_order)
        self.assertListEqual(sorted_vulnerabilities, expected_vulnerabilities)


class TestModelVulnerabilityAdvisory(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([title], url)
        expected_order = [0, 1, 2, 3]
        advisories = [
            VulnerabilityAdvisory(url=XsUri('a'), title='a'),
            VulnerabilityAdvisory(url=XsUri('b'), title='a'),
            VulnerabilityAdvisory(url=XsUri('a')),
            VulnerabilityAdvisory(url=XsUri('b')),
        ]
        sorted_advisories = sorted(advisories)
        expected_advisories = reorder(advisories, expected_order)
        self.assertListEqual(sorted_advisories, expected_advisories)


class TestModelVulnerabilitySource(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([name], [url])
        expected_order = [0, 1, 4, 5, 2, 3]
        sources = [
            VulnerabilitySource(url=XsUri('a'), name='a'),
            VulnerabilitySource(url=XsUri('b'), name='a'),
            VulnerabilitySource(url=XsUri('a')),
            VulnerabilitySource(url=XsUri('b')),
            VulnerabilitySource(name='a'),
            VulnerabilitySource(name='b'),
        ]
        sorted_sources = sorted(sources)
        expected_sources = reorder(sources, expected_order)
        self.assertListEqual(sorted_sources, expected_sources)


class TestModelVulnerabilityReference(TestCase):

    def test_sort(self) -> None:
        source_a = VulnerabilitySource(name='a')
        source_b = VulnerabilitySource(name='b')

        # expected sort order: ([id], [source])
        expected_order = [2, 3, 1, 0]
        refs = [
            VulnerabilityReference(id='b', source=source_b),
            VulnerabilityReference(id='b', source=source_a),
            VulnerabilityReference(id='a', source=source_a),
            VulnerabilityReference(id='a', source=source_b),
        ]
        sorted_refs = sorted(refs)
        expected_refs = reorder(refs, expected_order)
        self.assertListEqual(sorted_refs, expected_refs)


class TestModelVulnerabilityRating(TestCase):

    def test_sort(self) -> None:
        source_a = VulnerabilitySource(name='a')
        method_a = VulnerabilityScoreSource.CVSS_V3_1

        # expected sort order: ([severity], [score], [source], [method], [vector], [justification])
        expected_order = [5, 0, 1, 2, 3, 4, 6, 7]
        refs = [
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a, vector='a', justification='a'),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a, vector='a'),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10), source=source_a),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10)),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH),
            VulnerabilityRating(severity=VulnerabilitySeverity.LOW, score=Decimal(10),
                                source=source_a, method=method_a, vector='a', justification='a'),
            VulnerabilityRating(score=Decimal(10), source=source_a, method=method_a, vector='a', justification='a'),
        ]
        sorted_refs = sorted(refs)
        expected_refs = reorder(refs, expected_order)
        self.maxDiff = None  # gimme all diff on error
        self.assertListEqual(sorted_refs, expected_refs)


class TestModelBomTargetVersionRange(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([version], [range], [status])
        expected_order = [0, 3, 1, 2, 4, 5]
        ranges = [
            BomTargetVersionRange(version='1.0.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
            BomTargetVersionRange(version='1.0.0'),
            BomTargetVersionRange(version='2.0.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
            BomTargetVersionRange(version='1.0.0', status=ImpactAnalysisAffectedStatus.UNAFFECTED),
            BomTargetVersionRange(range='1.0.0 - 2.0.0', status=ImpactAnalysisAffectedStatus.UNAFFECTED),
            BomTargetVersionRange(range='2.0.0 - 2.1.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
        ]
        sorted_ranges = sorted(ranges)
        expected_ranges = reorder(ranges, expected_order)
        self.assertListEqual(sorted_ranges, expected_ranges)


class TestModelBomTarget(TestCase):

    def test_sort(self) -> None:
        version_a = BomTargetVersionRange(version='1.0.0')
        version_b = BomTargetVersionRange(version='2.0.0')

        # expected sort order: (ref)
        expected_order = [1, 0, 3, 2, 4]
        targets = [
            BomTarget(ref='b'),
            BomTarget(ref='a'),
            BomTarget(ref='d'),
            BomTarget(ref='c', versions=[version_a, version_b]),
            BomTarget(ref='g'),
        ]
        sorted_targets = sorted(targets)
        expected_targets = reorder(targets, expected_order)
        self.assertListEqual(sorted_targets, expected_targets)


class TestModelVulnerabilityAnalysis(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([state], [justification], [responses], [detail], [first_issued], [last_updated])
        expected_order = [3, 1, 0, 2, 5, 4]
        analyses = [
            VulnerabilityAnalysis(state=ImpactAnalysisState.EXPLOITABLE),
            VulnerabilityAnalysis(state=ImpactAnalysisState.EXPLOITABLE,
                                  responses=[ImpactAnalysisResponse.CAN_NOT_FIX]),
            VulnerabilityAnalysis(state=ImpactAnalysisState.NOT_AFFECTED,
                                  justification=ImpactAnalysisJustification.CODE_NOT_PRESENT),
            VulnerabilityAnalysis(state=ImpactAnalysisState.EXPLOITABLE,
                                  justification=ImpactAnalysisJustification.REQUIRES_ENVIRONMENT),
            VulnerabilityAnalysis(first_issued=datetime(2024, 4, 4), last_updated=datetime(2025, 5, 5)),
            VulnerabilityAnalysis(first_issued=datetime(2023, 3, 3), last_updated=datetime(2023, 3, 3)),
        ]
        sorted_analyses = sorted(analyses)
        expected_analyses = reorder(analyses, expected_order)
        self.assertListEqual(sorted_analyses, expected_analyses)
