from sqlalchemy.testing import eq_, assert_raises, \
    assert_raises_message, is_
from sqlalchemy.ext import declarative as decl
import sqlalchemy as sa
from sqlalchemy import testing
from sqlalchemy import Integer, String, ForeignKey
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy.orm import relationship, create_session, class_mapper, \
    configure_mappers, clear_mappers, \
    deferred, column_property, \
    Session
from sqlalchemy.util import classproperty
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.testing import fixtures

Base = None


class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):

    def setup(self):
        global Base
        Base = decl.declarative_base(testing.db)

    def teardown(self):
        Session.close_all()
        clear_mappers()
        Base.metadata.drop_all()


class DeclarativeMixinTest(DeclarativeTestBase):

    def test_simple(self):

        class MyMixin(object):

            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

            def foo(self):
                return 'bar' + str(self.id)

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            name = Column(String(100), nullable=False, index=True)

        Base.metadata.create_all()
        session = create_session()
        session.add(MyModel(name='testing'))
        session.flush()
        session.expunge_all()
        obj = session.query(MyModel).one()
        eq_(obj.id, 1)
        eq_(obj.name, 'testing')
        eq_(obj.foo(), 'bar1')

    def test_unique_column(self):

        class MyMixin(object):

            id = Column(Integer, primary_key=True)
            value = Column(String, unique=True)

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'

        assert MyModel.__table__.c.value.unique

    def test_hierarchical_bases(self):

        class MyMixinParent:

            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

            def foo(self):
                return 'bar' + str(self.id)

        class MyMixin(MyMixinParent):

            baz = Column(String(100), nullable=False, index=True)

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            name = Column(String(100), nullable=False, index=True)

        Base.metadata.create_all()
        session = create_session()
        session.add(MyModel(name='testing', baz='fu'))
        session.flush()
        session.expunge_all()
        obj = session.query(MyModel).one()
        eq_(obj.id, 1)
        eq_(obj.name, 'testing')
        eq_(obj.foo(), 'bar1')
        eq_(obj.baz, 'fu')

    def test_mixin_overrides(self):
        """test a mixin that overrides a column on a superclass."""

        class MixinA(object):
            foo = Column(String(50))

        class MixinB(MixinA):
            foo = Column(Integer)

        class MyModelA(Base, MixinA):
            __tablename__ = 'testa'
            id = Column(Integer, primary_key=True)

        class MyModelB(Base, MixinB):
            __tablename__ = 'testb'
            id = Column(Integer, primary_key=True)

        eq_(MyModelA.__table__.c.foo.type.__class__, String)
        eq_(MyModelB.__table__.c.foo.type.__class__, Integer)

    def test_not_allowed(self):

        class MyMixin:
            foo = Column(Integer, ForeignKey('bar.id'))

        def go():
            class MyModel(Base, MyMixin):
                __tablename__ = 'foo'

        assert_raises(sa.exc.InvalidRequestError, go)

        class MyRelMixin:
            foo = relationship('Bar')

        def go():
            class MyModel(Base, MyRelMixin):

                __tablename__ = 'foo'

        assert_raises(sa.exc.InvalidRequestError, go)

        class MyDefMixin:
            foo = deferred(Column('foo', String))

        def go():
            class MyModel(Base, MyDefMixin):
                __tablename__ = 'foo'

        assert_raises(sa.exc.InvalidRequestError, go)

        class MyCPropMixin:
            foo = column_property(Column('foo', String))

        def go():
            class MyModel(Base, MyCPropMixin):
                __tablename__ = 'foo'

        assert_raises(sa.exc.InvalidRequestError, go)

    def test_table_name_inherited(self):

        class MyMixin:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()
            id = Column(Integer, primary_key=True)

        class MyModel(Base, MyMixin):
            pass

        eq_(MyModel.__table__.name, 'mymodel')

    def test_classproperty_still_works(self):
        class MyMixin(object):

            @classproperty
            def __tablename__(cls):
                return cls.__name__.lower()
            id = Column(Integer, primary_key=True)

        class MyModel(Base, MyMixin):
            __tablename__ = 'overridden'

        eq_(MyModel.__table__.name, 'overridden')

    def test_table_name_not_inherited(self):

        class MyMixin:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()
            id = Column(Integer, primary_key=True)

        class MyModel(Base, MyMixin):
            __tablename__ = 'overridden'

        eq_(MyModel.__table__.name, 'overridden')

    def test_table_name_inheritance_order(self):

        class MyMixin1:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower() + '1'

        class MyMixin2:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower() + '2'

        class MyModel(Base, MyMixin1, MyMixin2):
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.name, 'mymodel1')

    def test_table_name_dependent_on_subclass(self):

        class MyHistoryMixin:

            @declared_attr
            def __tablename__(cls):
                return cls.parent_name + '_changelog'

        class MyModel(Base, MyHistoryMixin):
            parent_name = 'foo'
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.name, 'foo_changelog')

    def test_table_args_inherited(self):

        class MyMixin:
            __table_args__ = {'mysql_engine': 'InnoDB'}

        class MyModel(Base, MyMixin):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.kwargs, {'mysql_engine': 'InnoDB'})

    def test_table_args_inherited_descriptor(self):

        class MyMixin:

            @declared_attr
            def __table_args__(cls):
                return {'info': cls.__name__}

        class MyModel(Base, MyMixin):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.info, 'MyModel')

    def test_table_args_inherited_single_table_inheritance(self):

        class MyMixin:
            __table_args__ = {'mysql_engine': 'InnoDB'}

        class General(Base, MyMixin):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)
            type_ = Column(String(50))
            __mapper__args = {'polymorphic_on': type_}

        class Specific(General):
            __mapper_args__ = {'polymorphic_identity': 'specific'}

        assert Specific.__table__ is General.__table__
        eq_(General.__table__.kwargs, {'mysql_engine': 'InnoDB'})

    def test_columns_single_table_inheritance(self):
        """Test a column on a mixin with an alternate attribute name,
        mapped to a superclass and single-table inheritance subclass.
        The superclass table gets the column, the subclass shares
        the MapperProperty.

        """

        class MyMixin(object):
            foo = Column('foo', Integer)
            bar = Column('bar_newname', Integer)

        class General(Base, MyMixin):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)
            type_ = Column(String(50))
            __mapper__args = {'polymorphic_on': type_}

        class Specific(General):
            __mapper_args__ = {'polymorphic_identity': 'specific'}

        assert General.bar.prop.columns[0] is General.__table__.c.bar_newname
        assert len(General.bar.prop.columns) == 1
        assert Specific.bar.prop is General.bar.prop

    @testing.skip_if(lambda: testing.against('oracle'),
                     "Test has an empty insert in it at the moment")
    def test_columns_single_inheritance_conflict_resolution(self):
        """Test that a declared_attr can return the existing column and it will
        be ignored.  this allows conditional columns to be added.

        See [ticket:2472].

        """
        class Person(Base):
            __tablename__ = 'person'
            id = Column(Integer, primary_key=True)

        class Mixin(object):

            @declared_attr
            def target_id(cls):
                return cls.__table__.c.get(
                    'target_id',
                    Column(Integer, ForeignKey('other.id'))
                )

            @declared_attr
            def target(cls):
                return relationship("Other")

        class Engineer(Mixin, Person):

            """single table inheritance"""

        class Manager(Mixin, Person):

            """single table inheritance"""

        class Other(Base):
            __tablename__ = 'other'
            id = Column(Integer, primary_key=True)

        is_(
            Engineer.target_id.property.columns[0],
            Person.__table__.c.target_id
        )
        is_(
            Manager.target_id.property.columns[0],
            Person.__table__.c.target_id
        )
        # do a brief round trip on this
        Base.metadata.create_all()
        session = Session()
        o1, o2 = Other(), Other()
        session.add_all([
            Engineer(target=o1),
            Manager(target=o2),
            Manager(target=o1)
        ])
        session.commit()
        eq_(session.query(Engineer).first().target, o1)

    def test_columns_joined_table_inheritance(self):
        """Test a column on a mixin with an alternate attribute name,
        mapped to a superclass and joined-table inheritance subclass.
        Both tables get the column, in the case of the subclass the two
        columns are joined under one MapperProperty.

        """

        class MyMixin(object):
            foo = Column('foo', Integer)
            bar = Column('bar_newname', Integer)

        class General(Base, MyMixin):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)
            type_ = Column(String(50))
            __mapper__args = {'polymorphic_on': type_}

        class Specific(General):
            __tablename__ = 'sub'
            id = Column(Integer, ForeignKey('test.id'), primary_key=True)
            __mapper_args__ = {'polymorphic_identity': 'specific'}

        assert General.bar.prop.columns[0] is General.__table__.c.bar_newname
        assert len(General.bar.prop.columns) == 1
        assert Specific.bar.prop is General.bar.prop
        eq_(len(Specific.bar.prop.columns), 1)
        assert Specific.bar.prop.columns[0] is General.__table__.c.bar_newname

    def test_column_join_checks_superclass_type(self):
        """Test that the logic which joins subclass props to those
        of the superclass checks that the superclass property is a column.

        """

        class General(Base):
            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)
            general_id = Column(Integer, ForeignKey('test.id'))
            type_ = relationship("General")

        class Specific(General):
            __tablename__ = 'sub'
            id = Column(Integer, ForeignKey('test.id'), primary_key=True)
            type_ = Column('foob', String(50))

        assert isinstance(General.type_.property, sa.orm.RelationshipProperty)
        assert Specific.type_.property.columns[0] is Specific.__table__.c.foob

    def test_column_join_checks_subclass_type(self):
        """Test that the logic which joins subclass props to those
        of the superclass checks that the subclass property is a column.

        """

        def go():
            class General(Base):
                __tablename__ = 'test'
                id = Column(Integer, primary_key=True)
                type_ = Column('foob', Integer)

            class Specific(General):
                __tablename__ = 'sub'
                id = Column(Integer, ForeignKey('test.id'), primary_key=True)
                specific_id = Column(Integer, ForeignKey('sub.id'))
                type_ = relationship("Specific")
        assert_raises_message(
            sa.exc.ArgumentError, "column 'foob' conflicts with property", go
        )

    def test_table_args_overridden(self):

        class MyMixin:
            __table_args__ = {'mysql_engine': 'Foo'}

        class MyModel(Base, MyMixin):
            __tablename__ = 'test'
            __table_args__ = {'mysql_engine': 'InnoDB'}
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.kwargs, {'mysql_engine': 'InnoDB'})

    def test_mapper_args_declared_attr(self):

        class ComputedMapperArgs:

            @declared_attr
            def __mapper_args__(cls):
                if cls.__name__ == 'Person':
                    return {'polymorphic_on': cls.discriminator}
                else:
                    return {'polymorphic_identity': cls.__name__}

        class Person(Base, ComputedMapperArgs):
            __tablename__ = 'people'
            id = Column(Integer, primary_key=True)
            discriminator = Column('type', String(50))

        class Engineer(Person):
            pass

        configure_mappers()
        assert class_mapper(Person).polymorphic_on \
            is Person.__table__.c.type
        eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer')

    def test_mapper_args_declared_attr_two(self):

        # same as test_mapper_args_declared_attr, but we repeat
        # ComputedMapperArgs on both classes for no apparent reason.

        class ComputedMapperArgs:

            @declared_attr
            def __mapper_args__(cls):
                if cls.__name__ == 'Person':
                    return {'polymorphic_on': cls.discriminator}
                else:
                    return {'polymorphic_identity': cls.__name__}

        class Person(Base, ComputedMapperArgs):

            __tablename__ = 'people'
            id = Column(Integer, primary_key=True)
            discriminator = Column('type', String(50))

        class Engineer(Person, ComputedMapperArgs):
            pass

        configure_mappers()
        assert class_mapper(Person).polymorphic_on \
            is Person.__table__.c.type
        eq_(class_mapper(Engineer).polymorphic_identity, 'Engineer')

    def test_table_args_composite(self):

        class MyMixin1:

            __table_args__ = {'info': {'baz': 'bob'}}

        class MyMixin2:

            __table_args__ = {'info': {'foo': 'bar'}}

        class MyModel(Base, MyMixin1, MyMixin2):

            __tablename__ = 'test'

            @declared_attr
            def __table_args__(self):
                info = {}
                args = dict(info=info)
                info.update(MyMixin1.__table_args__['info'])
                info.update(MyMixin2.__table_args__['info'])
                return args
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__table__.info, {'foo': 'bar', 'baz': 'bob'})

    def test_mapper_args_inherited(self):

        class MyMixin:

            __mapper_args__ = {'always_refresh': True}

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__mapper__.always_refresh, True)

    def test_mapper_args_inherited_descriptor(self):

        class MyMixin:

            @declared_attr
            def __mapper_args__(cls):

                # tenuous, but illustrates the problem!

                if cls.__name__ == 'MyModel':
                    return dict(always_refresh=True)
                else:
                    return dict(always_refresh=False)

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__mapper__.always_refresh, True)

    def test_mapper_args_polymorphic_on_inherited(self):

        class MyMixin:

            type_ = Column(String(50))
            __mapper_args__ = {'polymorphic_on': type_}

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        col = MyModel.__mapper__.polymorphic_on
        eq_(col.name, 'type_')
        assert col.table is not None

    def test_mapper_args_overridden(self):

        class MyMixin:

            __mapper_args__ = dict(always_refresh=True)

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            __mapper_args__ = dict(always_refresh=False)
            id = Column(Integer, primary_key=True)

        eq_(MyModel.__mapper__.always_refresh, False)

    def test_mapper_args_composite(self):

        class MyMixin1:

            type_ = Column(String(50))
            __mapper_args__ = {'polymorphic_on': type_}

        class MyMixin2:

            __mapper_args__ = {'always_refresh': True}

        class MyModel(Base, MyMixin1, MyMixin2):

            __tablename__ = 'test'

            @declared_attr
            def __mapper_args__(cls):
                args = {}
                args.update(MyMixin1.__mapper_args__)
                args.update(MyMixin2.__mapper_args__)
                if cls.__name__ != 'MyModel':
                    args.pop('polymorphic_on')
                    args['polymorphic_identity'] = cls.__name__

                return args
            id = Column(Integer, primary_key=True)

        class MySubModel(MyModel):
            pass

        eq_(
            MyModel.__mapper__.polymorphic_on.name,
            'type_'
        )
        assert MyModel.__mapper__.polymorphic_on.table is not None
        eq_(MyModel.__mapper__.always_refresh, True)
        eq_(MySubModel.__mapper__.always_refresh, True)
        eq_(MySubModel.__mapper__.polymorphic_identity, 'MySubModel')

    def test_mapper_args_property(self):
        class MyModel(Base):

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()

            @declared_attr
            def __table_args__(cls):
                return {'mysql_engine': 'InnoDB'}

            @declared_attr
            def __mapper_args__(cls):
                args = {}
                args['polymorphic_identity'] = cls.__name__
                return args
            id = Column(Integer, primary_key=True)

        class MySubModel(MyModel):
            id = Column(Integer, ForeignKey('mymodel.id'), primary_key=True)

        class MySubModel2(MyModel):
            __tablename__ = 'sometable'
            id = Column(Integer, ForeignKey('mymodel.id'), primary_key=True)

        eq_(MyModel.__mapper__.polymorphic_identity, 'MyModel')
        eq_(MySubModel.__mapper__.polymorphic_identity, 'MySubModel')
        eq_(MyModel.__table__.kwargs['mysql_engine'], 'InnoDB')
        eq_(MySubModel.__table__.kwargs['mysql_engine'], 'InnoDB')
        eq_(MySubModel2.__table__.kwargs['mysql_engine'], 'InnoDB')
        eq_(MyModel.__table__.name, 'mymodel')
        eq_(MySubModel.__table__.name, 'mysubmodel')

    def test_mapper_args_custom_base(self):
        """test the @declared_attr approach from a custom base."""

        class Base(object):

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()

            @declared_attr
            def __table_args__(cls):
                return {'mysql_engine': 'InnoDB'}

            @declared_attr
            def id(self):
                return Column(Integer, primary_key=True)

        Base = decl.declarative_base(cls=Base)

        class MyClass(Base):
            pass

        class MyOtherClass(Base):
            pass

        eq_(MyClass.__table__.kwargs['mysql_engine'], 'InnoDB')
        eq_(MyClass.__table__.name, 'myclass')
        eq_(MyOtherClass.__table__.name, 'myotherclass')
        assert MyClass.__table__.c.id.table is MyClass.__table__
        assert MyOtherClass.__table__.c.id.table is MyOtherClass.__table__

    def test_single_table_no_propagation(self):

        class IdColumn:

            id = Column(Integer, primary_key=True)

        class Generic(Base, IdColumn):

            __tablename__ = 'base'
            discriminator = Column('type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)
            value = Column(Integer())

        class Specific(Generic):

            __mapper_args__ = dict(polymorphic_identity='specific')

        assert Specific.__table__ is Generic.__table__
        eq_(list(Generic.__table__.c.keys()), ['id', 'type', 'value'])
        assert class_mapper(Specific).polymorphic_on \
            is Generic.__table__.c.type
        eq_(class_mapper(Specific).polymorphic_identity, 'specific')

    def test_joined_table_propagation(self):

        class CommonMixin:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()
            __table_args__ = {'mysql_engine': 'InnoDB'}
            timestamp = Column(Integer)
            id = Column(Integer, primary_key=True)

        class Generic(Base, CommonMixin):

            discriminator = Column('python_type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)

        class Specific(Generic):

            __mapper_args__ = dict(polymorphic_identity='specific')
            id = Column(Integer, ForeignKey('generic.id'),
                        primary_key=True)

        eq_(Generic.__table__.name, 'generic')
        eq_(Specific.__table__.name, 'specific')
        eq_(list(Generic.__table__.c.keys()), ['timestamp', 'id',
                                               'python_type'])
        eq_(list(Specific.__table__.c.keys()), ['id'])
        eq_(Generic.__table__.kwargs, {'mysql_engine': 'InnoDB'})
        eq_(Specific.__table__.kwargs, {'mysql_engine': 'InnoDB'})

    def test_some_propagation(self):

        class CommonMixin:

            @declared_attr
            def __tablename__(cls):
                return cls.__name__.lower()
            __table_args__ = {'mysql_engine': 'InnoDB'}
            timestamp = Column(Integer)

        class BaseType(Base, CommonMixin):

            discriminator = Column('type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)
            id = Column(Integer, primary_key=True)
            value = Column(Integer())

        class Single(BaseType):

            __tablename__ = None
            __mapper_args__ = dict(polymorphic_identity='type1')

        class Joined(BaseType):

            __mapper_args__ = dict(polymorphic_identity='type2')
            id = Column(Integer, ForeignKey('basetype.id'),
                        primary_key=True)

        eq_(BaseType.__table__.name, 'basetype')
        eq_(list(BaseType.__table__.c.keys()), ['timestamp', 'type', 'id',
                                                'value'])
        eq_(BaseType.__table__.kwargs, {'mysql_engine': 'InnoDB'})
        assert Single.__table__ is BaseType.__table__
        eq_(Joined.__table__.name, 'joined')
        eq_(list(Joined.__table__.c.keys()), ['id'])
        eq_(Joined.__table__.kwargs, {'mysql_engine': 'InnoDB'})

    def test_col_copy_vs_declared_attr_joined_propagation(self):
        class Mixin(object):
            a = Column(Integer)

            @declared_attr
            def b(cls):
                return Column(Integer)

        class A(Mixin, Base):
            __tablename__ = 'a'
            id = Column(Integer, primary_key=True)

        class B(A):
            __tablename__ = 'b'
            id = Column(Integer, ForeignKey('a.id'), primary_key=True)

        assert 'a' in A.__table__.c
        assert 'b' in A.__table__.c
        assert 'a' not in B.__table__.c
        assert 'b' not in B.__table__.c

    def test_col_copy_vs_declared_attr_joined_propagation_newname(self):
        class Mixin(object):
            a = Column('a1', Integer)

            @declared_attr
            def b(cls):
                return Column('b1', Integer)

        class A(Mixin, Base):
            __tablename__ = 'a'
            id = Column(Integer, primary_key=True)

        class B(A):
            __tablename__ = 'b'
            id = Column(Integer, ForeignKey('a.id'), primary_key=True)

        assert 'a1' in A.__table__.c
        assert 'b1' in A.__table__.c
        assert 'a1' not in B.__table__.c
        assert 'b1' not in B.__table__.c

    def test_col_copy_vs_declared_attr_single_propagation(self):
        class Mixin(object):
            a = Column(Integer)

            @declared_attr
            def b(cls):
                return Column(Integer)

        class A(Mixin, Base):
            __tablename__ = 'a'
            id = Column(Integer, primary_key=True)

        class B(A):
            pass

        assert 'a' in A.__table__.c
        assert 'b' in A.__table__.c

    def test_non_propagating_mixin(self):

        class NoJoinedTableNameMixin:

            @declared_attr
            def __tablename__(cls):
                if decl.has_inherited_table(cls):
                    return None
                return cls.__name__.lower()

        class BaseType(Base, NoJoinedTableNameMixin):

            discriminator = Column('type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)
            id = Column(Integer, primary_key=True)
            value = Column(Integer())

        class Specific(BaseType):

            __mapper_args__ = dict(polymorphic_identity='specific')

        eq_(BaseType.__table__.name, 'basetype')
        eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'value'])
        assert Specific.__table__ is BaseType.__table__
        assert class_mapper(Specific).polymorphic_on \
            is BaseType.__table__.c.type
        eq_(class_mapper(Specific).polymorphic_identity, 'specific')

    def test_non_propagating_mixin_used_for_joined(self):

        class TableNameMixin:

            @declared_attr
            def __tablename__(cls):
                if decl.has_inherited_table(cls) and TableNameMixin \
                        not in cls.__bases__:
                    return None
                return cls.__name__.lower()

        class BaseType(Base, TableNameMixin):

            discriminator = Column('type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)
            id = Column(Integer, primary_key=True)
            value = Column(Integer())

        class Specific(BaseType, TableNameMixin):

            __mapper_args__ = dict(polymorphic_identity='specific')
            id = Column(Integer, ForeignKey('basetype.id'),
                        primary_key=True)

        eq_(BaseType.__table__.name, 'basetype')
        eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'value'])
        eq_(Specific.__table__.name, 'specific')
        eq_(list(Specific.__table__.c.keys()), ['id'])

    def test_single_back_propagate(self):

        class ColumnMixin:

            timestamp = Column(Integer)

        class BaseType(Base):

            __tablename__ = 'foo'
            discriminator = Column('type', String(50))
            __mapper_args__ = dict(polymorphic_on=discriminator)
            id = Column(Integer, primary_key=True)

        class Specific(BaseType, ColumnMixin):

            __mapper_args__ = dict(polymorphic_identity='specific')

        eq_(list(BaseType.__table__.c.keys()), ['type', 'id', 'timestamp'])

    def test_table_in_model_and_same_column_in_mixin(self):

        class ColumnMixin:

            data = Column(Integer)

        class Model(Base, ColumnMixin):

            __table__ = Table('foo', Base.metadata,
                              Column('data', Integer),
                              Column('id', Integer, primary_key=True))

        model_col = Model.__table__.c.data
        mixin_col = ColumnMixin.data
        assert model_col is not mixin_col
        eq_(model_col.name, 'data')
        assert model_col.type.__class__ is mixin_col.type.__class__

    def test_table_in_model_and_different_named_column_in_mixin(self):

        class ColumnMixin:
            tada = Column(Integer)

        def go():

            class Model(Base, ColumnMixin):

                __table__ = Table('foo', Base.metadata,
                                  Column('data', Integer),
                                  Column('id', Integer, primary_key=True))
                foo = relationship("Dest")

        assert_raises_message(sa.exc.ArgumentError,
                              "Can't add additional column 'tada' when "
                              "specifying __table__", go)

    def test_table_in_model_and_different_named_alt_key_column_in_mixin(self):

        # here, the __table__ has a column 'tada'.  We disallow
        # the add of the 'foobar' column, even though it's
        # keyed to 'tada'.

        class ColumnMixin:
            tada = Column('foobar', Integer)

        def go():

            class Model(Base, ColumnMixin):

                __table__ = Table('foo', Base.metadata,
                                  Column('data', Integer),
                                  Column('tada', Integer),
                                  Column('id', Integer, primary_key=True))
                foo = relationship("Dest")

        assert_raises_message(sa.exc.ArgumentError,
                              "Can't add additional column 'foobar' when "
                              "specifying __table__", go)

    def test_table_in_model_overrides_different_typed_column_in_mixin(self):

        class ColumnMixin:

            data = Column(String)

        class Model(Base, ColumnMixin):

            __table__ = Table('foo', Base.metadata,
                              Column('data', Integer),
                              Column('id', Integer, primary_key=True))

        model_col = Model.__table__.c.data
        mixin_col = ColumnMixin.data
        assert model_col is not mixin_col
        eq_(model_col.name, 'data')
        assert model_col.type.__class__ is Integer

    def test_mixin_column_ordering(self):

        class Foo(object):

            col1 = Column(Integer)
            col3 = Column(Integer)

        class Bar(object):

            col2 = Column(Integer)
            col4 = Column(Integer)

        class Model(Base, Foo, Bar):

            id = Column(Integer, primary_key=True)
            __tablename__ = 'model'

        eq_(list(Model.__table__.c.keys()), ['col1', 'col3', 'col2', 'col4',
                                             'id'])

    def test_honor_class_mro_one(self):
        class HasXMixin(object):

            @declared_attr
            def x(self):
                return Column(Integer)

        class Parent(HasXMixin, Base):
            __tablename__ = 'parent'
            id = Column(Integer, primary_key=True)

        class Child(Parent):
            __tablename__ = 'child'
            id = Column(Integer, ForeignKey('parent.id'), primary_key=True)

        assert "x" not in Child.__table__.c

    def test_honor_class_mro_two(self):
        class HasXMixin(object):

            @declared_attr
            def x(self):
                return Column(Integer)

        class Parent(HasXMixin, Base):
            __tablename__ = 'parent'
            id = Column(Integer, primary_key=True)

            def x(self):
                return "hi"

        class C(Parent):
            __tablename__ = 'c'
            id = Column(Integer, ForeignKey('parent.id'), primary_key=True)

        assert C().x() == 'hi'

    def test_arbitrary_attrs_one(self):
        class HasMixin(object):

            @declared_attr
            def some_attr(cls):
                return cls.__name__ + "SOME ATTR"

        class Mapped(HasMixin, Base):
            __tablename__ = 't'
            id = Column(Integer, primary_key=True)

        eq_(Mapped.some_attr, "MappedSOME ATTR")
        eq_(Mapped.__dict__['some_attr'], "MappedSOME ATTR")

    def test_arbitrary_attrs_two(self):
        from sqlalchemy.ext.associationproxy import association_proxy

        class FilterA(Base):
            __tablename__ = 'filter_a'
            id = Column(Integer(), primary_key=True)
            parent_id = Column(Integer(),
                               ForeignKey('type_a.id'))
            filter = Column(String())

            def __init__(self, filter_, **kw):
                self.filter = filter_

        class FilterB(Base):
            __tablename__ = 'filter_b'
            id = Column(Integer(), primary_key=True)
            parent_id = Column(Integer(),
                               ForeignKey('type_b.id'))
            filter = Column(String())

            def __init__(self, filter_, **kw):
                self.filter = filter_

        class FilterMixin(object):

            @declared_attr
            def _filters(cls):
                return relationship(cls.filter_class,
                                    cascade='all,delete,delete-orphan')

            @declared_attr
            def filters(cls):
                return association_proxy('_filters', 'filter')

        class TypeA(Base, FilterMixin):
            __tablename__ = 'type_a'
            filter_class = FilterA
            id = Column(Integer(), primary_key=True)

        class TypeB(Base, FilterMixin):
            __tablename__ = 'type_b'
            filter_class = FilterB
            id = Column(Integer(), primary_key=True)

        TypeA(filters=['foo'])
        TypeB(filters=['foo'])


class DeclarativeMixinPropertyTest(DeclarativeTestBase):

    def test_column_property(self):

        class MyMixin(object):

            @declared_attr
            def prop_hoho(cls):
                return column_property(Column('prop', String(50)))

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        class MyOtherModel(Base, MyMixin):

            __tablename__ = 'othertest'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        assert MyModel.__table__.c.prop is not None
        assert MyOtherModel.__table__.c.prop is not None
        assert MyModel.__table__.c.prop \
            is not MyOtherModel.__table__.c.prop
        assert MyModel.prop_hoho.property.columns \
            == [MyModel.__table__.c.prop]
        assert MyOtherModel.prop_hoho.property.columns \
            == [MyOtherModel.__table__.c.prop]
        assert MyModel.prop_hoho.property \
            is not MyOtherModel.prop_hoho.property
        Base.metadata.create_all()
        sess = create_session()
        m1, m2 = MyModel(prop_hoho='foo'), MyOtherModel(prop_hoho='bar')
        sess.add_all([m1, m2])
        sess.flush()
        eq_(sess.query(MyModel).filter(MyModel.prop_hoho == 'foo'
                                       ).one(), m1)
        eq_(sess.query(MyOtherModel).filter(MyOtherModel.prop_hoho
                                            == 'bar').one(), m2)

    def test_doc(self):
        """test documentation transfer.

        the documentation situation with @declared_attr is problematic.
        at least see if mapped subclasses get the doc.

        """

        class MyMixin(object):

            @declared_attr
            def type_(cls):
                """this is a document."""

                return Column(String(50))

            @declared_attr
            def t2(cls):
                """this is another document."""

                return column_property(Column(String(50)))

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        configure_mappers()
        eq_(MyModel.type_.__doc__, """this is a document.""")
        eq_(MyModel.t2.__doc__, """this is another document.""")

    def test_column_in_mapper_args(self):

        class MyMixin(object):

            @declared_attr
            def type_(cls):
                return Column(String(50))
            __mapper_args__ = {'polymorphic_on': type_}

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True)

        configure_mappers()
        col = MyModel.__mapper__.polymorphic_on
        eq_(col.name, 'type_')
        assert col.table is not None

    def test_column_in_mapper_args_used_multiple_times(self):

        class MyMixin(object):

            version_id = Column(Integer)
            __mapper_args__ = {'version_id_col': version_id}

        class ModelOne(Base, MyMixin):

            __tablename__ = 'm1'
            id = Column(Integer, primary_key=True)

        class ModelTwo(Base, MyMixin):

            __tablename__ = 'm2'
            id = Column(Integer, primary_key=True)

        is_(
            ModelOne.__mapper__.version_id_col,
            ModelOne.__table__.c.version_id
        )
        is_(
            ModelTwo.__mapper__.version_id_col,
            ModelTwo.__table__.c.version_id
        )

    def test_deferred(self):

        class MyMixin(object):

            @declared_attr
            def data(cls):
                return deferred(Column('data', String(50)))

        class MyModel(Base, MyMixin):

            __tablename__ = 'test'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        Base.metadata.create_all()
        sess = create_session()
        sess.add_all([MyModel(data='d1'), MyModel(data='d2')])
        sess.flush()
        sess.expunge_all()
        d1, d2 = sess.query(MyModel).order_by(MyModel.data)
        assert 'data' not in d1.__dict__
        assert d1.data == 'd1'
        assert 'data' in d1.__dict__

    def _test_relationship(self, usestring):

        class RefTargetMixin(object):

            @declared_attr
            def target_id(cls):
                return Column('target_id', ForeignKey('target.id'))
            if usestring:

                @declared_attr
                def target(cls):
                    return relationship('Target',
                                        primaryjoin='Target.id==%s.target_id'
                                        % cls.__name__)
            else:

                @declared_attr
                def target(cls):
                    return relationship('Target')

        class Foo(Base, RefTargetMixin):

            __tablename__ = 'foo'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        class Bar(Base, RefTargetMixin):

            __tablename__ = 'bar'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        class Target(Base):

            __tablename__ = 'target'
            id = Column(Integer, primary_key=True,
                        test_needs_autoincrement=True)

        Base.metadata.create_all()
        sess = create_session()
        t1, t2 = Target(), Target()
        f1, f2, b1 = Foo(target=t1), Foo(target=t2), Bar(target=t1)
        sess.add_all([f1, f2, b1])
        sess.flush()
        eq_(sess.query(Foo).filter(Foo.target == t2).one(), f2)
        eq_(sess.query(Bar).filter(Bar.target == t2).first(), None)
        sess.expire_all()
        eq_(f1.target, t1)

    def test_relationship(self):
        self._test_relationship(False)

    def test_relationship_primryjoin(self):
        self._test_relationship(True)


class AbstractTest(DeclarativeTestBase):

    def test_abstract_boolean(self):

        class A(Base):
            __abstract__ = True
            __tablename__ = 'x'
            id = Column(Integer, primary_key=True)

        class B(Base):
            __abstract__ = False
            __tablename__ = 'y'
            id = Column(Integer, primary_key=True)

        class C(Base):
            __abstract__ = False
            __tablename__ = 'z'
            id = Column(Integer, primary_key=True)

        class D(Base):
            __tablename__ = 'q'
            id = Column(Integer, primary_key=True)

        eq_(set(Base.metadata.tables), set(['y', 'z', 'q']))
