"""Unit test suite for docx.opc.pkgreader module."""

import pytest

from docx.opc.constants import CONTENT_TYPE as CT
from docx.opc.constants import RELATIONSHIP_TARGET_MODE as RTM
from docx.opc.packuri import PackURI
from docx.opc.phys_pkg import _ZipPkgReader
from docx.opc.pkgreader import (
    PackageReader,
    _ContentTypeMap,
    _SerializedPart,
    _SerializedRelationship,
    _SerializedRelationships,
)

from ..unitutil.mock import (
    ANY,
    Mock,
    call,
    class_mock,
    function_mock,
    initializer_mock,
    instance_mock,
    loose_mock,
    method_mock,
    patch,
)
from .unitdata.types import a_Default, a_Types, an_Override


class DescribePackageReader:
    def it_can_construct_from_pkg_file(
        self, _init_, PhysPkgReader_, from_xml, _srels_for, _load_serialized_parts
    ):
        phys_reader = PhysPkgReader_.return_value
        content_types = from_xml.return_value
        pkg_srels = _srels_for.return_value
        sparts = _load_serialized_parts.return_value
        pkg_file = Mock(name="pkg_file")

        pkg_reader = PackageReader.from_file(pkg_file)

        PhysPkgReader_.assert_called_once_with(pkg_file)
        from_xml.assert_called_once_with(phys_reader.content_types_xml)
        _srels_for.assert_called_once_with(phys_reader, "/")
        _load_serialized_parts.assert_called_once_with(
            phys_reader, pkg_srels, content_types
        )
        phys_reader.close.assert_called_once_with()
        _init_.assert_called_once_with(ANY, content_types, pkg_srels, sparts)
        assert isinstance(pkg_reader, PackageReader)

    def it_can_iterate_over_the_serialized_parts(self, iter_sparts_fixture):
        pkg_reader, expected_iter_spart_items = iter_sparts_fixture
        iter_spart_items = list(pkg_reader.iter_sparts())
        assert iter_spart_items == expected_iter_spart_items

    def it_can_iterate_over_all_the_srels(self):
        # mockery ----------------------
        pkg_srels = ["srel1", "srel2"]
        sparts = [
            Mock(name="spart1", partname="pn1", srels=["srel3", "srel4"]),
            Mock(name="spart2", partname="pn2", srels=["srel5", "srel6"]),
        ]
        pkg_reader = PackageReader(None, pkg_srels, sparts)
        # exercise ---------------------
        generated_tuples = list(pkg_reader.iter_srels())
        # verify -----------------------
        expected_tuples = [
            ("/", "srel1"),
            ("/", "srel2"),
            ("pn1", "srel3"),
            ("pn1", "srel4"),
            ("pn2", "srel5"),
            ("pn2", "srel6"),
        ]
        assert generated_tuples == expected_tuples

    def it_can_load_serialized_parts(self, _SerializedPart_, _walk_phys_parts):
        # test data --------------------
        test_data = (
            ("/part/name1.xml", "app/vnd.type_1", "reltype1", "<Part_1/>", "srels_1"),
            ("/part/name2.xml", "app/vnd.type_2", "reltype2", "<Part_2/>", "srels_2"),
        )
        iter_vals = [(t[0], t[2], t[3], t[4]) for t in test_data]
        content_types = {t[0]: t[1] for t in test_data}
        # mockery ----------------------
        phys_reader = Mock(name="phys_reader")
        pkg_srels = Mock(name="pkg_srels")
        _walk_phys_parts.return_value = iter_vals
        _SerializedPart_.side_effect = expected_sparts = (
            Mock(name="spart_1"),
            Mock(name="spart_2"),
        )
        # exercise ---------------------
        retval = PackageReader._load_serialized_parts(
            phys_reader, pkg_srels, content_types
        )
        # verify -----------------------
        expected_calls = [
            call(
                "/part/name1.xml", "app/vnd.type_1", "<Part_1/>", "reltype1", "srels_1"
            ),
            call(
                "/part/name2.xml", "app/vnd.type_2", "<Part_2/>", "reltype2", "srels_2"
            ),
        ]
        assert _SerializedPart_.call_args_list == expected_calls
        assert retval == expected_sparts

    def it_can_walk_phys_pkg_parts(self, _srels_for):
        # test data --------------------
        # +----------+       +--------+
        # | pkg_rels |-----> | part_1 |
        # +----------+       +--------+
        #      |               |    ^
        #      v               v    |
        #   external         +--------+     +--------+
        #                    | part_2 |---> | part_3 |
        #                    +--------+     +--------+
        partname_1, partname_2, partname_3 = (
            "/part/name1.xml",
            "/part/name2.xml",
            "/part/name3.xml",
        )
        part_1_blob, part_2_blob, part_3_blob = ("<Part_1/>", "<Part_2/>", "<Part_3/>")
        reltype1, reltype2, reltype3 = ("reltype1", "reltype2", "reltype3")
        srels = [
            Mock(name="rId1", is_external=True),
            Mock(
                name="rId2",
                is_external=False,
                reltype=reltype1,
                target_partname=partname_1,
            ),
            Mock(
                name="rId3",
                is_external=False,
                reltype=reltype2,
                target_partname=partname_2,
            ),
            Mock(
                name="rId4",
                is_external=False,
                reltype=reltype1,
                target_partname=partname_1,
            ),
            Mock(
                name="rId5",
                is_external=False,
                reltype=reltype3,
                target_partname=partname_3,
            ),
        ]
        pkg_srels = srels[:2]
        part_1_srels = srels[2:3]
        part_2_srels = srels[3:5]
        part_3_srels = []
        # mockery ----------------------
        phys_reader = Mock(name="phys_reader")
        _srels_for.side_effect = [part_1_srels, part_2_srels, part_3_srels]
        phys_reader.blob_for.side_effect = [part_1_blob, part_2_blob, part_3_blob]
        # exercise ---------------------
        generated_tuples = list(PackageReader._walk_phys_parts(phys_reader, pkg_srels))
        # verify -----------------------
        expected_tuples = [
            (partname_1, part_1_blob, reltype1, part_1_srels),
            (partname_2, part_2_blob, reltype2, part_2_srels),
            (partname_3, part_3_blob, reltype3, part_3_srels),
        ]
        assert generated_tuples == expected_tuples

    def it_can_retrieve_srels_for_a_source_uri(self, _SerializedRelationships_):
        # mockery ----------------------
        phys_reader = Mock(name="phys_reader")
        source_uri = Mock(name="source_uri")
        rels_xml = phys_reader.rels_xml_for.return_value
        load_from_xml = _SerializedRelationships_.load_from_xml
        srels = load_from_xml.return_value
        # exercise ---------------------
        retval = PackageReader._srels_for(phys_reader, source_uri)
        # verify -----------------------
        phys_reader.rels_xml_for.assert_called_once_with(source_uri)
        load_from_xml.assert_called_once_with(source_uri.baseURI, rels_xml)
        assert retval == srels

    # fixtures -------------------------------------------------------

    @pytest.fixture
    def blobs_(self, request):
        blob_ = loose_mock(request, spec=str, name="blob_")
        blob_2_ = loose_mock(request, spec=str, name="blob_2_")
        return blob_, blob_2_

    @pytest.fixture
    def content_types_(self, request):
        content_type_ = loose_mock(request, spec=str, name="content_type_")
        content_type_2_ = loose_mock(request, spec=str, name="content_type_2_")
        return content_type_, content_type_2_

    @pytest.fixture
    def from_xml(self, request):
        return method_mock(request, _ContentTypeMap, "from_xml", autospec=False)

    @pytest.fixture
    def _init_(self, request):
        return initializer_mock(request, PackageReader)

    @pytest.fixture
    def iter_sparts_fixture(
        self, sparts_, partnames_, content_types_, reltypes_, blobs_
    ):
        pkg_reader = PackageReader(None, None, sparts_)
        expected_iter_spart_items = [
            (partnames_[0], content_types_[0], reltypes_[0], blobs_[0]),
            (partnames_[1], content_types_[1], reltypes_[1], blobs_[1]),
        ]
        return pkg_reader, expected_iter_spart_items

    @pytest.fixture
    def _load_serialized_parts(self, request):
        return method_mock(
            request, PackageReader, "_load_serialized_parts", autospec=False
        )

    @pytest.fixture
    def partnames_(self, request):
        partname_ = loose_mock(request, spec=str, name="partname_")
        partname_2_ = loose_mock(request, spec=str, name="partname_2_")
        return partname_, partname_2_

    @pytest.fixture
    def PhysPkgReader_(self):
        p = patch("docx.opc.pkgreader.PhysPkgReader", spec_set=_ZipPkgReader)
        yield p.start()
        p.stop()

    @pytest.fixture
    def reltypes_(self, request):
        reltype_ = instance_mock(request, str, name="reltype_")
        reltype_2_ = instance_mock(request, str, name="reltype_2")
        return reltype_, reltype_2_

    @pytest.fixture
    def _SerializedPart_(self, request):
        return class_mock(request, "docx.opc.pkgreader._SerializedPart")

    @pytest.fixture
    def _SerializedRelationships_(self, request):
        return class_mock(request, "docx.opc.pkgreader._SerializedRelationships")

    @pytest.fixture
    def sparts_(self, request, partnames_, content_types_, reltypes_, blobs_):
        sparts_ = []
        for idx in range(2):
            name = "spart_%s" % (("%d_" % (idx + 1)) if idx else "")
            spart_ = instance_mock(
                request,
                _SerializedPart,
                name=name,
                partname=partnames_[idx],
                content_type=content_types_[idx],
                reltype=reltypes_[idx],
                blob=blobs_[idx],
            )
            sparts_.append(spart_)
        return sparts_

    @pytest.fixture
    def _srels_for(self, request):
        return method_mock(request, PackageReader, "_srels_for", autospec=False)

    @pytest.fixture
    def _walk_phys_parts(self, request):
        return method_mock(request, PackageReader, "_walk_phys_parts", autospec=False)


