from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import polymorphic_union
from sqlalchemy.orm import relationship
from sqlalchemy.orm.interfaces import MANYTOONE
from sqlalchemy.orm.interfaces import ONETOMANY
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table


def _combinations():
    for parent in ["a", "b", "c"]:
        for child in ["a", "b", "c"]:
            for direction in [ONETOMANY, MANYTOONE]:
                name = "Test%sTo%s%s" % (
                    parent,
                    child,
                    (direction is ONETOMANY and "O2M" or "M2O"),
                )
                yield (name, parent, child, direction)


@testing.combinations(
    *list(_combinations()), argnames="name,parent,child,direction", id_="saaa"
)
class ABCTest(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        parent, child, direction = cls.parent, cls.child, cls.direction

        ta = ["a", metadata]
        ta.append(
            Column(
                "id",
                Integer,
                primary_key=True,
                test_needs_autoincrement=True,
            )
        ),
        ta.append(Column("a_data", String(30)))
        if "a" == parent and direction == MANYTOONE:
            ta.append(
                Column(
                    "child_id",
                    Integer,
                    ForeignKey("%s.id" % child, use_alter=True, name="foo"),
                )
            )
        elif "a" == child and direction == ONETOMANY:
            ta.append(
                Column(
                    "parent_id",
                    Integer,
                    ForeignKey("%s.id" % parent, use_alter=True, name="foo"),
                )
            )
        ta = Table(*ta)

        tb = ["b", metadata]
        tb.append(Column("id", Integer, ForeignKey("a.id"), primary_key=True))

        tb.append(Column("b_data", String(30)))

        if "b" == parent and direction == MANYTOONE:
            tb.append(
                Column(
                    "child_id",
                    Integer,
                    ForeignKey("%s.id" % child, use_alter=True, name="foo"),
                )
            )
        elif "b" == child and direction == ONETOMANY:
            tb.append(
                Column(
                    "parent_id",
                    Integer,
                    ForeignKey("%s.id" % parent, use_alter=True, name="foo"),
                )
            )
        tb = Table(*tb)

        tc = ["c", metadata]
        tc.append(Column("id", Integer, ForeignKey("b.id"), primary_key=True))

        tc.append(Column("c_data", String(30)))

        if "c" == parent and direction == MANYTOONE:
            tc.append(
                Column(
                    "child_id",
                    Integer,
                    ForeignKey("%s.id" % child, use_alter=True, name="foo"),
                )
            )
        elif "c" == child and direction == ONETOMANY:
            tc.append(
                Column(
                    "parent_id",
                    Integer,
                    ForeignKey("%s.id" % parent, use_alter=True, name="foo"),
                )
            )
        tc = Table(*tc)

    @classmethod
    def setup_mappers(cls):
        parent, child, direction = cls.parent, cls.child, cls.direction
        ta, tb, tc = cls.tables("a", "b", "c")
        parent_table = {"a": ta, "b": tb, "c": tc}[parent]
        child_table = {"a": ta, "b": tb, "c": tc}[child]

        remote_side = None

        if direction == MANYTOONE:
            foreign_keys = [parent_table.c.child_id]
        elif direction == ONETOMANY:
            foreign_keys = [child_table.c.parent_id]

        atob = ta.c.id == tb.c.id
        btoc = tc.c.id == tb.c.id

        if direction == ONETOMANY:
            relationshipjoin = parent_table.c.id == child_table.c.parent_id
        elif direction == MANYTOONE:
            relationshipjoin = parent_table.c.child_id == child_table.c.id
            if parent is child:
                remote_side = [child_table.c.id]

        abcjoin = polymorphic_union(
            {
                "a": ta.select()
                .where(tb.c.id == None)  # noqa
                .select_from(ta.outerjoin(tb, onclause=atob))
                .subquery(),
                "b": ta.join(tb, onclause=atob)
                .outerjoin(tc, onclause=btoc)
                .select()
                .where(tc.c.id == None)
                .reduce_columns()
                .subquery(),  # noqa
                "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob),
            },
            "type",
            "abcjoin",
        )

        bcjoin = polymorphic_union(
            {
                "b": ta.join(tb, onclause=atob)
                .outerjoin(tc, onclause=btoc)
                .select()
                .where(tc.c.id == None)
                .reduce_columns()
                .subquery(),  # noqa
                "c": tc.join(tb, onclause=btoc).join(ta, onclause=atob),
            },
            "type",
            "bcjoin",
        )

        class A(cls.Comparable):
            def __init__(self, name):
                self.a_data = name

        class B(A):
            pass

        class C(B):
            pass

        cls.mapper_registry.map_imperatively(
            A,
            ta,
            polymorphic_on=abcjoin.c.type,
            with_polymorphic=("*", abcjoin),
            polymorphic_identity="a",
        )
        cls.mapper_registry.map_imperatively(
            B,
            tb,
            polymorphic_on=bcjoin.c.type,
            with_polymorphic=("*", bcjoin),
            polymorphic_identity="b",
            inherits=A,
            inherit_condition=atob,
        )
        cls.mapper_registry.map_imperatively(
            C,
            tc,
            polymorphic_identity="c",
            with_polymorphic=("*", tc.join(tb, btoc).join(ta, atob)),
            inherits=B,
            inherit_condition=btoc,
        )

        parent_mapper = class_mapper({ta: A, tb: B, tc: C}[parent_table])
        child_mapper = class_mapper({ta: A, tb: B, tc: C}[child_table])

        parent_mapper.add_property(
            "collection",
            relationship(
                child_mapper,
                primaryjoin=relationshipjoin,
                foreign_keys=foreign_keys,
                order_by=child_mapper.c.id,
                remote_side=remote_side,
                uselist=True,
            ),
        )

    def test_roundtrip(self):
        parent, child, direction = self.parent, self.child, self.direction
        A, B, C = self.classes("A", "B", "C")
        parent_class = {"a": A, "b": B, "c": C}[parent]
        child_class = {"a": A, "b": B, "c": C}[child]

        sess = fixture_session(autoflush=False, expire_on_commit=False)

        parent_obj = parent_class("parent1")
        child_obj = child_class("child1")
        somea = A("somea")
        someb = B("someb")
        somec = C("somec")

        # print "APPENDING", parent.__class__.__name__ , "TO",
        # child.__class__.__name__

        sess.add(parent_obj)
        parent_obj.collection.append(child_obj)
        if direction == ONETOMANY:
            child2 = child_class("child2")
            parent_obj.collection.append(child2)
            sess.add(child2)
        elif direction == MANYTOONE:
            parent2 = parent_class("parent2")
            parent2.collection.append(child_obj)
            sess.add(parent2)
        sess.add(somea)
        sess.add(someb)
        sess.add(somec)
        sess.commit()
        sess.close()

        # assert result via direct get() of parent object
        result = sess.get(parent_class, parent_obj.id)
        assert result.id == parent_obj.id
        assert result.collection[0].id == child_obj.id
        if direction == ONETOMANY:
            assert result.collection[1].id == child2.id
        elif direction == MANYTOONE:
            result2 = sess.get(parent_class, parent2.id)
            assert result2.id == parent2.id
            assert result2.collection[0].id == child_obj.id

        sess.expunge_all()

        # assert result via polymorphic load of parent object
        result = sess.query(A).filter_by(id=parent_obj.id).one()
        assert result.id == parent_obj.id
        assert result.collection[0].id == child_obj.id
        if direction == ONETOMANY:
            assert result.collection[1].id == child2.id
        elif direction == MANYTOONE:
            result2 = sess.query(A).filter_by(id=parent2.id).one()
            assert result2.id == parent2.id
            assert result2.collection[0].id == child_obj.id
