from sqlalchemy import Boolean
from sqlalchemy import case
from sqlalchemy import column
from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy.orm import backref
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.orm import synonym
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from sqlalchemy.testing.assertsql import CompiledSQL
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table


class UpdateDeleteTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            "users",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("name", String(32)),
            Column("age_int", Integer),
        )
        Table(
            "addresses",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("user_id", ForeignKey("users.id")),
        )

    @classmethod
    def setup_classes(cls):
        class User(cls.Comparable):
            pass

        class Address(cls.Comparable):
            pass

    @classmethod
    def insert_data(cls, connection):
        users = cls.tables.users

        connection.execute(
            users.insert(),
            [
                dict(id=1, name="john", age_int=25),
                dict(id=2, name="jack", age_int=47),
                dict(id=3, name="jill", age_int=29),
                dict(id=4, name="jane", age_int=37),
            ],
        )

    @classmethod
    def setup_mappers(cls):
        User = cls.classes.User
        users = cls.tables.users

        Address = cls.classes.Address
        addresses = cls.tables.addresses

        mapper(
            User,
            users,
            properties={
                "age": users.c.age_int,
                "addresses": relationship(Address),
            },
        )
        mapper(Address, addresses)

    def test_illegal_eval(self):
        User = self.classes.User
        s = Session()
        assert_raises_message(
            exc.ArgumentError,
            "Valid strategies for session synchronization "
            "are 'evaluate', 'fetch', False",
            s.query(User).update,
            {},
            synchronize_session="fake",
        )

    def test_illegal_operations(self):
        User = self.classes.User
        Address = self.classes.Address

        s = Session()

        for q, mname in (
            (s.query(User).limit(2), r"limit\(\)"),
            (s.query(User).offset(2), r"offset\(\)"),
            (s.query(User).limit(2).offset(2), r"limit\(\)"),
            (s.query(User).order_by(User.id), r"order_by\(\)"),
            (s.query(User).group_by(User.id), r"group_by\(\)"),
            (s.query(User).distinct(), r"distinct\(\)"),
            (
                s.query(User).join(User.addresses),
                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)",
            ),
            (
                s.query(User).outerjoin(User.addresses),
                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)",
            ),
            (
                s.query(User).select_from(Address),
                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)",
            ),
            (
                s.query(User).from_self(),
                r"join\(\), outerjoin\(\), select_from\(\), or from_self\(\)",
            ),
        ):
            assert_raises_message(
                exc.InvalidRequestError,
                r"Can't call Query.update\(\) or Query.delete\(\) when "
                "%s has been called" % mname,
                q.update,
                {"name": "ed"},
            )
            assert_raises_message(
                exc.InvalidRequestError,
                r"Can't call Query.update\(\) or Query.delete\(\) when "
                "%s has been called" % mname,
                q.delete,
            )

    def test_evaluate_clauseelement(self):
        User = self.classes.User

        class Thing(object):
            def __clause_element__(self):
                return User.name.__clause_element__()

        s = Session()
        jill = s.query(User).get(3)
        s.query(User).update(
            {Thing(): "moonbeam"}, synchronize_session="evaluate"
        )
        eq_(jill.name, "moonbeam")

    def test_evaluate_invalid(self):
        User = self.classes.User

        class Thing(object):
            def __clause_element__(self):
                return 5

        s = Session()

        assert_raises_message(
            exc.InvalidRequestError,
            "Invalid expression type: 5",
            s.query(User).update,
            {Thing(): "moonbeam"},
            synchronize_session="evaluate",
        )

    def test_evaluate_unmapped_col(self):
        User = self.classes.User

        s = Session()
        jill = s.query(User).get(3)
        s.query(User).update(
            {column("name"): "moonbeam"}, synchronize_session="evaluate"
        )
        eq_(jill.name, "jill")
        s.expire(jill)
        eq_(jill.name, "moonbeam")

    def test_evaluate_synonym_string(self):
        class Foo(object):
            pass

        mapper(Foo, self.tables.users, properties={"uname": synonym("name")})

        s = Session()
        jill = s.query(Foo).get(3)
        s.query(Foo).update(
            {"uname": "moonbeam"}, synchronize_session="evaluate"
        )
        eq_(jill.uname, "moonbeam")

    def test_evaluate_synonym_attr(self):
        class Foo(object):
            pass

        mapper(Foo, self.tables.users, properties={"uname": synonym("name")})

        s = Session()
        jill = s.query(Foo).get(3)
        s.query(Foo).update(
            {Foo.uname: "moonbeam"}, synchronize_session="evaluate"
        )
        eq_(jill.uname, "moonbeam")

    def test_evaluate_double_synonym_attr(self):
        class Foo(object):
            pass

        mapper(
            Foo,
            self.tables.users,
            properties={"uname": synonym("name"), "ufoo": synonym("uname")},
        )

        s = Session()
        jill = s.query(Foo).get(3)
        s.query(Foo).update(
            {Foo.ufoo: "moonbeam"}, synchronize_session="evaluate"
        )
        eq_(jill.ufoo, "moonbeam")

    @testing.combinations(
        (False, False),
        (False, True),
        (True, False),
        (True, True),
    )
    def test_evaluate_dont_refresh_expired_objects(
        self, expire_jane_age, add_filter_criteria
    ):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        sess.expire(john)
        sess.expire(jill)

        if expire_jane_age:
            sess.expire(jane, ["name", "age"])
        else:
            sess.expire(jane, ["name"])

        with self.sql_execution_asserter() as asserter:
            # using 1.x style for easier backport
            if add_filter_criteria:
                sess.query(User).filter(User.name != None).update(
                    {"age": User.age + 10}, synchronize_session="evaluate"
                )
            else:
                sess.query(User).update(
                    {"age": User.age + 10}, synchronize_session="evaluate"
                )

        if add_filter_criteria:
            if expire_jane_age:
                asserter.assert_(
                    # it has to unexpire jane.name, because jane is not fully
                    # expired and the critiera needs to look at this particular
                    # key
                    CompiledSQL(
                        "SELECT users.age_int AS users_age_int, "
                        "users.name AS users_name FROM users "
                        "WHERE users.id = :param_1",
                        [{"param_1": 4}],
                    ),
                    CompiledSQL(
                        "UPDATE users "
                        "SET age_int=(users.age_int + :age_int_1) "
                        "WHERE users.name IS NOT NULL",
                        [{"age_int_1": 10}],
                    ),
                )
            else:
                asserter.assert_(
                    # it has to unexpire jane.name, because jane is not fully
                    # expired and the critiera needs to look at this particular
                    # key
                    CompiledSQL(
                        "SELECT users.name AS users_name FROM users "
                        "WHERE users.id = :param_1",
                        [{"param_1": 4}],
                    ),
                    CompiledSQL(
                        "UPDATE users SET "
                        "age_int=(users.age_int + :age_int_1) "
                        "WHERE users.name IS NOT NULL",
                        [{"age_int_1": 10}],
                    ),
                )
        else:
            asserter.assert_(
                CompiledSQL(
                    "UPDATE users SET age_int=(users.age_int + :age_int_1)",
                    [{"age_int_1": 10}],
                ),
            )

        with self.sql_execution_asserter() as asserter:
            eq_(john.age, 35)  # needs refresh
            eq_(jack.age, 57)  # no SQL needed
            eq_(jill.age, 39)  # needs refresh
            eq_(jane.age, 47)  # needs refresh

        to_assert = [
            # refresh john
            CompiledSQL(
                "SELECT users.age_int AS users_age_int, "
                "users.id AS users_id, users.name AS users_name FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 1}],
            ),
            # refresh jill
            CompiledSQL(
                "SELECT users.age_int AS users_age_int, "
                "users.id AS users_id, users.name AS users_name FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 3}],
            ),
        ]

        if expire_jane_age and not add_filter_criteria:
            to_assert.append(
                # refresh jane
                CompiledSQL(
                    "SELECT users.age_int AS users_age_int, "
                    "users.name AS users_name FROM users "
                    "WHERE users.id = :param_1",
                    [{"param_1": 4}],
                )
            )
        asserter.assert_(*to_assert)

    def test_fetch_dont_refresh_expired_objects(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        sess.expire(john)
        sess.expire(jill)
        sess.expire(jane, ["name"])

        with self.sql_execution_asserter() as asserter:
            # using 1.x style for easier backport
            sess.query(User).filter(User.name != None).update(
                {"age": User.age + 10}, synchronize_session="fetch"
            )

        asserter.assert_(
            CompiledSQL(
                "SELECT users.id AS users_id FROM users "
                "WHERE users.name IS NOT NULL"
            ),
            CompiledSQL(
                "UPDATE users SET age_int=(users.age_int + :age_int_1) "
                "WHERE users.name IS NOT NULL",
                [{"age_int_1": 10}],
            ),
        )

        with self.sql_execution_asserter() as asserter:
            eq_(john.age, 35)  # needs refresh
            eq_(jack.age, 57)  # refreshes in 1.3
            eq_(jill.age, 39)  # needs refresh
            eq_(jane.age, 47)  # refreshes in 1.3

        asserter.assert_(
            # refresh john
            CompiledSQL(
                "SELECT users.age_int AS users_age_int, "
                "users.id AS users_id, users.name AS users_name FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 1}],
            ),
            # refresh jack.age in 1.3 only
            CompiledSQL(
                "SELECT users.age_int AS users_age_int FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 2}],
            ),
            # refresh jill
            CompiledSQL(
                "SELECT users.age_int AS users_age_int, "
                "users.id AS users_id, users.name AS users_name FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 3}],
            ),
            # refresh jane, seems to be a full refresh in 1.3.
            CompiledSQL(
                "SELECT users.age_int AS users_age_int, "
                "users.name AS users_name FROM users "
                "WHERE users.id = :param_1",
                [{"param_1": 4}],
            ),
        )

    def test_delete(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(
            or_(User.name == "john", User.name == "jill")
        ).delete()

        assert john not in sess and jill not in sess

        eq_(sess.query(User).order_by(User.id).all(), [jack, jane])

    def test_delete_against_metadata(self):
        User = self.classes.User
        users = self.tables.users

        sess = Session()
        sess.query(users).delete(synchronize_session=False)
        eq_(sess.query(User).count(), 0)

    def test_delete_with_bindparams(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(text("name = :name")).params(
            name="john"
        ).delete("fetch")
        assert john not in sess

        eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])

    def test_delete_rollback(self):
        User = self.classes.User

        sess = Session()
        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(
            or_(User.name == "john", User.name == "jill")
        ).delete(synchronize_session="evaluate")
        assert john not in sess and jill not in sess
        sess.rollback()
        assert john in sess and jill in sess

    def test_delete_rollback_with_fetch(self):
        User = self.classes.User

        sess = Session()
        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(
            or_(User.name == "john", User.name == "jill")
        ).delete(synchronize_session="fetch")
        assert john not in sess and jill not in sess
        sess.rollback()
        assert john in sess and jill in sess

    def test_delete_without_session_sync(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(
            or_(User.name == "john", User.name == "jill")
        ).delete(synchronize_session=False)

        assert john in sess and jill in sess

        eq_(sess.query(User).order_by(User.id).all(), [jack, jane])

    def test_delete_with_fetch_strategy(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(
            or_(User.name == "john", User.name == "jill")
        ).delete(synchronize_session="fetch")

        assert john not in sess and jill not in sess

        eq_(sess.query(User).order_by(User.id).all(), [jack, jane])

    @testing.requires.update_where_target_in_subquery
    def test_delete_invalid_evaluation(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        assert_raises(
            exc.InvalidRequestError,
            sess.query(User)
            .filter(User.name == select([func.max(User.name)]))
            .delete,
            synchronize_session="evaluate",
        )

        sess.query(User).filter(
            User.name == select([func.max(User.name)])
        ).delete(synchronize_session="fetch")

        assert john not in sess

        eq_(sess.query(User).order_by(User.id).all(), [jack, jill, jane])

    def test_update(self):
        User, users = self.classes.User, self.tables.users

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(User.age > 29).update(
            {"age": User.age - 10}, synchronize_session="evaluate"
        )

        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 37, 29, 27])),
        )

        sess.query(User).filter(User.age > 29).update(
            {User.age: User.age - 10}, synchronize_session="evaluate"
        )
        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 27, 29, 27])),
        )

        sess.query(User).filter(User.age > 27).update(
            {users.c.age_int: User.age - 10}, synchronize_session="evaluate"
        )
        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 27, 19, 27])),
        )

        sess.query(User).filter(User.age == 25).update(
            {User.age: User.age - 10}, synchronize_session="fetch"
        )
        eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([15, 27, 19, 27])),
        )

    def test_update_against_table_col(self):
        User, users = self.classes.User, self.tables.users

        sess = Session()
        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        eq_([john.age, jack.age, jill.age, jane.age], [25, 47, 29, 37])
        sess.query(User).filter(User.age > 27).update(
            {users.c.age_int: User.age - 10}, synchronize_session="evaluate"
        )
        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 19, 27])

    def test_update_against_metadata(self):
        User, users = self.classes.User, self.tables.users

        sess = Session()

        sess.query(users).update(
            {users.c.age_int: 29}, synchronize_session=False
        )
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([29, 29, 29, 29])),
        )

    def test_update_with_bindparams(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        sess.query(User).filter(text("age_int > :x")).params(x=29).update(
            {"age": User.age - 10}, synchronize_session="fetch"
        )

        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 37, 29, 27])),
        )

    def test_update_without_load(self):
        User = self.classes.User

        sess = Session()

        sess.query(User).filter(User.id == 3).update(
            {"age": 44}, synchronize_session="fetch"
        )
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 47, 44, 37])),
        )

    def test_update_changes_resets_dirty(self):
        User = self.classes.User

        sess = Session(autoflush=False)

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        john.age = 50
        jack.age = 37

        # autoflush is false.  therefore our '50' and '37' are getting
        # blown away by this operation.

        sess.query(User).filter(User.age > 29).update(
            {"age": User.age - 10}, synchronize_session="evaluate"
        )

        for x in (john, jack, jill, jane):
            assert not sess.is_modified(x)

        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])

        john.age = 25
        assert john in sess.dirty
        assert jack in sess.dirty
        assert jill not in sess.dirty
        assert not sess.is_modified(john)
        assert not sess.is_modified(jack)

    def test_update_changes_with_autoflush(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()

        john.age = 50
        jack.age = 37

        sess.query(User).filter(User.age > 29).update(
            {"age": User.age - 10}, synchronize_session="evaluate"
        )

        for x in (john, jack, jill, jane):
            assert not sess.is_modified(x)

        eq_([john.age, jack.age, jill.age, jane.age], [40, 27, 29, 27])

        john.age = 25
        assert john in sess.dirty
        assert jack not in sess.dirty
        assert jill not in sess.dirty
        assert sess.is_modified(john)
        assert not sess.is_modified(jack)

    def test_update_with_expire_strategy(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).filter(User.age > 29).update(
            {"age": User.age - 10}, synchronize_session="fetch"
        )

        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 37, 29, 27])),
        )

    @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
    def test_update_returns_rowcount(self):
        User = self.classes.User

        sess = Session()

        rowcount = (
            sess.query(User)
            .filter(User.age > 29)
            .update({"age": User.age + 0})
        )
        eq_(rowcount, 2)

        rowcount = (
            sess.query(User)
            .filter(User.age > 29)
            .update({"age": User.age - 10})
        )
        eq_(rowcount, 2)

    @testing.fails_if(lambda: not testing.db.dialect.supports_sane_rowcount)
    def test_delete_returns_rowcount(self):
        User = self.classes.User

        sess = Session()

        rowcount = (
            sess.query(User)
            .filter(User.age > 26)
            .delete(synchronize_session=False)
        )
        eq_(rowcount, 3)

    def test_update_all(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).update({"age": 42}, synchronize_session="evaluate")

        eq_([john.age, jack.age, jill.age, jane.age], [42, 42, 42, 42])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([42, 42, 42, 42])),
        )

    def test_delete_all(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).delete(synchronize_session="evaluate")

        assert not (
            john in sess or jack in sess or jill in sess or jane in sess
        )
        eq_(sess.query(User).count(), 0)

    def test_autoflush_before_evaluate_update(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        john.name = "j2"

        sess.query(User).filter_by(name="j2").update(
            {"age": 42}, synchronize_session="evaluate"
        )
        eq_(john.age, 42)

    def test_autoflush_before_fetch_update(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        john.name = "j2"

        sess.query(User).filter_by(name="j2").update(
            {"age": 42}, synchronize_session="fetch"
        )
        eq_(john.age, 42)

    def test_autoflush_before_evaluate_delete(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        john.name = "j2"

        sess.query(User).filter_by(name="j2").delete(
            synchronize_session="evaluate"
        )
        assert john not in sess

    def test_autoflush_before_fetch_delete(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        john.name = "j2"

        sess.query(User).filter_by(name="j2").delete(
            synchronize_session="fetch"
        )
        assert john not in sess

    def test_evaluate_before_update(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        sess.expire(john, ["age"])

        # eval must be before the update.  otherwise
        # we eval john, age has been expired and doesn't
        # match the new value coming in
        sess.query(User).filter_by(name="john").filter_by(age=25).update(
            {"name": "j2", "age": 40}, synchronize_session="evaluate"
        )
        eq_(john.name, "j2")
        eq_(john.age, 40)

    def test_fetch_before_update(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        sess.expire(john, ["age"])

        sess.query(User).filter_by(name="john").filter_by(age=25).update(
            {"name": "j2", "age": 40}, synchronize_session="fetch"
        )
        eq_(john.name, "j2")
        eq_(john.age, 40)

    def test_evaluate_before_delete(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        sess.expire(john, ["age"])

        sess.query(User).filter_by(name="john").filter_by(age=25).delete(
            synchronize_session="evaluate"
        )
        assert john not in sess

    def test_fetch_before_delete(self):
        User = self.classes.User

        sess = Session()
        john = sess.query(User).filter_by(name="john").one()
        sess.expire(john, ["age"])

        sess.query(User).filter_by(name="john").filter_by(age=25).delete(
            synchronize_session="fetch"
        )
        assert john not in sess

    def test_update_unordered_dict(self):
        User = self.classes.User
        session = Session()

        # Do an update using unordered dict and check that the parameters used
        # are ordered in table order
        q = session.query(User)
        with mock.patch.object(q, "_execute_crud") as exec_:
            q.filter(User.id == 15).update({"name": "foob", "id": 123})
            # Confirm that parameters are a dict instead of tuple or list
            params_type = type(exec_.mock_calls[0][1][0].parameters)
            is_(params_type, dict)

    def test_update_preserve_parameter_order(self):
        User = self.classes.User
        session = Session()

        # Do update using a tuple and check that order is preserved
        q = session.query(User)
        with mock.patch.object(q, "_execute_crud") as exec_:
            q.filter(User.id == 15).update(
                (("id", 123), ("name", "foob")),
                update_args={"preserve_parameter_order": True},
            )
            cols = [
                c.key for c in exec_.mock_calls[0][1][0]._parameter_ordering
            ]
            eq_(["id", "name"], cols)

        # Now invert the order and use a list instead, and check that order is
        # also preserved
        q = session.query(User)
        with mock.patch.object(q, "_execute_crud") as exec_:
            q.filter(User.id == 15).update(
                [("name", "foob"), ("id", 123)],
                update_args={"preserve_parameter_order": True},
            )
            cols = [
                c.key for c in exec_.mock_calls[0][1][0]._parameter_ordering
            ]
            eq_(["name", "id"], cols)


class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        Table(
            "users",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("name", String(32)),
            Column("age", Integer),
        )

        Table(
            "documents",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("user_id", None, ForeignKey("users.id")),
            Column("title", String(32)),
        )

    @classmethod
    def setup_classes(cls):
        class User(cls.Comparable):
            pass

        class Document(cls.Comparable):
            pass

    @classmethod
    def insert_data(cls, connection):
        users = cls.tables.users

        connection.execute(
            users.insert(),
            [
                dict(id=1, name="john", age=25),
                dict(id=2, name="jack", age=47),
                dict(id=3, name="jill", age=29),
                dict(id=4, name="jane", age=37),
            ],
        )

        documents = cls.tables.documents

        connection.execute(
            documents.insert(),
            [
                dict(id=1, user_id=1, title="foo"),
                dict(id=2, user_id=1, title="bar"),
                dict(id=3, user_id=2, title="baz"),
            ],
        )

    @classmethod
    def setup_mappers(cls):
        documents, Document, User, users = (
            cls.tables.documents,
            cls.classes.Document,
            cls.classes.User,
            cls.tables.users,
        )

        mapper(User, users)
        mapper(
            Document,
            documents,
            properties={
                "user": relationship(
                    User,
                    lazy="joined",
                    backref=backref("documents", lazy="select"),
                )
            },
        )

    def test_update_with_eager_relationships(self):
        Document = self.classes.Document

        sess = Session()

        foo, bar, baz = sess.query(Document).order_by(Document.id).all()
        sess.query(Document).filter(Document.user_id == 1).update(
            {"title": Document.title + Document.title},
            synchronize_session="fetch",
        )

        eq_([foo.title, bar.title, baz.title], ["foofoo", "barbar", "baz"])
        eq_(
            sess.query(Document.title).order_by(Document.id).all(),
            list(zip(["foofoo", "barbar", "baz"])),
        )

    def test_update_with_explicit_joinedload(self):
        User = self.classes.User

        sess = Session()

        john, jack, jill, jane = sess.query(User).order_by(User.id).all()
        sess.query(User).options(joinedload(User.documents)).filter(
            User.age > 29
        ).update({"age": User.age - 10}, synchronize_session="fetch")

        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
        eq_(
            sess.query(User.age).order_by(User.id).all(),
            list(zip([25, 37, 29, 27])),
        )

    def test_delete_with_eager_relationships(self):
        Document = self.classes.Document

        sess = Session()

        sess.query(Document).filter(Document.user_id == 1).delete(
            synchronize_session=False
        )

        eq_(sess.query(Document.title).all(), list(zip(["baz"])))


class UpdateDeleteFromTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            "users",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("samename", String(10)),
        )
        Table(
            "documents",
            metadata,
            Column("id", Integer, primary_key=True),
            Column("user_id", None, ForeignKey("users.id")),
            Column("title", String(32)),
            Column("flag", Boolean),
            Column("samename", String(10)),
        )

    @classmethod
    def setup_classes(cls):
        class User(cls.Comparable):
            pass

        class Document(cls.Comparable):
            pass

    @classmethod
    def insert_data(cls, connection):
        users = cls.tables.users

        connection.execute(
            users.insert(), [dict(id=1), dict(id=2), dict(id=3), dict(id=4)]
        )

        documents = cls.tables.documents

        connection.execute(
            documents.insert(),
            [
                dict(id=1, user_id=1, title="foo"),
                dict(id=2, user_id=1, title="bar"),
                dict(id=3, user_id=2, title="baz"),
                dict(id=4, user_id=2, title="hoho"),
                dict(id=5, user_id=3, title="lala"),
                dict(id=6, user_id=3, title="bleh"),
            ],
        )

    @classmethod
    def setup_mappers(cls):
        documents, Document, User, users = (
            cls.tables.documents,
            cls.classes.Document,
            cls.classes.User,
            cls.tables.users,
        )

        mapper(User, users)
        mapper(
            Document,
            documents,
            properties={"user": relationship(User, backref="documents")},
        )

    @testing.requires.update_from
    def test_update_from_joined_subq_test(self):
        Document = self.classes.Document
        s = Session()

        subq = (
            s.query(func.max(Document.title).label("title"))
            .group_by(Document.user_id)
            .subquery()
        )

        s.query(Document).filter(Document.title == subq.c.title).update(
            {"flag": True}, synchronize_session=False
        )

        eq_(
            set(s.query(Document.id, Document.flag)),
            set(
                [
                    (1, True),
                    (2, None),
                    (3, None),
                    (4, True),
                    (5, True),
                    (6, None),
                ]
            ),
        )

    @testing.requires.delete_from
    def test_delete_from_joined_subq_test(self):
        Document = self.classes.Document
        s = Session()

        subq = (
            s.query(func.max(Document.title).label("title"))
            .group_by(Document.user_id)
            .subquery()
        )

        s.query(Document).filter(Document.title == subq.c.title).delete(
            synchronize_session=False
        )

        eq_(
            set(s.query(Document.id, Document.flag)),
            set([(2, None), (3, None), (6, None)]),
        )

    def test_no_eval_against_multi_table_criteria(self):
        User = self.classes.User
        Document = self.classes.Document

        s = Session()

        q = s.query(User).filter(User.id == Document.user_id)
        assert_raises_message(
            exc.InvalidRequestError,
            "Could not evaluate current criteria in Python.",
            q.update,
            {"name": "ed"},
        )

    @testing.requires.update_where_target_in_subquery
    def test_update_using_in(self):
        Document = self.classes.Document
        s = Session()

        subq = (
            s.query(func.max(Document.title).label("title"))
            .group_by(Document.user_id)
            .subquery()
        )

        s.query(Document).filter(Document.title.in_(subq)).update(
            {"flag": True}, synchronize_session=False
        )

        eq_(
            set(s.query(Document.id, Document.flag)),
            set(
                [
                    (1, True),
                    (2, None),
                    (3, None),
                    (4, True),
                    (5, True),
                    (6, None),
                ]
            ),
        )

    @testing.requires.update_where_target_in_subquery
    @testing.requires.standalone_binds
    def test_update_using_case(self):
        Document = self.classes.Document
        s = Session()

        subq = (
            s.query(func.max(Document.title).label("title"))
            .group_by(Document.user_id)
            .subquery()
        )

        # this would work with Firebird if you do literal_column('1')
        # instead
        case_stmt = case([(Document.title.in_(subq), True)], else_=False)
        s.query(Document).update(
            {"flag": case_stmt}, synchronize_session=False
        )

        eq_(
            set(s.query(Document.id, Document.flag)),
            set(
                [
                    (1, True),
                    (2, False),
                    (3, False),
                    (4, True),
                    (5, True),
                    (6, False),
                ]
            ),
        )

    @testing.only_on("mysql", "Multi table update")
    def test_update_from_multitable_same_names(self):
        Document = self.classes.Document
        User = self.classes.User

        s = Session()

        s.query(Document).filter(User.id == Document.user_id).filter(
            User.id == 2
        ).update(
            {Document.samename: "d_samename", User.samename: "u_samename"},
            synchronize_session=False,
        )
        eq_(
            s.query(User.id, Document.samename, User.samename)
            .filter(User.id == Document.user_id)
            .order_by(User.id)
            .all(),
            [
                (1, None, None),
                (1, None, None),
                (2, "d_samename", "u_samename"),
                (2, "d_samename", "u_samename"),
                (3, None, None),
                (3, None, None),
            ],
        )


class ExpressionUpdateTest(fixtures.MappedTest):
    @classmethod
    def define_tables(cls, metadata):
        Table(
            "data",
            metadata,
            Column(
                "id", Integer, primary_key=True, test_needs_autoincrement=True
            ),
            Column("counter", Integer, nullable=False, default=0),
        )

    @classmethod
    def setup_classes(cls):
        class Data(cls.Comparable):
            pass

    @classmethod
    def setup_mappers(cls):
        data = cls.tables.data
        mapper(cls.classes.Data, data, properties={"cnt": data.c.counter})

    @testing.provide_metadata
    def test_update_attr_names(self):
        Data = self.classes.Data

        d1 = Data()
        sess = Session()
        sess.add(d1)
        sess.commit()
        eq_(d1.cnt, 0)

        sess.query(Data).update({Data.cnt: Data.cnt + 1})
        sess.flush()

        eq_(d1.cnt, 1)

        sess.query(Data).update({Data.cnt: Data.cnt + 1}, "fetch")
        sess.flush()

        eq_(d1.cnt, 2)
        sess.close()

    def test_update_args(self):
        Data = self.classes.Data
        session = testing.mock.Mock(wraps=Session())
        update_args = {"mysql_limit": 1}

        q = session.query(Data)
        with testing.mock.patch.object(q, "_execute_crud") as exec_:
            q.update({Data.cnt: Data.cnt + 1}, update_args=update_args)
        eq_(exec_.call_count, 1)
        args, kwargs = exec_.mock_calls[0][1:3]
        eq_(len(args), 2)
        update_stmt = args[0]
        eq_(update_stmt.dialect_kwargs, update_args)


class InheritTest(fixtures.DeclarativeMappedTest):

    run_inserts = "each"

    run_deletes = "each"
    __backend__ = True

    @classmethod
    def setup_classes(cls):
        Base = cls.DeclarativeBasic

        class Person(Base):
            __tablename__ = "person"
            id = Column(
                Integer, primary_key=True, test_needs_autoincrement=True
            )
            type = Column(String(50))
            name = Column(String(50))

        class Engineer(Person):
            __tablename__ = "engineer"
            id = Column(Integer, ForeignKey("person.id"), primary_key=True)
            engineer_name = Column(String(50))

        class Manager(Person):
            __tablename__ = "manager"
            id = Column(Integer, ForeignKey("person.id"), primary_key=True)
            manager_name = Column(String(50))

    @classmethod
    def insert_data(cls, connection):
        Engineer, Person, Manager = (
            cls.classes.Engineer,
            cls.classes.Person,
            cls.classes.Manager,
        )
        s = Session(connection)
        s.add_all(
            [
                Engineer(name="e1", engineer_name="e1"),
                Manager(name="m1", manager_name="m1"),
                Engineer(name="e2", engineer_name="e2"),
                Person(name="p1"),
            ]
        )
        s.commit()

    def test_illegal_metadata(self):
        person = self.classes.Person.__table__
        engineer = self.classes.Engineer.__table__

        sess = Session()
        assert_raises_message(
            exc.InvalidRequestError,
            "This operation requires only one Table or entity be "
            "specified as the target.",
            sess.query(person.join(engineer)).update,
            {},
        )

    def test_update_subtable_only(self):
        Engineer = self.classes.Engineer
        s = Session(testing.db)
        s.query(Engineer).update({"engineer_name": "e5"})

        eq_(s.query(Engineer.engineer_name).all(), [("e5",), ("e5",)])

    @testing.requires.update_from
    def test_update_from(self):
        Engineer = self.classes.Engineer
        Person = self.classes.Person
        s = Session(testing.db)
        s.query(Engineer).filter(Engineer.id == Person.id).filter(
            Person.name == "e2"
        ).update({"engineer_name": "e5"})

        eq_(
            set(s.query(Person.name, Engineer.engineer_name)),
            set([("e1", "e1"), ("e2", "e5")]),
        )

    @testing.requires.delete_from
    def test_delete_from(self):
        Engineer = self.classes.Engineer
        Person = self.classes.Person
        s = Session(testing.db)
        s.query(Engineer).filter(Engineer.id == Person.id).filter(
            Person.name == "e2"
        ).delete()

        eq_(
            set(s.query(Person.name, Engineer.engineer_name)),
            set([("e1", "e1")]),
        )

    @testing.only_on("mysql", "Multi table update")
    def test_update_from_multitable(self):
        Engineer = self.classes.Engineer
        Person = self.classes.Person
        s = Session(testing.db)
        s.query(Engineer).filter(Engineer.id == Person.id).filter(
            Person.name == "e2"
        ).update({Person.name: "e22", Engineer.engineer_name: "e55"})

        eq_(
            set(s.query(Person.name, Engineer.engineer_name)),
            set([("e1", "e1"), ("e22", "e55")]),
        )
