from sqlalchemy.testing import eq_, assert_raises, \
    assert_raises_message
from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.orm import exc as orm_exc
from sqlalchemy.testing import *
import sqlalchemy as sa
from sqlalchemy import testing
from sqlalchemy.testing import fixtures
from sqlalchemy.orm import attributes
from sqlalchemy.testing import eq_
from sqlalchemy.testing.schema import Table, Column


class Employee(object):

    def __init__(self, name):
        self.name = name

    def __repr__(self):
        return self.__class__.__name__ + ' ' + self.name


class Manager(Employee):

    def __init__(self, name, manager_data):
        self.name = name
        self.manager_data = manager_data

    def __repr__(self):
        return self.__class__.__name__ + ' ' + self.name + ' ' \
            + self.manager_data


class Engineer(Employee):

    def __init__(self, name, engineer_info):
        self.name = name
        self.engineer_info = engineer_info

    def __repr__(self):
        return self.__class__.__name__ + ' ' + self.name + ' ' \
            + self.engineer_info


class Hacker(Engineer):

    def __init__(
        self,
        name,
        nickname,
        engineer_info,
        ):
        self.name = name
        self.nickname = nickname
        self.engineer_info = engineer_info

    def __repr__(self):
        return self.__class__.__name__ + ' ' + self.name + " '" \
            + self.nickname + "' " + self.engineer_info


class Company(object):
    pass