class Describe_ContentTypeMap:
    def it_can_construct_from_ct_item_xml(self, from_xml_fixture):
        content_types_xml, expected_defaults, expected_overrides = from_xml_fixture
        ct_map = _ContentTypeMap.from_xml(content_types_xml)
        assert ct_map._defaults == expected_defaults
        assert ct_map._overrides == expected_overrides

    def it_matches_an_override_on_case_insensitive_partname(
        self, match_override_fixture
    ):
        ct_map, partname, content_type = match_override_fixture
        assert ct_map[partname] == content_type

    def it_falls_back_to_case_insensitive_extension_default_match(
        self, match_default_fixture
    ):
        ct_map, partname, content_type = match_default_fixture
        assert ct_map[partname] == content_type

    def it_should_raise_on_partname_not_found(self):
        ct_map = _ContentTypeMap()
        with pytest.raises(KeyError):
            ct_map[PackURI("/!blat/rhumba.1x&")]

    def it_should_raise_on_key_not_instance_of_PackURI(self):
        ct_map = _ContentTypeMap()
        ct_map._overrides = {PackURI("/part/name1.xml"): "app/vnd.type1"}
        with pytest.raises(KeyError):
            ct_map["/part/name1.xml"]

    # fixtures ---------------------------------------------

    @pytest.fixture
    def from_xml_fixture(self):
        entries = (
            ("Default", "xml", CT.XML),
            ("Default", "PNG", CT.PNG),
            ("Override", "/ppt/presentation.xml", CT.PML_PRESENTATION_MAIN),
        )
        content_types_xml = self._xml_from(entries)
        expected_defaults = {}
        expected_overrides = {}
        for entry in entries:
            if entry[0] == "Default":
                ext = entry[1].lower()
                content_type = entry[2]
                expected_defaults[ext] = content_type
            elif entry[0] == "Override":
                partname, content_type = entry[1:]
                expected_overrides[partname] = content_type
        return content_types_xml, expected_defaults, expected_overrides

    @pytest.fixture(
        params=[
            ("/foo/bar.xml", "xml", "application/xml"),
            ("/foo/bar.PNG", "png", "image/png"),
            ("/foo/bar.jpg", "JPG", "image/jpeg"),
        ]
    )
    def match_default_fixture(self, request):
        partname_str, ext, content_type = request.param
        partname = PackURI(partname_str)
        ct_map = _ContentTypeMap()
        ct_map._add_override(PackURI("/bar/foo.xyz"), "application/xyz")
        ct_map._add_default(ext, content_type)
        return ct_map, partname, content_type

    @pytest.fixture(
        params=[
            ("/foo/bar.xml", "/foo/bar.xml"),
            ("/foo/bar.xml", "/FOO/Bar.XML"),
            ("/FoO/bAr.XmL", "/foo/bar.xml"),
        ]
    )
    def match_override_fixture(self, request):
        partname_str, should_match_partname_str = request.param
        partname = PackURI(partname_str)
        should_match_partname = PackURI(should_match_partname_str)
        content_type = "appl/vnd-foobar"
        ct_map = _ContentTypeMap()
        ct_map._add_override(partname, content_type)
        return ct_map, should_match_partname, content_type

    def _xml_from(self, entries):
        """
        Return XML for a [Content_Types].xml based on items in `entries`.
        """
        types_bldr = a_Types().with_nsdecls()
        for entry in entries:
            if entry[0] == "Default":
                ext, content_type = entry[1:]
                default_bldr = a_Default()
                default_bldr.with_Extension(ext)
                default_bldr.with_ContentType(content_type)
                types_bldr.with_child(default_bldr)
            elif entry[0] == "Override":
                partname, content_type = entry[1:]
                override_bldr = an_Override()
                override_bldr.with_PartName(partname)
                override_bldr.with_ContentType(content_type)
                types_bldr.with_child(override_bldr)
        return types_bldr.xml()


