from xsdata.codegen.handlers import MergeAttributes
from xsdata.codegen.models import Restrictions
from xsdata.models.enums import DataType
from xsdata.utils.testing import (
    AttrFactory,
    AttrTypeFactory,
    ClassFactory,
    FactoryTestCase,
)


class MergeAttributesTests(FactoryTestCase):
    def setUp(self) -> None:
        super().setUp()
        self.processor = MergeAttributes

    def test_process_with_enumeration(self) -> None:
        target = ClassFactory.create()
        target.attrs = [
            AttrFactory.enumeration(default=1),
            AttrFactory.enumeration(default=1),
            AttrFactory.enumeration(default=2),
            AttrFactory.enumeration(default=2),
        ]

        self.processor.process(target)
        self.assertEqual([1, 2], [x.default for x in target.attrs])

    def test_process_with_non_enumeration(self) -> None:
        one = AttrFactory.attribute(fixed=True)
        one_clone = one.clone()
        restrictions = Restrictions(min_occurs=10, max_occurs=15)
        two = AttrFactory.element(restrictions=restrictions, fixed=True)
        two_clone = two.clone()
        two_clone.restrictions.min_occurs = 5
        two_clone.restrictions.max_occurs = 5
        two_clone_two = two.clone()
        two_clone_two.restrictions.min_occurs = 4
        two_clone_two.restrictions.max_occurs = 4
        three = AttrFactory.element()
        four = AttrFactory.enumeration()
        four_clone = four.clone()
        five = AttrFactory.element()
        five.types = [AttrTypeFactory.native(DataType.INT)]
        five_clone = five.clone()
        five_clone_two = five.clone()
        five_clone_two.restrictions.sequence = 1
        five_clone_two.types.append(AttrTypeFactory.native(DataType.FLOAT))

        target = ClassFactory.create(
            attrs=[
                one,
                one_clone,
                two,
                two_clone,
                two_clone_two,
                three,
                four,
                four_clone,
                five,
                five_clone,
                five_clone_two,
            ]
        )

        winners = [one, two, three, four, five]

        self.processor.process(target)
        self.assertEqual(winners, target.attrs)

        self.assertTrue(one.fixed)
        self.assertIsNone(one.restrictions.min_occurs)
        self.assertIsNone(one.restrictions.max_occurs)
        self.assertFalse(two.fixed)
        self.assertEqual(4, two.restrictions.min_occurs)
        self.assertEqual(24, two.restrictions.max_occurs)
        self.assertIsNone(three.restrictions.min_occurs)
        self.assertIsNone(three.restrictions.max_occurs)
        self.assertIsNone(four.restrictions.min_occurs)
        self.assertIsNone(four.restrictions.max_occurs)
        self.assertEqual(1, five.restrictions.sequence)
        self.assertEqual(0, five.restrictions.min_occurs)
        self.assertEqual(3, five.restrictions.max_occurs)
        self.assertEqual(["int", "float"], [x.name for x in five.types])

    def test_process_elements_in_different_choice_groups(self) -> None:
        """Test that elements in different choice groups use max, not sum."""
        choice1_id = 12345
        choice2_id = 67890

        attr1 = AttrFactory.element(
            name="configOption",
            index=5,
            restrictions=Restrictions(min_occurs=0, max_occurs=1, choice=choice1_id),
        )
        attr2 = AttrFactory.element(
            name="configOption",
            index=8,
            restrictions=Restrictions(min_occurs=0, max_occurs=1, choice=choice2_id),
        )

        target = ClassFactory.create(attrs=[attr1, attr2])

        self.processor.process(target)

        # Should merge into one attr with max_occurs=1 (not 2)
        self.assertEqual(1, len(target.attrs))
        self.assertEqual("configOption", target.attrs[0].name)
        self.assertEqual(0, target.attrs[0].restrictions.min_occurs)
        self.assertEqual(1, target.attrs[0].restrictions.max_occurs)

    def test_process_elements_in_same_choice_different_branches(self) -> None:
        """Test that elements in the same choice but different branches use max.

        This tests the case where an element appears in multiple branches of the
        same choice. Since only one branch is selected at runtime, they are
        mutually exclusive and max_occurs should use max(), not sum().
        """
        choice_id = 12345

        attr1 = AttrFactory.element(
            name="b",
            index=10,
            restrictions=Restrictions(min_occurs=0, max_occurs=1, choice=choice_id),
        )
        attr2 = AttrFactory.element(
            name="b",
            index=11,
            restrictions=Restrictions(min_occurs=1, max_occurs=1, choice=choice_id),
        )

        target = ClassFactory.create(attrs=[attr1, attr2])

        self.processor.process(target)

        # Should merge into one attr with max_occurs=1 (not 2)
        # because both are in different branches of the same choice
        self.assertEqual(1, len(target.attrs))
        self.assertEqual("b", target.attrs[0].name)
        self.assertEqual(0, target.attrs[0].restrictions.min_occurs)
        self.assertEqual(1, target.attrs[0].restrictions.max_occurs)
