# This file is part of CycloneDX Python Library
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) OWASP Foundation. All Rights Reserved.

from datetime import datetime, timedelta
from decimal import Decimal
from unittest import TestCase

from cyclonedx.model import XsUri
from cyclonedx.model.impact_analysis import ImpactAnalysisAffectedStatus
from cyclonedx.model.vulnerability import (
    BomTarget,
    BomTargetVersionRange,
    Vulnerability,
    VulnerabilityAdvisory,
    VulnerabilityRating,
    VulnerabilityReference,
    VulnerabilityScoreSource,
    VulnerabilitySeverity,
    VulnerabilitySource,
)
from tests import reorder


class TestModelVulnerability(TestCase):

    def test_v_severity_from_cvss_scores_single_critical(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(9.1),
            VulnerabilitySeverity.CRITICAL
        )

    def test_v_severity_from_cvss_scores_multiple_critical(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores((9.1, 9.5)),
            VulnerabilitySeverity.CRITICAL
        )

    def test_v_severity_from_cvss_scores_single_high(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(8.9),
            VulnerabilitySeverity.HIGH
        )

    def test_v_severity_from_cvss_scores_single_medium(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(4.2),
            VulnerabilitySeverity.MEDIUM
        )

    def test_v_severity_from_cvss_scores_single_low(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(1.1),
            VulnerabilitySeverity.LOW
        )

    def test_v_severity_from_cvss_scores_single_none(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores(0.0),
            VulnerabilitySeverity.NONE
        )

    def test_v_severity_from_cvss_scores_multiple_high(self) -> None:
        self.assertEqual(
            VulnerabilitySeverity.get_from_cvss_scores((1.2, 8.9, 2.2, 5.6)),
            VulnerabilitySeverity.HIGH
        )

    def test_v_source_parse_cvss3_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector('CVSS:3.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            VulnerabilityScoreSource.CVSS_V3
        )

    def test_v_source_parse_cvss2_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector('CVSS:2.0/AV:N/AC:L/Au:N/C:N/I:N/A:C'),
            VulnerabilityScoreSource.CVSS_V2
        )

    def test_v_source_parse_owasp_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.get_from_vector('OWASP/K9:M1:O0:Z2/D1:X1:W1:L3/C2:I1:A1:T1/F1:R1:S2:P3/50'),
            VulnerabilityScoreSource.OWASP
        )

    def test_v_source_get_localised_vector_cvss3_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(
                vector='CVSS:3.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
            ),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_2(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(vector='CVSS:3.0AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss3_3(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V3.get_localised_vector(vector='AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(
                vector='CVSS:2.0/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_2(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(vector='CVSS:2.1AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_cvss2_3(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.CVSS_V2.get_localised_vector(vector='AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_1(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(vector='OWASP/AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_2(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(vector='OWASPAV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_owasp_3(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OWASP.get_localised_vector(vector='AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'),
            'AV:L/AC:L/PR:N/UI:R/S:C/C:L/I:N/A:N'
        )

    def test_v_source_get_localised_vector_other_2(self) -> None:
        self.assertEqual(
            VulnerabilityScoreSource.OTHER.get_localised_vector(vector='SOMETHING_OR_OTHER'),
            'SOMETHING_OR_OTHER'
        )

    def test_empty_vulnerability(self) -> None:
        v = Vulnerability()
        self.assertIsNone(v.bom_ref.value)
        self.assertIsNone(v.id)
        self.assertIsNone(v.source)
        self.assertFalse(v.references)
        self.assertFalse(v.ratings)
        self.assertFalse(v.cwes)
        self.assertIsNone(v.description)
        self.assertIsNone(v.detail)
        self.assertIsNone(v.recommendation)
        self.assertIsNone(v.workaround)
        self.assertFalse(v.advisories)
        self.assertIsNone(v.created)
        self.assertIsNone(v.published)
        self.assertIsNone(v.updated)
        self.assertIsNone(v.credits)
        self.assertFalse(v.tools)
        self.assertIsNone(v.analysis)
        self.assertFalse(v.affects)

    def test_sort(self) -> None:
        source1 = VulnerabilitySource(name='a')
        source2 = VulnerabilitySource(name='b')
        datetime1 = datetime.utcnow()
        datetime2 = datetime1 + timedelta(seconds=5)

        # expected sort order: (id, description, detail, source, created, published)
        expected_order = [0, 1, 10, 2, 3, 4, 5, 6, 7, 8, 9, 11]
        vulnerabilities = [
            Vulnerability(bom_ref='0', id='a', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='1', id='a', description='a', detail='a',
                          source=source1, created=datetime1),
            Vulnerability(bom_ref='2', id='a', description='a', detail='a',
                          source=source1),
            Vulnerability(bom_ref='3', id='a', description='a', detail='a'),
            Vulnerability(bom_ref='4', id='a', description='a'),
            Vulnerability(bom_ref='5', id='a'),
            Vulnerability(bom_ref='6', id='a', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime2),
            Vulnerability(bom_ref='7', id='a', description='a', detail='a',
                          source=source1, created=datetime2, published=datetime1),
            Vulnerability(bom_ref='8', id='a', description='a', detail='a',
                          source=source2, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='9', id='a', description='a', detail='b',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='10', id='a', description='b', detail='b',
                          source=source1, created=datetime1, published=datetime1),
            Vulnerability(bom_ref='11', id='b', description='a', detail='a',
                          source=source1, created=datetime1, published=datetime1),
        ]
        sorted_vulnerabilities = sorted(vulnerabilities)
        expected_vulnerabilities = reorder(vulnerabilities, expected_order)
        self.assertListEqual(sorted_vulnerabilities, expected_vulnerabilities)


class TestModelVulnerabilityAdvisory(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([title], url)
        expected_order = [0, 1, 2, 3]
        advisories = [
            VulnerabilityAdvisory(url=XsUri('a'), title='a'),
            VulnerabilityAdvisory(url=XsUri('b'), title='a'),
            VulnerabilityAdvisory(url=XsUri('a')),
            VulnerabilityAdvisory(url=XsUri('b')),
        ]
        sorted_advisories = sorted(advisories)
        expected_advisories = reorder(advisories, expected_order)
        self.assertListEqual(sorted_advisories, expected_advisories)


class TestModelVulnerabilitySource(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([name], [url])
        expected_order = [0, 1, 4, 5, 2, 3]
        sources = [
            VulnerabilitySource(url=XsUri('a'), name='a'),
            VulnerabilitySource(url=XsUri('b'), name='a'),
            VulnerabilitySource(url=XsUri('a')),
            VulnerabilitySource(url=XsUri('b')),
            VulnerabilitySource(name='a'),
            VulnerabilitySource(name='b'),
        ]
        sorted_sources = sorted(sources)
        expected_sources = reorder(sources, expected_order)
        self.assertListEqual(sorted_sources, expected_sources)


class TestModelVulnerabilityReference(TestCase):

    def test_sort(self) -> None:
        source_a = VulnerabilitySource(name='a')
        source_b = VulnerabilitySource(name='b')

        # expected sort order: ([id], [source])
        expected_order = [2, 3, 1, 0]
        refs = [
            VulnerabilityReference(id='b', source=source_b),
            VulnerabilityReference(id='b', source=source_a),
            VulnerabilityReference(id='a', source=source_a),
            VulnerabilityReference(id='a', source=source_b),
        ]
        sorted_refs = sorted(refs)
        expected_refs = reorder(refs, expected_order)
        self.assertListEqual(sorted_refs, expected_refs)


class TestModelVulnerabilityRating(TestCase):

    def test_sort(self) -> None:
        source_a = VulnerabilitySource(name='a')
        method_a = VulnerabilityScoreSource.CVSS_V3_1

        # expected sort order: ([severity], [score], [source], [method], [vector], [justification])
        expected_order = [5, 0, 1, 2, 3, 4, 6, 7]
        refs = [
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a, vector='a', justification='a'),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a, vector='a'),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10),
                                source=source_a, method=method_a),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10), source=source_a),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH, score=Decimal(10)),
            VulnerabilityRating(severity=VulnerabilitySeverity.HIGH),
            VulnerabilityRating(severity=VulnerabilitySeverity.LOW, score=Decimal(10),
                                source=source_a, method=method_a, vector='a', justification='a'),
            VulnerabilityRating(score=Decimal(10), source=source_a, method=method_a, vector='a', justification='a'),
        ]
        sorted_refs = sorted(refs)
        expected_refs = reorder(refs, expected_order)
        self.maxDiff = None  # gimme all diff on error
        self.assertListEqual(sorted_refs, expected_refs)


class TestModelBomTargetVersionRange(TestCase):

    def test_sort(self) -> None:
        # expected sort order: ([version], [range], [status])
        expected_order = [0, 3, 1, 2, 4, 5]
        ranges = [
            BomTargetVersionRange(version='1.0.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
            BomTargetVersionRange(version='1.0.0'),
            BomTargetVersionRange(version='2.0.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
            BomTargetVersionRange(version='1.0.0', status=ImpactAnalysisAffectedStatus.UNAFFECTED),
            BomTargetVersionRange(range='1.0.0 - 2.0.0', status=ImpactAnalysisAffectedStatus.UNAFFECTED),
            BomTargetVersionRange(range='2.0.0 - 2.1.0', status=ImpactAnalysisAffectedStatus.AFFECTED),
        ]
        sorted_ranges = sorted(ranges)
        expected_ranges = reorder(ranges, expected_order)
        self.assertListEqual(sorted_ranges, expected_ranges)


class TestModelBomTarget(TestCase):

    def test_sort(self) -> None:
        version_a = BomTargetVersionRange(version='1.0.0')
        version_b = BomTargetVersionRange(version='2.0.0')

        # expected sort order: (ref)
        expected_order = [1, 0, 3, 2, 4]
        targets = [
            BomTarget(ref='b'),
            BomTarget(ref='a'),
            BomTarget(ref='d'),
            BomTarget(ref='c', versions=[version_a, version_b]),
            BomTarget(ref='g'),
        ]
        sorted_targets = sorted(targets)
        expected_targets = reorder(targets, expected_order)
        self.assertListEqual(sorted_targets, expected_targets)