class Describe_SerializedPart:
    def it_remembers_construction_values(self):
        # test data --------------------
        partname = "/part/name.xml"
        content_type = "app/vnd.type"
        reltype = "http://rel/type"
        blob = "<Part/>"
        srels = "srels proxy"
        # exercise ---------------------
        spart = _SerializedPart(partname, content_type, reltype, blob, srels)
        # verify -----------------------
        assert spart.partname == partname
        assert spart.content_type == content_type
        assert spart.reltype == reltype
        assert spart.blob == blob
        assert spart.srels == srels


class Describe_SerializedRelationship:
    def it_remembers_construction_values(self):
        # test data --------------------
        rel_elm = Mock(
            name="rel_elm",
            rId="rId9",
            reltype="ReLtYpE",
            target_ref="docProps/core.xml",
            target_mode=RTM.INTERNAL,
        )
        # exercise ---------------------
        srel = _SerializedRelationship("/", rel_elm)
        # verify -----------------------
        assert srel.rId == "rId9"
        assert srel.reltype == "ReLtYpE"
        assert srel.target_ref == "docProps/core.xml"
        assert srel.target_mode == RTM.INTERNAL

    def it_knows_when_it_is_external(self):
        cases = (RTM.INTERNAL, RTM.EXTERNAL, "FOOBAR")
        expected_values = (False, True, False)
        for target_mode, expected_value in zip(cases, expected_values):
            rel_elm = Mock(
                name="rel_elm",
                rId=None,
                reltype=None,
                target_ref=None,
                target_mode=target_mode,
            )
            srel = _SerializedRelationship(None, rel_elm)
            assert srel.is_external is expected_value

    def it_can_calculate_its_target_partname(self):
        # test data --------------------
        cases = (
            ("/", "docProps/core.xml", "/docProps/core.xml"),
            ("/ppt", "viewProps.xml", "/ppt/viewProps.xml"),
            (
                "/ppt/slides",
                "../slideLayouts/slideLayout1.xml",
                "/ppt/slideLayouts/slideLayout1.xml",
            ),
        )
        for baseURI, target_ref, expected_partname in cases:
            # setup --------------------
            rel_elm = Mock(
                name="rel_elm",
                rId=None,
                reltype=None,
                target_ref=target_ref,
                target_mode=RTM.INTERNAL,
            )
            # exercise -----------------
            srel = _SerializedRelationship(baseURI, rel_elm)
            # verify -------------------
            assert srel.target_partname == expected_partname

    def it_raises_on_target_partname_when_external(self):
        rel_elm = Mock(
            name="rel_elm",
            rId="rId9",
            reltype="ReLtYpE",
            target_ref="docProps/core.xml",
            target_mode=RTM.EXTERNAL,
        )
        srel = _SerializedRelationship("/", rel_elm)
        with pytest.raises(ValueError, match="target_partname attribute on Relat"):
            srel.target_partname


