from sqlalchemy import Integer, String, ForeignKey, and_, or_, func, \
    literal, update, table, bindparam, column, select, exc
from sqlalchemy import testing
from sqlalchemy.dialects import mysql
from sqlalchemy.engine import default
from sqlalchemy.testing import AssertsCompiledSQL, eq_, fixtures, \
    assert_raises_message
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy import util


class _UpdateFromTestBase(object):

    @classmethod
    def define_tables(cls, metadata):
        Table('mytable', metadata,
              Column('myid', Integer),
              Column('name', String(30)),
              Column('description', String(50)))
        Table('myothertable', metadata,
              Column('otherid', Integer),
              Column('othername', String(30)))
        Table('users', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('name', String(30), nullable=False))
        Table('addresses', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('user_id', None, ForeignKey('users.id')),
              Column('name', String(30), nullable=False),
              Column('email_address', String(50), nullable=False))
        Table('dingalings', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('address_id', None, ForeignKey('addresses.id')),
              Column('data', String(30)))
        Table('update_w_default', metadata,
              Column('id', Integer, primary_key=True),
              Column('x', Integer),
              Column('ycol', Integer, key='y'),
              Column('data', String(30), onupdate=lambda: "hi"))

    @classmethod
    def fixtures(cls):
        return dict(
            users=(
                ('id', 'name'),
                (7, 'jack'),
                (8, 'ed'),
                (9, 'fred'),
                (10, 'chuck')
            ),
            addresses = (
                ('id', 'user_id', 'name', 'email_address'),
                (1, 7, 'x', 'jack@bean.com'),
                (2, 8, 'x', 'ed@wood.com'),
                (3, 8, 'x', 'ed@bettyboop.com'),
                (4, 8, 'x', 'ed@lala.com'),
                (5, 9, 'x', 'fred@fred.com')
            ),
            dingalings = (
                ('id', 'address_id', 'data'),
                (1, 2, 'ding 1/2'),
                (2, 5, 'ding 2/5')
            ),
        )