class ConcreteTest(fixtures.MappedTest):

    @classmethod
    def define_tables(cls, metadata):
        global managers_table, engineers_table, hackers_table, \
            companies, employees_table
        companies = Table('companies', metadata, Column('id', Integer,
                          primary_key=True,
                          test_needs_autoincrement=True), Column('name'
                          , String(50)))
        employees_table = Table('employees', metadata,
                                Column('employee_id', Integer,
                                primary_key=True,
                                test_needs_autoincrement=True),
                                Column('name', String(50)),
                                Column('company_id', Integer,
                                ForeignKey('companies.id')))
        managers_table = Table(
            'managers',
            metadata,
            Column('employee_id', Integer, primary_key=True,
                   test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('manager_data', String(50)),
            Column('company_id', Integer, ForeignKey('companies.id')),
            )
        engineers_table = Table(
            'engineers',
            metadata,
            Column('employee_id', Integer, primary_key=True,
                   test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('engineer_info', String(50)),
            Column('company_id', Integer, ForeignKey('companies.id')),
            )
        hackers_table = Table(
            'hackers',
            metadata,
            Column('employee_id', Integer, primary_key=True,
                   test_needs_autoincrement=True),
            Column('name', String(50)),
            Column('engineer_info', String(50)),
            Column('company_id', Integer, ForeignKey('companies.id')),
            Column('nickname', String(50)),
            )

    def test_basic(self):
        pjoin = polymorphic_union({'manager': managers_table, 'engineer'
                                  : engineers_table}, 'type', 'pjoin')
        employee_mapper = mapper(Employee, pjoin,
                                 polymorphic_on=pjoin.c.type)
        manager_mapper = mapper(Manager, managers_table,
                                inherits=employee_mapper,
                                concrete=True,
                                polymorphic_identity='manager')
        engineer_mapper = mapper(Engineer, engineers_table,
                                 inherits=employee_mapper,
                                 concrete=True,
                                 polymorphic_identity='engineer')
        session = create_session()
        session.add(Manager('Tom', 'knows how to manage things'))
        session.add(Engineer('Kurt', 'knows how to hack'))
        session.flush()
        session.expunge_all()
        assert set([repr(x) for x in session.query(Employee)]) \
            == set(['Engineer Kurt knows how to hack',
                   'Manager Tom knows how to manage things'])
        assert set([repr(x) for x in session.query(Manager)]) \
            == set(['Manager Tom knows how to manage things'])
        assert set([repr(x) for x in session.query(Engineer)]) \
            == set(['Engineer Kurt knows how to hack'])
        manager = session.query(Manager).one()
        session.expire(manager, ['manager_data'])
        eq_(manager.manager_data, 'knows how to manage things')

    def test_multi_level_no_base(self):
        pjoin = polymorphic_union({'manager': managers_table, 'engineer'
                                  : engineers_table, 'hacker'
                                  : hackers_table}, 'type', 'pjoin')
        pjoin2 = polymorphic_union({'engineer': engineers_table,
                                   'hacker': hackers_table}, 'type',
                                   'pjoin2')
        employee_mapper = mapper(Employee, pjoin,
                                 polymorphic_on=pjoin.c.type)
        manager_mapper = mapper(Manager, managers_table,
                                inherits=employee_mapper,
                                concrete=True,
                                polymorphic_identity='manager')
        engineer_mapper = mapper(
            Engineer,
            engineers_table,
            with_polymorphic=('*', pjoin2),
            polymorphic_on=pjoin2.c.type,
            inherits=employee_mapper,
            concrete=True,
            polymorphic_identity='engineer',
            )
        hacker_mapper = mapper(Hacker, hackers_table,
                               inherits=engineer_mapper, concrete=True,
                               polymorphic_identity='hacker')
        session = create_session()
        tom = Manager('Tom', 'knows how to manage things')

        assert_raises_message(AttributeError,
            "does not implement attribute .?'type' at the instance level.",
            setattr, tom, "type", "sometype")

        jerry = Engineer('Jerry', 'knows how to program')
        hacker = Hacker('Kurt', 'Badass', 'knows how to hack')

        assert_raises_message(AttributeError,
            "does not implement attribute .?'type' at the instance level.",
            setattr, hacker, "type", "sometype")

        session.add_all((tom, jerry, hacker))
        session.flush()

        # ensure "readonly" on save logic didn't pollute the
        # expired_attributes collection

        assert 'nickname' \
            not in attributes.instance_state(jerry).expired_attributes
        assert 'name' \
            not in attributes.instance_state(jerry).expired_attributes
        assert 'name' \
            not in attributes.instance_state(hacker).expired_attributes
        assert 'nickname' \
            not in attributes.instance_state(hacker).expired_attributes

        def go():
            eq_(jerry.name, 'Jerry')
            eq_(hacker.nickname, 'Badass')

        self.assert_sql_count(testing.db, go, 0)
        session.expunge_all()
        assert repr(session.query(Employee).filter(Employee.name
                    == 'Tom').one()) \
            == 'Manager Tom knows how to manage things'
        assert repr(session.query(Manager).filter(Manager.name == 'Tom'
                    ).one()) == 'Manager Tom knows how to manage things'
        assert set([repr(x) for x in session.query(Employee).all()]) \
            == set(['Engineer Jerry knows how to program',
                   'Manager Tom knows how to manage things',
                   "Hacker Kurt 'Badass' knows how to hack"])
        assert set([repr(x) for x in session.query(Manager).all()]) \
            == set(['Manager Tom knows how to manage things'])
        assert set([repr(x) for x in session.query(Engineer).all()]) \
            == set(['Engineer Jerry knows how to program',
                   "Hacker Kurt 'Badass' knows how to hack"])
        assert set([repr(x) for x in session.query(Hacker).all()]) \
            == set(["Hacker Kurt 'Badass' knows how to hack"])

    def test_multi_level_with_base(self):
        pjoin = polymorphic_union({
            'employee': employees_table,
            'manager': managers_table,
            'engineer': engineers_table,
            'hacker': hackers_table,
            }, 'type', 'pjoin')
        pjoin2 = polymorphic_union({'engineer': engineers_table,
                                   'hacker': hackers_table}, 'type',
                                   'pjoin2')
        employee_mapper = mapper(Employee, employees_table,
                                 with_polymorphic=('*', pjoin),
                                 polymorphic_on=pjoin.c.type)
        manager_mapper = mapper(Manager, managers_table,
                                inherits=employee_mapper,
                                concrete=True,
                                polymorphic_identity='manager')
        engineer_mapper = mapper(
            Engineer,
            engineers_table,
            with_polymorphic=('*', pjoin2),
            polymorphic_on=pjoin2.c.type,
            inherits=employee_mapper,
            concrete=True,
            polymorphic_identity='engineer',
            )
        hacker_mapper = mapper(Hacker, hackers_table,
                               inherits=engineer_mapper, concrete=True,
                               polymorphic_identity='hacker')
        session = create_session()
        tom = Manager('Tom', 'knows how to manage things')
        jerry = Engineer('Jerry', 'knows how to program')
        hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
        session.add_all((tom, jerry, hacker))
        session.flush()

        def go():
            eq_(jerry.name, 'Jerry')
            eq_(hacker.nickname, 'Badass')

        self.assert_sql_count(testing.db, go, 0)
        session.expunge_all()

        # check that we aren't getting a cartesian product in the raw
        # SQL. this requires that Engineer's polymorphic discriminator
        # is not rendered in the statement which is only against
        # Employee's "pjoin"

        assert len(testing.db.execute(session.query(Employee).with_labels().statement).fetchall()) \
            == 3
        assert set([repr(x) for x in session.query(Employee)]) \
            == set(['Engineer Jerry knows how to program',
                   'Manager Tom knows how to manage things',
                   "Hacker Kurt 'Badass' knows how to hack"])
        assert set([repr(x) for x in session.query(Manager)]) \
            == set(['Manager Tom knows how to manage things'])
        assert set([repr(x) for x in session.query(Engineer)]) \
            == set(['Engineer Jerry knows how to program',
                   "Hacker Kurt 'Badass' knows how to hack"])
        assert set([repr(x) for x in session.query(Hacker)]) \
            == set(["Hacker Kurt 'Badass' knows how to hack"])

    def test_without_default_polymorphic(self):
        pjoin = polymorphic_union({
            'employee': employees_table,
            'manager': managers_table,
            'engineer': engineers_table,
            'hacker': hackers_table,
            }, 'type', 'pjoin')
        pjoin2 = polymorphic_union({'engineer': engineers_table,
                                   'hacker': hackers_table}, 'type',
                                   'pjoin2')
        employee_mapper = mapper(Employee, employees_table,
                                 polymorphic_identity='employee')
        manager_mapper = mapper(Manager, managers_table,
                                inherits=employee_mapper,
                                concrete=True,
                                polymorphic_identity='manager')
        engineer_mapper = mapper(Engineer, engineers_table,
                                 inherits=employee_mapper,
                                 concrete=True,
                                 polymorphic_identity='engineer')
        hacker_mapper = mapper(Hacker, hackers_table,
                               inherits=engineer_mapper, concrete=True,
                               polymorphic_identity='hacker')
        session = create_session()
        jdoe = Employee('Jdoe')
        tom = Manager('Tom', 'knows how to manage things')
        jerry = Engineer('Jerry', 'knows how to program')
        hacker = Hacker('Kurt', 'Badass', 'knows how to hack')
        session.add_all((jdoe, tom, jerry, hacker))
        session.flush()
        eq_(len(testing.db.execute(session.query(Employee).with_polymorphic('*'
            , pjoin,
            pjoin.c.type).with_labels().statement).fetchall()), 4)
        eq_(session.query(Employee).get(jdoe.employee_id), jdoe)
        eq_(session.query(Engineer).get(jerry.employee_id), jerry)
        eq_(set([repr(x) for x in
            session.query(Employee).with_polymorphic('*', pjoin,
            pjoin.c.type)]), set(['Employee Jdoe',
            'Engineer Jerry knows how to program',
            'Manager Tom knows how to manage things',
            "Hacker Kurt 'Badass' knows how to hack"]))
        eq_(set([repr(x) for x in session.query(Manager)]),
            set(['Manager Tom knows how to manage things']))
        eq_(set([repr(x) for x in
            session.query(Engineer).with_polymorphic('*', pjoin2,
            pjoin2.c.type)]), set(['Engineer Jerry knows how to program'
            , "Hacker Kurt 'Badass' knows how to hack"]))
        eq_(set([repr(x) for x in session.query(Hacker)]),
            set(["Hacker Kurt 'Badass' knows how to hack"]))

        # test adaption of the column by wrapping the query in a
        # subquery

        eq_(len(testing.db.execute(session.query(Engineer).with_polymorphic('*'
            , pjoin2,
            pjoin2.c.type).from_self().statement).fetchall()), 2)
        eq_(set([repr(x) for x in
            session.query(Engineer).with_polymorphic('*', pjoin2,
            pjoin2.c.type).from_self()]),
            set(['Engineer Jerry knows how to program',
            "Hacker Kurt 'Badass' knows how to hack"]))

    def test_relationship(self):
        pjoin = polymorphic_union({'manager': managers_table, 'engineer'
                                  : engineers_table}, 'type', 'pjoin')
        mapper(Company, companies, properties={'employees'
               : relationship(Employee)})
        employee_mapper = mapper(Employee, pjoin,
                                 polymorphic_on=pjoin.c.type)
        manager_mapper = mapper(Manager, managers_table,
                                inherits=employee_mapper,
                                concrete=True,
                                polymorphic_identity='manager')
        engineer_mapper = mapper(Engineer, engineers_table,
                                 inherits=employee_mapper,
                                 concrete=True,
                                 polymorphic_identity='engineer')
        session = create_session()
        c = Company()
        c.employees.append(Manager('Tom', 'knows how to manage things'))
        c.employees.append(Engineer('Kurt', 'knows how to hack'))
        session.add(c)
        session.flush()
        session.expunge_all()

        def go():
            c2 = session.query(Company).get(c.id)
            assert set([repr(x) for x in c2.employees]) \
                == set(['Engineer Kurt knows how to hack',
                       'Manager Tom knows how to manage things'])

        self.assert_sql_count(testing.db, go, 2)
        session.expunge_all()

        def go():
            c2 = \
                session.query(Company).options(joinedload(Company.employees)).get(c.id)
            assert set([repr(x) for x in c2.employees]) \
                == set(['Engineer Kurt knows how to hack',
                       'Manager Tom knows how to manage things'])

        self.assert_sql_count(testing.db, go, 1)


class PropertyInheritanceTest(fixtures.MappedTest):

    @classmethod
    def define_tables(cls, metadata):
        Table('a_table', metadata, Column('id', Integer,
              primary_key=True, test_needs_autoincrement=True),
              Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
              Column('aname', String(50)))
        Table('b_table', metadata, Column('id', Integer,
              primary_key=True, test_needs_autoincrement=True),
              Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
              Column('bname', String(50)))

        Table('c_table', metadata, Column('id', Integer,
            primary_key=True, test_needs_autoincrement=True),
            Column('some_dest_id', Integer, ForeignKey('dest_table.id')),
            Column('cname', String(50)))

        Table('dest_table', metadata, Column('id', Integer,
              primary_key=True, test_needs_autoincrement=True),
              Column('name', String(50)))

    @classmethod
    def setup_classes(cls):

        class A(cls.Comparable):
            pass

        class B(A):
            pass

        class C(A):
            pass

        class Dest(cls.Comparable):
            pass

    def test_noninherited_warning(self):
        A, B, b_table, a_table, Dest, dest_table = (self.classes.A,
                                self.classes.B,
                                self.tables.b_table,
                                self.tables.a_table,
                                self.classes.Dest,
                                self.tables.dest_table)

        mapper(A, a_table, properties={'some_dest': relationship(Dest)})
        mapper(B, b_table, inherits=A, concrete=True)
        mapper(Dest, dest_table)
        b = B()
        dest = Dest()
        assert_raises(AttributeError, setattr, b, 'some_dest', dest)
        clear_mappers()

        mapper(A, a_table, properties={'a_id': a_table.c.id})
        mapper(B, b_table, inherits=A, concrete=True)
        mapper(Dest, dest_table)
        b = B()
        assert_raises(AttributeError, setattr, b, 'a_id', 3)
        clear_mappers()

        mapper(A, a_table, properties={'a_id': a_table.c.id})
        mapper(B, b_table, inherits=A, concrete=True)
        mapper(Dest, dest_table)

    def test_inheriting(self):
        A, B, b_table, a_table, Dest, dest_table = (self.classes.A,
                                self.classes.B,
                                self.tables.b_table,
                                self.tables.a_table,
                                self.classes.Dest,
                                self.tables.dest_table)

        mapper(A, a_table, properties={
                'some_dest': relationship(Dest,back_populates='many_a')
            })
        mapper(B, b_table, inherits=A, concrete=True,
               properties={
                    'some_dest': relationship(Dest, back_populates='many_b')
                })

        mapper(Dest, dest_table, properties={
                    'many_a': relationship(A,back_populates='some_dest'),
                    'many_b': relationship(B,back_populates='some_dest')
                })
        sess = sessionmaker()()
        dest1 = Dest(name='c1')
        dest2 = Dest(name='c2')
        a1 = A(some_dest=dest1, aname='a1')
        a2 = A(some_dest=dest2, aname='a2')
        b1 = B(some_dest=dest1, bname='b1')
        b2 = B(some_dest=dest1, bname='b2')
        assert_raises(AttributeError, setattr, b1, 'aname', 'foo')
        assert_raises(AttributeError, getattr, A, 'bname')
        assert dest2.many_a == [a2]
        assert dest1.many_a == [a1]
        assert dest1.many_b == [b1, b2]
        sess.add_all([dest1, dest2])
        sess.commit()
        assert sess.query(Dest).filter(Dest.many_a.contains(a2)).one() is dest2
        assert dest2.many_a == [a2]
        assert dest1.many_a == [a1]
        assert dest1.many_b == [b1, b2]
        assert sess.query(B).filter(B.bname == 'b1').one() is b1

    def test_polymorphic_backref(self):
        """test multiple backrefs to the same polymorphically-loading
        attribute."""

        A, C, B, c_table, b_table, a_table, Dest, dest_table = (self.classes.A,
                                self.classes.C,
                                self.classes.B,
                                self.tables.c_table,
                                self.tables.b_table,
                                self.tables.a_table,
                                self.classes.Dest,
                                self.tables.dest_table)


        ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c':c_table},
                                'type','ajoin')
        mapper(
            A,
            a_table,
            with_polymorphic=('*', ajoin),
            polymorphic_on=ajoin.c.type,
            polymorphic_identity='a',
            properties={
                'some_dest': relationship(Dest, back_populates='many_a')
                },
            )
        mapper(
            B,
            b_table,
            inherits=A,
            concrete=True,
            polymorphic_identity='b',
            properties={
                    'some_dest': relationship(Dest, back_populates='many_a')},
            )

        mapper(
            C,
            c_table,
            inherits=A,
            concrete=True,
            polymorphic_identity='c',
            properties={
                    'some_dest': relationship(Dest, back_populates='many_a')},
            )

        mapper(Dest, dest_table, properties={
                'many_a': relationship(A,
                            back_populates='some_dest',
                            order_by=ajoin.c.id)
                        }
                )

        sess = sessionmaker()()
        dest1 = Dest(name='c1')
        dest2 = Dest(name='c2')
        a1 = A(some_dest=dest1, aname='a1', id=1)
        a2 = A(some_dest=dest2, aname='a2', id=2)
        b1 = B(some_dest=dest1, bname='b1', id=3)
        b2 = B(some_dest=dest1, bname='b2', id=4)
        c1 = C(some_dest=dest1, cname='c1', id=5)
        c2 = C(some_dest=dest2, cname='c2', id=6)

        eq_([a2, c2], dest2.many_a)
        eq_([a1, b1, b2, c1], dest1.many_a)
        sess.add_all([dest1, dest2])
        sess.commit()

        assert sess.query(Dest).filter(Dest.many_a.contains(a2)).one() is dest2
        assert sess.query(Dest).filter(Dest.many_a.contains(b1)).one() is dest1
        assert sess.query(Dest).filter(Dest.many_a.contains(c2)).one() is dest2

        eq_(dest2.many_a, [a2, c2])
        eq_(dest1.many_a, [a1, b1, b2, c1])
        sess.expire_all()

        def go():
            eq_(
                [
                    Dest(many_a=[A(aname='a1'),
                                    B(bname='b1'),
                                    B(bname='b2'),
                                    C(cname='c1')]),
                    Dest(many_a=[A(aname='a2'), C(cname='c2')])],
                sess.query(Dest).options(joinedload(Dest.many_a)).order_by(Dest.id).all())

        self.assert_sql_count(testing.db, go, 1)

    def test_merge_w_relationship(self):
        A, C, B, c_table, b_table, a_table, Dest, dest_table = (self.classes.A,
                                self.classes.C,
                                self.classes.B,
                                self.tables.c_table,
                                self.tables.b_table,
                                self.tables.a_table,
                                self.classes.Dest,
                                self.tables.dest_table)

        ajoin = polymorphic_union({'a': a_table, 'b': b_table, 'c':c_table},
                                'type','ajoin')
        mapper(
            A,
            a_table,
            with_polymorphic=('*', ajoin),
            polymorphic_on=ajoin.c.type,
            polymorphic_identity='a',
            properties={
                'some_dest': relationship(Dest, back_populates='many_a')
                },
            )
        mapper(
            B,
            b_table,
            inherits=A,
            concrete=True,
            polymorphic_identity='b',
            properties={
                    'some_dest': relationship(Dest, back_populates='many_a')},
            )

        mapper(
            C,
            c_table,
            inherits=A,
            concrete=True,
            polymorphic_identity='c',
            properties={
                    'some_dest': relationship(Dest, back_populates='many_a')},
            )

        mapper(Dest, dest_table, properties={
                'many_a': relationship(A,
                            back_populates='some_dest',
                            order_by=ajoin.c.id)
                        }
                )

        assert C.some_dest.property.parent is class_mapper(C)
        assert B.some_dest.property.parent is class_mapper(B)
        assert A.some_dest.property.parent is class_mapper(A)

        sess = sessionmaker()()
        dest1 = Dest(name='d1')
        dest2 = Dest(name='d2')
        a1 = A(some_dest=dest2, aname='a1')
        b1 = B(some_dest=dest1, bname='b1')
        c1 = C(some_dest=dest2, cname='c1')
        sess.add_all([dest1, dest2, c1, a1, b1])
        sess.commit()

        sess2 = sessionmaker()()
        merged_c1 = sess2.merge(c1)
        eq_(merged_c1.some_dest.name, 'd2')
        eq_(merged_c1.some_dest_id, c1.some_dest_id)

class ManyToManyTest(fixtures.MappedTest):

    @classmethod
    def define_tables(cls, metadata):
        Table('base', metadata, Column('id', Integer, primary_key=True,
              test_needs_autoincrement=True))
        Table('sub', metadata, Column('id', Integer, primary_key=True,
              test_needs_autoincrement=True))
        Table('base_mtom', metadata, Column('base_id', Integer,
              ForeignKey('base.id'), primary_key=True),
              Column('related_id', Integer, ForeignKey('related.id'),
              primary_key=True))
        Table('sub_mtom', metadata, Column('base_id', Integer,
              ForeignKey('sub.id'), primary_key=True),
              Column('related_id', Integer, ForeignKey('related.id'),
              primary_key=True))
        Table('related', metadata, Column('id', Integer,
              primary_key=True, test_needs_autoincrement=True))

    @classmethod
    def setup_classes(cls):
        class Base(cls.Comparable):
            pass

        class Sub(Base):
            pass

        class Related(cls.Comparable):
            pass


    def test_selective_relationships(self):
        sub, base_mtom, Related, Base, related, sub_mtom, base, Sub = (self.tables.sub,
                                self.tables.base_mtom,
                                self.classes.Related,
                                self.classes.Base,
                                self.tables.related,
                                self.tables.sub_mtom,
                                self.tables.base,
                                self.classes.Sub)

        mapper(Base, base, properties={'related': relationship(Related,
               secondary=base_mtom, backref='bases',
               order_by=related.c.id)})
        mapper(Sub, sub, inherits=Base, concrete=True,
               properties={'related': relationship(Related,
               secondary=sub_mtom, backref='subs',
               order_by=related.c.id)})
        mapper(Related, related)
        sess = sessionmaker()()
        b1, s1, r1, r2, r3 = Base(), Sub(), Related(), Related(), \
            Related()
        b1.related.append(r1)
        b1.related.append(r2)
        s1.related.append(r2)
        s1.related.append(r3)
        sess.add_all([b1, s1])
        sess.commit()
        eq_(s1.related, [r2, r3])
        eq_(b1.related, [r1, r2])


class ColKeysTest(fixtures.MappedTest):

    @classmethod
    def define_tables(cls, metadata):
        global offices_table, refugees_table
        refugees_table = Table('refugee', metadata, Column('refugee_fid'
                               , Integer, primary_key=True,
                               test_needs_autoincrement=True),
                               Column('refugee_name', String(30),
                               key='name'))
        offices_table = Table('office', metadata, Column('office_fid',
                              Integer, primary_key=True,
                              test_needs_autoincrement=True),
                              Column('office_name', String(30),
                              key='name'))

    @classmethod
    def insert_data(cls):
        refugees_table.insert().execute(dict(refugee_fid=1,
                name='refugee1'), dict(refugee_fid=2, name='refugee2'
                ))
        offices_table.insert().execute(dict(office_fid=1,
                name='office1'), dict(office_fid=2, name='office2'))

    def test_keys(self):
        pjoin = polymorphic_union({'refugee': refugees_table, 'office'
                                  : offices_table}, 'type', 'pjoin')

        class Location(object):
            pass

        class Refugee(Location):
            pass

        class Office(Location):
            pass


        location_mapper = mapper(Location, pjoin,
                                 polymorphic_on=pjoin.c.type,
                                 polymorphic_identity='location')
        office_mapper = mapper(Office, offices_table,
                               inherits=location_mapper, concrete=True,
                               polymorphic_identity='office')
        refugee_mapper = mapper(Refugee, refugees_table,
                                inherits=location_mapper,
                                concrete=True,
                                polymorphic_identity='refugee')
        sess = create_session()
        eq_(sess.query(Refugee).get(1).name, 'refugee1')
        eq_(sess.query(Refugee).get(2).name, 'refugee2')
        eq_(sess.query(Office).get(1).name, 'office1')
        eq_(sess.query(Office).get(2).name, 'office2')
