import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import backref
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.orm import relationship
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table


class M2MTest(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        Table(
            "place",
            metadata,
            Column(
                "place_id",
                Integer,
                test_needs_autoincrement=True,
                primary_key=True,
            ),
            Column("name", String(30), nullable=False),
            test_needs_acid=True,
        )

        Table(
            "transition",
            metadata,
            Column(
                "transition_id",
                Integer,
                test_needs_autoincrement=True,
                primary_key=True,
            ),
            Column("name", String(30), nullable=False),
            test_needs_acid=True,
        )

        Table(
            "place_thingy",
            metadata,
            Column(
                "thingy_id",
                Integer,
                test_needs_autoincrement=True,
                primary_key=True,
            ),
            Column(
                "place_id",
                Integer,
                ForeignKey("place.place_id"),
                nullable=False,
            ),
            Column("name", String(30), nullable=False),
            test_needs_acid=True,
        )

        # association table #1
        Table(
            "place_input",
            metadata,
            Column("place_id", Integer, ForeignKey("place.place_id")),
            Column(
                "transition_id",
                Integer,
                ForeignKey("transition.transition_id"),
            ),
            test_needs_acid=True,
        )

        # association table #2
        Table(
            "place_output",
            metadata,
            Column("place_id", Integer, ForeignKey("place.place_id")),
            Column(
                "transition_id",
                Integer,
                ForeignKey("transition.transition_id"),
            ),
            test_needs_acid=True,
        )

        Table(
            "place_place",
            metadata,
            Column("pl1_id", Integer, ForeignKey("place.place_id")),
            Column("pl2_id", Integer, ForeignKey("place.place_id")),
            test_needs_acid=True,
        )

    @classmethod
    def setup_classes(cls):
        class Place(cls.Basic):
            def __init__(self, name):
                self.name = name

        class PlaceThingy(cls.Basic):
            def __init__(self, name):
                self.name = name

        class Transition(cls.Basic):
            def __init__(self, name):
                self.name = name

    def test_overlapping_attribute_error(self):
        place, Transition, place_input, Place, transition = (
            self.tables.place,
            self.classes.Transition,
            self.tables.place_input,
            self.classes.Place,
            self.tables.transition,
        )

        self.mapper_registry.map_imperatively(
            Place,
            place,
            properties={
                "transitions": relationship(
                    Transition, secondary=place_input, backref="places"
                )
            },
        )
        self.mapper_registry.map_imperatively(
            Transition,
            transition,
            properties={
                "places": relationship(
                    Place, secondary=place_input, backref="transitions"
                )
            },
        )
        assert_raises_message(
            sa.exc.ArgumentError,
            "property of that name exists",
            sa.orm.configure_mappers,
        )

    def test_self_referential_roundtrip(self):
        place, Place, place_place = (
            self.tables.place,
            self.classes.Place,
            self.tables.place_place,
        )

        self.mapper_registry.map_imperatively(
            Place,
            place,
            properties={
                "places": relationship(
                    Place,
                    secondary=place_place,
                    primaryjoin=place.c.place_id == place_place.c.pl1_id,
                    secondaryjoin=place.c.place_id == place_place.c.pl2_id,
                    order_by=place_place.c.pl2_id,
                )
            },
        )

        sess = fixture_session()
        p1 = Place("place1")
        p2 = Place("place2")
        p3 = Place("place3")
        p4 = Place("place4")
        p5 = Place("place5")
        p6 = Place("place6")
        p7 = Place("place7")
        sess.add_all((p1, p2, p3, p4, p5, p6, p7))
        p1.places.append(p2)
        p1.places.append(p3)
        p5.places.append(p6)
        p6.places.append(p1)
        p7.places.append(p1)
        p1.places.append(p5)
        p4.places.append(p3)
        p3.places.append(p4)
        sess.commit()

        eq_(p1.places, [p2, p3, p5])
        eq_(p5.places, [p6])
        eq_(p7.places, [p1])
        eq_(p6.places, [p1])
        eq_(p4.places, [p3])
        eq_(p3.places, [p4])
        eq_(p2.places, [])

    def test_self_referential_bidirectional_mutation(self):
        place, Place, place_place = (
            self.tables.place,
            self.classes.Place,
            self.tables.place_place,
        )

        self.mapper_registry.map_imperatively(
            Place,
            place,
            properties={
                "child_places": relationship(
                    Place,
                    secondary=place_place,
                    primaryjoin=place.c.place_id == place_place.c.pl1_id,
                    secondaryjoin=place.c.place_id == place_place.c.pl2_id,
                    order_by=place_place.c.pl2_id,
                    backref="parent_places",
                )
            },
        )

        sess = fixture_session()
        p1 = Place("place1")
        p2 = Place("place2")
        p2.parent_places = [p1]
        sess.add_all([p1, p2])
        p1.parent_places.append(p2)
        sess.commit()

        assert p1 in p2.parent_places
        assert p2 in p1.parent_places

    def test_joinedload_on_double(self):
        """test that a mapper can have two eager relationships to the same
        table, via two different association tables.  aliases are required.
        """

        (
            place_input,
            transition,
            Transition,
            PlaceThingy,
            place,
            place_thingy,
            Place,
            place_output,
        ) = (
            self.tables.place_input,
            self.tables.transition,
            self.classes.Transition,
            self.classes.PlaceThingy,
            self.tables.place,
            self.tables.place_thingy,
            self.classes.Place,
            self.tables.place_output,
        )

        self.mapper_registry.map_imperatively(PlaceThingy, place_thingy)
        self.mapper_registry.map_imperatively(
            Place,
            place,
            properties={"thingies": relationship(PlaceThingy, lazy="joined")},
        )

        self.mapper_registry.map_imperatively(
            Transition,
            transition,
            properties=dict(
                inputs=relationship(Place, place_output, lazy="joined"),
                outputs=relationship(Place, place_input, lazy="joined"),
            ),
        )

        tran = Transition("transition1")
        tran.inputs.append(Place("place1"))
        tran.outputs.append(Place("place2"))
        tran.outputs.append(Place("place3"))
        sess = fixture_session()
        sess.add(tran)
        sess.commit()

        r = sess.query(Transition).all()
        self.assert_unordered_result(
            r,
            Transition,
            {
                "name": "transition1",
                "inputs": (Place, [{"name": "place1"}]),
                "outputs": (Place, [{"name": "place2"}, {"name": "place3"}]),
            },
        )

    def test_bidirectional(self):
        place_input, transition, Transition, Place, place, place_output = (
            self.tables.place_input,
            self.tables.transition,
            self.classes.Transition,
            self.classes.Place,
            self.tables.place,
            self.tables.place_output,
        )

        self.mapper_registry.map_imperatively(Place, place)
        self.mapper_registry.map_imperatively(
            Transition,
            transition,
            properties=dict(
                inputs=relationship(
                    Place,
                    place_output,
                    backref=backref(
                        "inputs", order_by=transition.c.transition_id
                    ),
                    order_by=Place.place_id,
                ),
                outputs=relationship(
                    Place,
                    place_input,
                    backref=backref(
                        "outputs", order_by=transition.c.transition_id
                    ),
                    order_by=Place.place_id,
                ),
            ),
        )

        t1 = Transition("transition1")
        t2 = Transition("transition2")
        t3 = Transition("transition3")
        p1 = Place("place1")
        p2 = Place("place2")
        p3 = Place("place3")

        sess = fixture_session()
        sess.add_all([p3, p1, t1, t2, p2, t3])

        t1.inputs.append(p1)
        t1.inputs.append(p2)
        t1.outputs.append(p3)
        t2.inputs.append(p1)
        p2.inputs.append(t2)
        p3.inputs.append(t2)
        p1.outputs.append(t1)
        sess.commit()

        self.assert_result(
            [t1],
            Transition,
            {"outputs": (Place, [{"name": "place3"}, {"name": "place1"}])},
        )
        self.assert_result(
            [p2],
            Place,
            {
                "inputs": (
                    Transition,
                    [{"name": "transition1"}, {"name": "transition2"}],
                )
            },
        )

    @testing.requires.updateable_autoincrement_pks
    @testing.requires.sane_multi_rowcount
    def test_stale_conditions(self):
        Place, Transition, place_input, place, transition = (
            self.classes.Place,
            self.classes.Transition,
            self.tables.place_input,
            self.tables.place,
            self.tables.transition,
        )

        self.mapper_registry.map_imperatively(
            Place,
            place,
            properties={
                "transitions": relationship(
                    Transition, secondary=place_input, passive_updates=False
                )
            },
        )
        self.mapper_registry.map_imperatively(Transition, transition)

        p1 = Place("place1")
        t1 = Transition("t1")
        p1.transitions.append(t1)
        sess = fixture_session()
        sess.add_all([p1, t1])
        sess.commit()

        p1.place_id
        p1.transitions

        sess.execute(place_input.delete())
        p1.place_id = 7

        assert_raises_message(
            orm_exc.StaleDataError,
            r"UPDATE statement on table 'place_input' expected to "
            r"update 1 row\(s\); Only 0 were matched.",
            sess.commit,
        )
        sess.rollback()

        p1.place_id
        p1.transitions
        sess.execute(place_input.delete())
        p1.transitions.remove(t1)
        assert_raises_message(
            orm_exc.StaleDataError,
            r"DELETE statement on table 'place_input' expected to "
            r"delete 1 row\(s\); Only 0 were matched.",
            sess.commit,
        )


class AssortedPersistenceTests(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        Table(
            "left",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("data", String(30)),
        )

        Table(
            "right",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("data", String(30)),
        )

        Table(
            "secondary",
            metadata,
            Column(
                "left_id", Integer, ForeignKey("left.id"), primary_key=True
            ),
            Column(
                "right_id", Integer, ForeignKey("right.id"), primary_key=True
            ),
        )

    @classmethod
    def setup_classes(cls):
        class A(cls.Comparable):
            pass

        class B(cls.Comparable):
            pass

    def _standard_bidirectional_fixture(self):
        left, secondary, right = (
            self.tables.left,
            self.tables.secondary,
            self.tables.right,
        )
        A, B = self.classes.A, self.classes.B
        self.mapper_registry.map_imperatively(
            A,
            left,
            properties={
                "bs": relationship(
                    B, secondary=secondary, backref="as", order_by=right.c.id
                )
            },
        )
        self.mapper_registry.map_imperatively(B, right)

    def _bidirectional_onescalar_fixture(self):
        left, secondary, right = (
            self.tables.left,
            self.tables.secondary,
            self.tables.right,
        )
        A, B = self.classes.A, self.classes.B
        self.mapper_registry.map_imperatively(
            A,
            left,
            properties={
                "bs": relationship(
                    B,
                    secondary=secondary,
                    backref=backref("a", uselist=False),
                    order_by=right.c.id,
                )
            },
        )
        self.mapper_registry.map_imperatively(B, right)

    def test_session_delete(self):
        self._standard_bidirectional_fixture()
        A, B = self.classes.A, self.classes.B
        secondary = self.tables.secondary

        sess = fixture_session()
        sess.add_all(
            [A(data="a1", bs=[B(data="b1")]), A(data="a2", bs=[B(data="b2")])]
        )
        sess.commit()

        a1 = sess.query(A).filter_by(data="a1").one()
        sess.delete(a1)
        sess.flush()
        eq_(sess.query(secondary).count(), 1)

        a2 = sess.query(A).filter_by(data="a2").one()
        sess.delete(a2)
        sess.flush()
        eq_(sess.query(secondary).count(), 0)

    def test_remove_scalar(self):
        # test setting a uselist=False to None
        self._bidirectional_onescalar_fixture()
        A, B = self.classes.A, self.classes.B
        secondary = self.tables.secondary

        sess = fixture_session()
        sess.add_all([A(data="a1", bs=[B(data="b1"), B(data="b2")])])
        sess.commit()

        a1 = sess.query(A).filter_by(data="a1").one()
        b2 = sess.query(B).filter_by(data="b2").one()
        assert b2.a is a1

        b2.a = None
        sess.commit()

        eq_(a1.bs, [B(data="b1")])
        eq_(b2.a, None)
        eq_(sess.query(secondary).count(), 1)