class UpdateTest(_UpdateFromTestBase, fixtures.TablesTest, AssertsCompiledSQL):
    __dialect__ = 'default'

    def test_update_1(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1, table1.c.myid == 7),
            'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1',
            params={table1.c.name: 'fred'})

    def test_update_2(self):
        table1 = self.tables.mytable

        self.assert_compile(
            table1.update().
            where(table1.c.myid == 7).
            values({table1.c.myid: 5}),
            'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1',
            checkparams={'myid': 5, 'myid_1': 7})

    def test_update_3(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1, table1.c.myid == 7),
            'UPDATE mytable SET name=:name WHERE mytable.myid = :myid_1',
            params={'name': 'fred'})

    def test_update_4(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1, values={table1.c.name: table1.c.myid}),
            'UPDATE mytable SET name=mytable.myid')

    def test_update_5(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1,
                   whereclause=table1.c.name == bindparam('crit'),
                   values={table1.c.name: 'hi'}),
            'UPDATE mytable SET name=:name WHERE mytable.name = :crit',
            params={'crit': 'notthere'},
            checkparams={'crit': 'notthere', 'name': 'hi'})

    def test_update_6(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1,
                   table1.c.myid == 12,
                   values={table1.c.name: table1.c.myid}),
            'UPDATE mytable '
            'SET name=mytable.myid, description=:description '
            'WHERE mytable.myid = :myid_1',
            params={'description': 'test'},
            checkparams={'description': 'test', 'myid_1': 12})

    def test_update_7(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1, table1.c.myid == 12, values={table1.c.myid: 9}),
            'UPDATE mytable '
            'SET myid=:myid, description=:description '
            'WHERE mytable.myid = :myid_1',
            params={'myid_1': 12, 'myid': 9, 'description': 'test'})

    def test_update_8(self):
        table1 = self.tables.mytable

        self.assert_compile(
            update(table1, table1.c.myid == 12),
            'UPDATE mytable SET myid=:myid WHERE mytable.myid = :myid_1',
            params={'myid': 18}, checkparams={'myid': 18, 'myid_1': 12})

    def test_update_9(self):
        table1 = self.tables.mytable

        s = table1.update(table1.c.myid == 12, values={table1.c.name: 'lala'})
        c = s.compile(column_keys=['id', 'name'])
        eq_(str(s), str(c))

    def test_update_10(self):
        table1 = self.tables.mytable

        v1 = {table1.c.name: table1.c.myid}
        v2 = {table1.c.name: table1.c.name + 'foo'}
        self.assert_compile(
            update(table1, table1.c.myid == 12, values=v1).values(v2),
            'UPDATE mytable '
            'SET '
            'name=(mytable.name || :name_1), '
            'description=:description '
            'WHERE mytable.myid = :myid_1',
            params={'description': 'test'})

    def test_update_11(self):
        table1 = self.tables.mytable

        values = {
            table1.c.name: table1.c.name + 'lala',
            table1.c.myid: func.do_stuff(table1.c.myid, literal('hoho'))
        }

        self.assert_compile(
            update(
                table1,
                (table1.c.myid == func.hoho(4)) & (
                    table1.c.name == literal('foo') +
                    table1.c.name +
                    literal('lala')),
                values=values),
            'UPDATE mytable '
            'SET '
            'myid=do_stuff(mytable.myid, :param_1), '
            'name=(mytable.name || :name_1) '
            'WHERE '
            'mytable.myid = hoho(:hoho_1) AND '
            'mytable.name = :param_2 || mytable.name || :param_3')

    def test_unconsumed_names_kwargs(self):
        t = table("t", column("x"), column("y"))

        assert_raises_message(
            exc.CompileError,
            "Unconsumed column names: z",
            t.update().values(x=5, z=5).compile,
        )

    def test_unconsumed_names_values_dict(self):
        t = table("t", column("x"), column("y"))
        t2 = table("t2", column("q"), column("z"))

        assert_raises_message(
            exc.CompileError,
            "Unconsumed column names: j",
            t.update().values(x=5, j=7).values({t2.c.z: 5}).
            where(t.c.x == t2.c.q).compile,
        )

    def test_unconsumed_names_kwargs_w_keys(self):
        t = table("t", column("x"), column("y"))

        assert_raises_message(
            exc.CompileError,
            "Unconsumed column names: j",
            t.update().values(x=5, j=7).compile,
            column_keys=['j']
        )

    def test_update_ordered_parameters_1(self):
        table1 = self.tables.mytable

        # Confirm that we can pass values as list value pairs
        # note these are ordered *differently* from table.c
        values = [
            (table1.c.name, table1.c.name + 'lala'),
            (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho'))),
        ]
        self.assert_compile(
            update(
                table1,
                (table1.c.myid == func.hoho(4)) & (
                    table1.c.name == literal('foo') +
                    table1.c.name +
                    literal('lala')),
                preserve_parameter_order=True,
                values=values),
            'UPDATE mytable '
            'SET '
            'name=(mytable.name || :name_1), '
            'myid=do_stuff(mytable.myid, :param_1) '
            'WHERE '
            'mytable.myid = hoho(:hoho_1) AND '
            'mytable.name = :param_2 || mytable.name || :param_3')

    def test_update_ordered_parameters_2(self):
        table1 = self.tables.mytable

        # Confirm that we can pass values as list value pairs
        # note these are ordered *differently* from table.c
        values = [
            (table1.c.name, table1.c.name + 'lala'),
            ('description', 'some desc'),
            (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho')))
        ]
        self.assert_compile(
            update(
                table1,
                (table1.c.myid == func.hoho(4)) & (
                    table1.c.name == literal('foo') +
                    table1.c.name +
                    literal('lala')),
                preserve_parameter_order=True).values(values),
            'UPDATE mytable '
            'SET '
            'name=(mytable.name || :name_1), '
            'description=:description, '
            'myid=do_stuff(mytable.myid, :param_1) '
            'WHERE '
            'mytable.myid = hoho(:hoho_1) AND '
            'mytable.name = :param_2 || mytable.name || :param_3')

    def test_update_ordered_parameters_fire_onupdate(self):
        table = self.tables.update_w_default

        values = [
            (table.c.y, table.c.x + 5),
            ('x', 10)
        ]

        self.assert_compile(
            table.update(preserve_parameter_order=True).values(values),
            "UPDATE update_w_default SET ycol=(update_w_default.x + :x_1), "
            "x=:x, data=:data"
        )

    def test_update_ordered_parameters_override_onupdate(self):
        table = self.tables.update_w_default

        values = [
            (table.c.y, table.c.x + 5),
            (table.c.data, table.c.x + 10),
            ('x', 10)
        ]

        self.assert_compile(
            table.update(preserve_parameter_order=True).values(values),
            "UPDATE update_w_default SET ycol=(update_w_default.x + :x_1), "
            "data=(update_w_default.x + :x_2), x=:x"
        )

    def test_update_preserve_order_reqs_listtups(self):
        table1 = self.tables.mytable
        testing.assert_raises_message(
            ValueError,
            "When preserve_parameter_order is True, values\(\) "
            "only accepts a list of 2-tuples",
            table1.update(preserve_parameter_order=True).values,
            {"description": "foo", "name": "bar"}
        )

    def test_update_ordereddict(self):
        table1 = self.tables.mytable

        # Confirm that ordered dicts are treated as normal dicts,
        # columns sorted in table order
        values = util.OrderedDict((
            (table1.c.name, table1.c.name + 'lala'),
            (table1.c.myid, func.do_stuff(table1.c.myid, literal('hoho')))))

        self.assert_compile(
            update(
                table1,
                (table1.c.myid == func.hoho(4)) & (
                    table1.c.name == literal('foo') +
                    table1.c.name +
                    literal('lala')),
                values=values),
            'UPDATE mytable '
            'SET '
            'myid=do_stuff(mytable.myid, :param_1), '
            'name=(mytable.name || :name_1) '
            'WHERE '
            'mytable.myid = hoho(:hoho_1) AND '
            'mytable.name = :param_2 || mytable.name || :param_3')

    def test_where_empty(self):
        table1 = self.tables.mytable
        self.assert_compile(
            table1.update().where(
                and_()),
            "UPDATE mytable SET myid=:myid, name=:name, description=:description")
        self.assert_compile(
            table1.update().where(
                or_()),
            "UPDATE mytable SET myid=:myid, name=:name, description=:description")

    def test_prefix_with(self):
        table1 = self.tables.mytable

        stmt = table1.update().\
            prefix_with('A', 'B', dialect='mysql').\
            prefix_with('C', 'D')

        self.assert_compile(stmt,
                            'UPDATE C D mytable SET myid=:myid, name=:name, '
                            'description=:description')

        self.assert_compile(
            stmt,
            'UPDATE A B C D mytable SET myid=%s, name=%s, description=%s',
            dialect=mysql.dialect())

    def test_update_to_expression(self):
        """test update from an expression.

        this logic is triggered currently by a left side that doesn't
        have a key.  The current supported use case is updating the index
        of a Postgresql ARRAY type.

        """
        table1 = self.tables.mytable
        expr = func.foo(table1.c.myid)
        eq_(expr.key, None)
        self.assert_compile(table1.update().values({expr: 'bar'}),
                            'UPDATE mytable SET foo(myid)=:param_1')

    def test_update_bound_ordering(self):
        """test that bound parameters between the UPDATE and FROM clauses
        order correctly in different SQL compilation scenarios.

        """
        table1 = self.tables.mytable
        table2 = self.tables.myothertable
        sel = select([table2]).where(table2.c.otherid == 5).alias()
        upd = table1.update().\
            where(table1.c.name == sel.c.othername).\
            values(name='foo')

        dialect = default.DefaultDialect()
        dialect.positional = True
        self.assert_compile(
            upd,
            "UPDATE mytable SET name=:name FROM (SELECT "
            "myothertable.otherid AS otherid, "
            "myothertable.othername AS othername "
            "FROM myothertable "
            "WHERE myothertable.otherid = :otherid_1) AS anon_1 "
            "WHERE mytable.name = anon_1.othername",
            checkpositional=('foo', 5),
            dialect=dialect
        )

        self.assert_compile(
            upd,
            "UPDATE mytable, (SELECT myothertable.otherid AS otherid, "
            "myothertable.othername AS othername "
            "FROM myothertable "
            "WHERE myothertable.otherid = %s) AS anon_1 SET mytable.name=%s "
            "WHERE mytable.name = anon_1.othername",
            checkpositional=(5, 'foo'),
            dialect=mysql.dialect()
        )


class UpdateFromCompileTest(_UpdateFromTestBase, fixtures.TablesTest,
                            AssertsCompiledSQL):
    __dialect__ = 'default'

    run_create_tables = run_inserts = run_deletes = None

    def test_alias_one(self):
        table1 = self.tables.mytable
        talias1 = table1.alias('t1')

        # this case is nonsensical.  the UPDATE is entirely
        # against the alias, but we name the table-bound column
        # in values.   The behavior here isn't really defined
        self.assert_compile(
            update(talias1, talias1.c.myid == 7).
            values({table1.c.name: "fred"}),
            'UPDATE mytable AS t1 '
            'SET name=:name '
            'WHERE t1.myid = :myid_1')

    def test_alias_two(self):
        table1 = self.tables.mytable
        talias1 = table1.alias('t1')

        # Here, compared to
        # test_alias_one(), here we actually have UPDATE..FROM,
        # which is causing the "table1.c.name" param to be handled
        # as an "extra table", hence we see the full table name rendered.
        self.assert_compile(
            update(talias1, table1.c.myid == 7).
            values({table1.c.name: 'fred'}),
            'UPDATE mytable AS t1 '
            'SET name=:mytable_name '
            'FROM mytable '
            'WHERE mytable.myid = :myid_1',
            checkparams={'mytable_name': 'fred', 'myid_1': 7},
        )

    def test_alias_two_mysql(self):
        table1 = self.tables.mytable
        talias1 = table1.alias('t1')

        self.assert_compile(
            update(talias1, table1.c.myid == 7).
            values({table1.c.name: 'fred'}),
            "UPDATE mytable AS t1, mytable SET mytable.name=%s "
            "WHERE mytable.myid = %s",
            checkparams={'mytable_name': 'fred', 'myid_1': 7},
            dialect='mysql')

    def test_update_from_multitable_same_name_mysql(self):
        users, addresses = self.tables.users, self.tables.addresses

        self.assert_compile(
            users.update().
            values(name='newname').
            values({addresses.c.name: "new address"}).
            where(users.c.id == addresses.c.user_id),
            "UPDATE users, addresses SET addresses.name=%s, "
            "users.name=%s WHERE users.id = addresses.user_id",
            checkparams={'addresses_name': 'new address', 'name': 'newname'},
            dialect='mysql'
        )

    def test_render_table(self):
        users, addresses = self.tables.users, self.tables.addresses

        self.assert_compile(
            users.update().
            values(name='newname').
            where(users.c.id == addresses.c.user_id).
            where(addresses.c.email_address == 'e1'),
            'UPDATE users '
            'SET name=:name FROM addresses '
            'WHERE '
            'users.id = addresses.user_id AND '
            'addresses.email_address = :email_address_1',
            checkparams={'email_address_1': 'e1', 'name': 'newname'})

    def test_render_multi_table(self):
        users = self.tables.users
        addresses = self.tables.addresses
        dingalings = self.tables.dingalings

        checkparams = {
            'email_address_1': 'e1',
            'id_1': 2,
            'name': 'newname'
        }

        self.assert_compile(
            users.update().
            values(name='newname').
            where(users.c.id == addresses.c.user_id).
            where(addresses.c.email_address == 'e1').
            where(addresses.c.id == dingalings.c.address_id).
            where(dingalings.c.id == 2),
            'UPDATE users '
            'SET name=:name '
            'FROM addresses, dingalings '
            'WHERE '
            'users.id = addresses.user_id AND '
            'addresses.email_address = :email_address_1 AND '
            'addresses.id = dingalings.address_id AND '
            'dingalings.id = :id_1',
            checkparams=checkparams)

    def test_render_table_mysql(self):
        users, addresses = self.tables.users, self.tables.addresses

        self.assert_compile(
            users.update().
            values(name='newname').
            where(users.c.id == addresses.c.user_id).
            where(addresses.c.email_address == 'e1'),
            'UPDATE users, addresses '
            'SET users.name=%s '
            'WHERE '
            'users.id = addresses.user_id AND '
            'addresses.email_address = %s',
            checkparams={'email_address_1': 'e1', 'name': 'newname'},
            dialect=mysql.dialect())

    def test_render_subquery(self):
        users, addresses = self.tables.users, self.tables.addresses

        checkparams = {
            'email_address_1': 'e1',
            'id_1': 7,
            'name': 'newname'
        }

        cols = [
            addresses.c.id,
            addresses.c.user_id,
            addresses.c.email_address
        ]

        subq = select(cols).where(addresses.c.id == 7).alias()
        self.assert_compile(
            users.update().
            values(name='newname').
            where(users.c.id == subq.c.user_id).
            where(subq.c.email_address == 'e1'),
            'UPDATE users '
            'SET name=:name FROM ('
            'SELECT '
            'addresses.id AS id, '
            'addresses.user_id AS user_id, '
            'addresses.email_address AS email_address '
            'FROM addresses '
            'WHERE addresses.id = :id_1'
            ') AS anon_1 '
            'WHERE users.id = anon_1.user_id '
            'AND anon_1.email_address = :email_address_1',
            checkparams=checkparams)


class UpdateFromRoundTripTest(_UpdateFromTestBase, fixtures.TablesTest):
    __backend__ = True

    @testing.requires.update_from
    def test_exec_two_table(self):
        users, addresses = self.tables.users, self.tables.addresses

        testing.db.execute(
            addresses.update().
            values(email_address=users.c.name).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed'))

        expected = [
            (1, 7, 'x', 'jack@bean.com'),
            (2, 8, 'x', 'ed'),
            (3, 8, 'x', 'ed'),
            (4, 8, 'x', 'ed'),
            (5, 9, 'x', 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

    @testing.requires.update_from
    def test_exec_two_table_plus_alias(self):
        users, addresses = self.tables.users, self.tables.addresses

        a1 = addresses.alias()
        testing.db.execute(
            addresses.update().
            values(email_address=users.c.name).
            where(users.c.id == a1.c.user_id).
            where(users.c.name == 'ed').
            where(a1.c.id == addresses.c.id)
        )

        expected = [
            (1, 7, 'x', 'jack@bean.com'),
            (2, 8, 'x', 'ed'),
            (3, 8, 'x', 'ed'),
            (4, 8, 'x', 'ed'),
            (5, 9, 'x', 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

    @testing.requires.update_from
    def test_exec_three_table(self):
        users = self.tables.users
        addresses = self.tables.addresses
        dingalings = self.tables.dingalings

        testing.db.execute(
            addresses.update().
            values(email_address=users.c.name).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed').
            where(addresses.c.id == dingalings.c.address_id).
            where(dingalings.c.id == 1))

        expected = [
            (1, 7, 'x', 'jack@bean.com'),
            (2, 8, 'x', 'ed'),
            (3, 8, 'x', 'ed@bettyboop.com'),
            (4, 8, 'x', 'ed@lala.com'),
            (5, 9, 'x', 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

    @testing.only_on('mysql', 'Multi table update')
    def test_exec_multitable(self):
        users, addresses = self.tables.users, self.tables.addresses

        values = {
            addresses.c.email_address: 'updated',
            users.c.name: 'ed2'
        }

        testing.db.execute(
            addresses.update().
            values(values).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed'))

        expected = [
            (1, 7, 'x', 'jack@bean.com'),
            (2, 8, 'x', 'updated'),
            (3, 8, 'x', 'updated'),
            (4, 8, 'x', 'updated'),
            (5, 9, 'x', 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

        expected = [
            (7, 'jack'),
            (8, 'ed2'),
            (9, 'fred'),
            (10, 'chuck')]
        self._assert_users(users, expected)

    @testing.only_on('mysql', 'Multi table update')
    def test_exec_multitable_same_name(self):
        users, addresses = self.tables.users, self.tables.addresses

        values = {
            addresses.c.name: 'ad_ed2',
            users.c.name: 'ed2'
        }

        testing.db.execute(
            addresses.update().
            values(values).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed'))

        expected = [
            (1, 7, 'x', 'jack@bean.com'),
            (2, 8, 'ad_ed2', 'ed@wood.com'),
            (3, 8, 'ad_ed2', 'ed@bettyboop.com'),
            (4, 8, 'ad_ed2', 'ed@lala.com'),
            (5, 9, 'x', 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

        expected = [
            (7, 'jack'),
            (8, 'ed2'),
            (9, 'fred'),
            (10, 'chuck')]
        self._assert_users(users, expected)

    def _assert_addresses(self, addresses, expected):
        stmt = addresses.select().order_by(addresses.c.id)
        eq_(testing.db.execute(stmt).fetchall(), expected)

    def _assert_users(self, users, expected):
        stmt = users.select().order_by(users.c.id)
        eq_(testing.db.execute(stmt).fetchall(), expected)


class UpdateFromMultiTableUpdateDefaultsTest(_UpdateFromTestBase,
                                             fixtures.TablesTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table('users', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('name', String(30), nullable=False),
              Column('some_update', String(30), onupdate='im the update'))

        Table('addresses', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('user_id', None, ForeignKey('users.id')),
              Column('email_address', String(50), nullable=False),
              )

        Table('foobar', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('user_id', None, ForeignKey('users.id')),
              Column('data', String(30)),
              Column('some_update', String(30), onupdate='im the other update')
              )

    @classmethod
    def fixtures(cls):
        return dict(
            users=(
                ('id', 'name', 'some_update'),
                (8, 'ed', 'value'),
                (9, 'fred', 'value'),
            ),
            addresses=(
                ('id', 'user_id', 'email_address'),
                (2, 8, 'ed@wood.com'),
                (3, 8, 'ed@bettyboop.com'),
                (4, 9, 'fred@fred.com')
            ),
            foobar=(
                ('id', 'user_id', 'data'),
                (2, 8, 'd1'),
                (3, 8, 'd2'),
                (4, 9, 'd3')
            )
        )

    @testing.only_on('mysql', 'Multi table update')
    def test_defaults_second_table(self):
        users, addresses = self.tables.users, self.tables.addresses

        values = {
            addresses.c.email_address: 'updated',
            users.c.name: 'ed2'
        }

        ret = testing.db.execute(
            addresses.update().
            values(values).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed'))

        eq_(set(ret.prefetch_cols()), set([users.c.some_update]))

        expected = [
            (2, 8, 'updated'),
            (3, 8, 'updated'),
            (4, 9, 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

        expected = [
            (8, 'ed2', 'im the update'),
            (9, 'fred', 'value')]
        self._assert_users(users, expected)

    @testing.only_on('mysql', 'Multi table update')
    def test_defaults_second_table_same_name(self):
        users, foobar = self.tables.users, self.tables.foobar

        values = {
            foobar.c.data: foobar.c.data + 'a',
            users.c.name: 'ed2'
        }

        ret = testing.db.execute(
            users.update().
            values(values).
            where(users.c.id == foobar.c.user_id).
            where(users.c.name == 'ed'))

        eq_(
            set(ret.prefetch_cols()),
            set([users.c.some_update, foobar.c.some_update])
        )

        expected = [
            (2, 8, 'd1a', 'im the other update'),
            (3, 8, 'd2a', 'im the other update'),
            (4, 9, 'd3', None)]
        self._assert_foobar(foobar, expected)

        expected = [
            (8, 'ed2', 'im the update'),
            (9, 'fred', 'value')]
        self._assert_users(users, expected)

    @testing.only_on('mysql', 'Multi table update')
    def test_no_defaults_second_table(self):
        users, addresses = self.tables.users, self.tables.addresses

        ret = testing.db.execute(
            addresses.update().
            values({'email_address': users.c.name}).
            where(users.c.id == addresses.c.user_id).
            where(users.c.name == 'ed'))

        eq_(ret.prefetch_cols(), [])

        expected = [
            (2, 8, 'ed'),
            (3, 8, 'ed'),
            (4, 9, 'fred@fred.com')]
        self._assert_addresses(addresses, expected)

        # users table not actually updated, so no onupdate
        expected = [
            (8, 'ed', 'value'),
            (9, 'fred', 'value')]
        self._assert_users(users, expected)

    def _assert_foobar(self, foobar, expected):
        stmt = foobar.select().order_by(foobar.c.id)
        eq_(testing.db.execute(stmt).fetchall(), expected)

    def _assert_addresses(self, addresses, expected):
        stmt = addresses.select().order_by(addresses.c.id)
        eq_(testing.db.execute(stmt).fetchall(), expected)

    def _assert_users(self, users, expected):
        stmt = users.select().order_by(users.c.id)
        eq_(testing.db.execute(stmt).fetchall(), expected)
