import random

from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import lambda_stmt
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.future import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import subqueryload
from sqlalchemy.sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
from .inheritance import _poly_fixtures
from .test_query import QueryTest


class LambdaTest(QueryTest, AssertsCompiledSQL):
    __dialect__ = "default"

    # we want to test the lambda expiration logic so use backend
    # to exercise that

    __backend__ = True
    run_setup_mappers = None

    @testing.fixture
    def plain_fixture(self):
        users, Address, addresses, User = (
            self.tables.users,
            self.classes.Address,
            self.tables.addresses,
            self.classes.User,
        )

        self.mapper_registry.map_imperatively(
            User,
            users,
            properties={
                "addresses": relationship(Address, back_populates="user")
            },
        )

        self.mapper_registry.map_imperatively(
            Address,
            addresses,
            properties={
                "user": relationship(User, back_populates="addresses")
            },
        )

        return User, Address

    def test_user_cols_single_lambda(self, plain_fixture):
        User, Address = plain_fixture

        q = select(lambda: (User.id, User.name)).select_from(lambda: User)

        self.assert_compile(q, "SELECT users.id, users.name FROM users")

    def test_user_cols_single_lambda_query(self, plain_fixture):
        User, Address = plain_fixture

        s = fixture_session()
        q = s.query(lambda: (User.id, User.name)).select_from(lambda: User)

        self.assert_compile(
            q,
            "SELECT users.id AS users_id, users.name AS users_name FROM users",
        )

    def test_multiple_entities_single_lambda(self, plain_fixture):
        User, Address = plain_fixture

        q = select(lambda: (User, Address)).join(lambda: User.addresses)

        self.assert_compile(
            q,
            "SELECT users.id, users.name, addresses.id AS id_1, "
            "addresses.user_id, addresses.email_address "
            "FROM users JOIN addresses ON users.id = addresses.user_id",
        )

    def test_cols_round_trip(self, plain_fixture):
        User, Address = plain_fixture

        s = Session(testing.db, future=True)

        # note this does a traversal + _clone of the InstrumentedAttribute
        # for the first time ever
        def query(names):
            stmt = lambda_stmt(
                lambda: select(User.name, Address.email_address)
                .where(User.name.in_(names))
                .join(User.addresses)
            ) + (lambda s: s.order_by(User.id, Address.id))

            return s.execute(stmt)

        def go1():
            r1 = query(["ed"])
            eq_(
                r1.all(),
                [
                    ("ed", "ed@wood.com"),
                    ("ed", "ed@bettyboop.com"),
                    ("ed", "ed@lala.com"),
                ],
            )

        def go2():
            r1 = query(["ed", "fred"])
            eq_(
                r1.all(),
                [
                    ("ed", "ed@wood.com"),
                    ("ed", "ed@bettyboop.com"),
                    ("ed", "ed@lala.com"),
                    ("fred", "fred@fred.com"),
                ],
            )

        for i in range(5):
            fn = random.choice([go1, go2])
            fn()

    @testing.combinations(
        (True, True),
        (True, False),
        (False, False),
        argnames="use_aliased,use_indirect_access",
    )
    def test_entity_round_trip(
        self, plain_fixture, use_aliased, use_indirect_access
    ):
        User, Address = plain_fixture

        s = Session(testing.db, future=True)

        if use_aliased:
            if use_indirect_access:

                def query(names):
                    class Foo:
                        def __init__(self):
                            self.u1 = aliased(User)

                    f1 = Foo()

                    stmt = lambda_stmt(
                        lambda: select(f1.u1)
                        .where(f1.u1.name.in_(names))
                        .options(selectinload(f1.u1.addresses)),
                        track_on=[f1.u1],
                    ).add_criteria(
                        lambda s: s.order_by(f1.u1.id), track_on=[f1.u1]
                    )

                    return s.execute(stmt)

            else:

                def query(names):
                    u1 = aliased(User)
                    stmt = lambda_stmt(
                        lambda: select(u1)
                        .where(u1.name.in_(names))
                        .options(selectinload(u1.addresses))
                    ) + (lambda s: s.order_by(u1.id))

                    return s.execute(stmt)

        else:

            def query(names):
                stmt = lambda_stmt(
                    lambda: select(User)
                    .where(User.name.in_(names))
                    .options(selectinload(User.addresses))
                ) + (lambda s: s.order_by(User.id))

                return s.execute(stmt)

        def go1():
            r1 = query(["ed"])
            eq_(
                r1.scalars().all(),
                [User(name="ed", addresses=[Address(), Address(), Address()])],
            )

        def go2():
            r1 = query(["ed", "fred"])
            eq_(
                r1.scalars().all(),
                [
                    User(
                        name="ed", addresses=[Address(), Address(), Address()]
                    ),
                    User(name="fred", addresses=[Address()]),
                ],
            )

        for i in range(5):
            fn = random.choice([go1, go2])
            self.assert_sql_count(testing.db, fn, 2)

    def test_lambdas_rejected_in_options(self, plain_fixture):
        User, Address = plain_fixture

        assert_raises_message(
            exc.ArgumentError,
            "ExecutionOption Core or ORM object expected, got",
            select(lambda: User).options,
            lambda: subqueryload(User.addresses),
        )

    def test_subqueryload_internal_lambda(self, plain_fixture):
        User, Address = plain_fixture

        s = Session(testing.db, future=True)

        def query(names):
            stmt = (
                select(lambda: User)
                .where(lambda: User.name.in_(names))
                .options(subqueryload(User.addresses))
                .order_by(lambda: User.id)
            )

            return s.execute(stmt)

        def go1():
            r1 = query(["ed"])
            eq_(
                r1.scalars().all(),
                [User(name="ed", addresses=[Address(), Address(), Address()])],
            )

        def go2():
            r1 = query(["ed", "fred"])
            eq_(
                r1.scalars().all(),
                [
                    User(
                        name="ed", addresses=[Address(), Address(), Address()]
                    ),
                    User(name="fred", addresses=[Address()]),
                ],
            )

        for i in range(5):
            fn = random.choice([go1, go2])
            self.assert_sql_count(testing.db, fn, 2)

    def test_subqueryload_external_lambda_caveats(self, plain_fixture):
        User, Address = plain_fixture

        s = Session(testing.db, future=True)

        def query(names):
            stmt = lambda_stmt(
                lambda: select(User)
                .where(User.name.in_(names))
                .options(subqueryload(User.addresses))
            ) + (lambda s: s.order_by(User.id))

            return s.execute(stmt)

        def go1():
            r1 = query(["ed"])
            eq_(
                r1.scalars().all(),
                [User(name="ed", addresses=[Address(), Address(), Address()])],
            )

        def go2():
            r1 = query(["ed", "fred"])
            eq_(
                r1.scalars().all(),
                [
                    User(
                        name="ed", addresses=[Address(), Address(), Address()]
                    ),
                    User(name="fred", addresses=[Address()]),
                ],
            )

        for i in range(5):
            fn = random.choice([go1, go2])
            with testing.expect_warnings(
                'subqueryloader for "User.addresses" must invoke lambda '
                r"callable at .*LambdaElement\(<code object <lambda> "
                r".*test_lambdas.py.* in order to produce a new query, "
                r"decreasing the efficiency of caching"
            ):
                self.assert_sql_count(testing.db, fn, 2)

    @testing.combinations(
        lambda s, User, Address: s.query(lambda: User).join(lambda: Address),
        lambda s, User, Address: s.query(lambda: User).join(
            lambda: User.addresses
        ),
        lambda s, User, Address: s.query(lambda: User).join(
            lambda: Address, lambda: User.addresses
        ),
        lambda s, User, Address: s.query(lambda: User).join(
            Address, lambda: User.addresses
        ),
        lambda s, User, Address: s.query(lambda: User).join(
            lambda: Address, User.addresses
        ),
        lambda User, Address: select(lambda: User)
        .join(lambda: Address)
        .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
        lambda User, Address: select(lambda: User)
        .join(lambda: User.addresses)
        .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
        lambda User, Address: select(lambda: User)
        .join(lambda: Address, lambda: User.addresses)
        .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
        lambda User, Address: select(lambda: User)
        .join(Address, lambda: User.addresses)
        .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
        lambda User, Address: select(lambda: User)
        .join(lambda: Address, User.addresses)
        .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
        argnames="test_case",
    )
    def test_join_entity_arg(self, plain_fixture, test_case):
        User, Address = plain_fixture

        s = Session(testing.db, future=True)

        stmt = testing.resolve_lambda(test_case, **locals())
        self.assert_compile(
            stmt,
            "SELECT users.id AS users_id, users.name AS users_name "
            "FROM users JOIN addresses ON users.id = addresses.user_id",
        )