class Describe_SerializedRelationships:
    def it_can_load_from_xml(self, parse_xml_, _SerializedRelationship_):
        # mockery ----------------------
        baseURI, rels_item_xml, rel_elm_1, rel_elm_2 = (
            Mock(name="baseURI"),
            Mock(name="rels_item_xml"),
            Mock(name="rel_elm_1"),
            Mock(name="rel_elm_2"),
        )
        rels_elm = Mock(name="rels_elm", Relationship_lst=[rel_elm_1, rel_elm_2])
        parse_xml_.return_value = rels_elm
        # exercise ---------------------
        srels = _SerializedRelationships.load_from_xml(baseURI, rels_item_xml)
        # verify -----------------------
        expected_calls = [
            call(baseURI, rel_elm_1),
            call(baseURI, rel_elm_2),
        ]
        parse_xml_.assert_called_once_with(rels_item_xml)
        assert _SerializedRelationship_.call_args_list == expected_calls
        assert isinstance(srels, _SerializedRelationships)

    def it_should_be_iterable(self):
        srels = _SerializedRelationships()
        try:
            for x in srels:
                pass
        except TypeError:
            msg = "_SerializedRelationships object is not iterable"
            pytest.fail(msg)

    # fixtures ---------------------------------------------

    @pytest.fixture
    def parse_xml_(self, request):
        return function_mock(request, "docx.opc.pkgreader.parse_xml")

    @pytest.fixture
    def _SerializedRelationship_(self, request):
        return class_mock(request, "docx.opc.pkgreader._SerializedRelationship")