class PolymorphicTest(_poly_fixtures._Polymorphic):
    run_setup_mappers = "once"
    __dialect__ = "default"

    def test_join_second_prop_lambda(self):
        Company = self.classes.Company
        Manager = self.classes.Manager

        s = Session(future=True)

        q = s.query(Company).join(lambda: Manager, lambda: Company.employees)

        self.assert_compile(
            q,
            "SELECT companies.company_id AS companies_company_id, "
            "companies.name AS companies_name FROM companies "
            "JOIN (people JOIN managers ON people.person_id = "
            "managers.person_id) ON companies.company_id = people.company_id",
        )


class UpdateDeleteTest(fixtures.MappedTest):
    __backend__ = True

    run_setup_mappers = "once"

    @classmethod
    def define_tables(cls, metadata):
        Table(
            "users",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("name", String(32)),
            Column("age_int", Integer),
        )
        Table(
            "addresses",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("user_id", ForeignKey("users.id")),
        )

    @classmethod
    def setup_classes(cls):
        class User(cls.Comparable):
            pass

        class Address(cls.Comparable):
            pass

    @classmethod
    def insert_data(cls, connection):
        users = cls.tables.users

        connection.execute(
            users.insert(),
            [
                dict(id=1, name="john", age_int=25),
                dict(id=2, name="jack", age_int=47),
                dict(id=3, name="jill", age_int=29),
                dict(id=4, name="jane", age_int=37),
            ],
        )

    @classmethod
    def setup_mappers(cls):
        User = cls.classes.User
        users = cls.tables.users

        Address = cls.classes.Address
        addresses = cls.tables.addresses

        cls.mapper_registry.map_imperatively(
            User,
            users,
            properties={
                "age": users.c.age_int,
                "addresses": relationship(Address),
            },
        )
        cls.mapper_registry.map_imperatively(Address, addresses)

    def test_update(self):
        User, Address = self.classes("User", "Address")

        s = Session(testing.db, future=True)

        def go(ids, values):
            stmt = lambda_stmt(lambda: update(User).where(User.id.in_(ids)))
            s.execute(
                stmt,
                values,
                # note this currently just unrolls the lambda on the statement.
                # so lambda caching for updates is not actually that useful
                # unless synchronize_session is turned off.
                # evaluate is similar just doesn't work for IN yet.
                execution_options={"synchronize_session": "fetch"},
            )

        go([1, 2], {"name": "jack2"})
        eq_(
            s.execute(select(User.id, User.name).order_by(User.id)).all(),
            [(1, "jack2"), (2, "jack2"), (3, "jill"), (4, "jane")],
        )

        go([3], {"name": "jane2"})
        eq_(
            s.execute(select(User.id, User.name).order_by(User.id)).all(),
            [(1, "jack2"), (2, "jack2"), (3, "jane2"), (4, "jane")],
        )
