import datetime
import sqlalchemy as sa
from sqlalchemy.testing import engines
from sqlalchemy import testing
from sqlalchemy import (
    Integer, String, Date, ForeignKey, orm, exc, select, TypeDecorator)
from sqlalchemy.testing.schema import Table, Column
from sqlalchemy.orm import (
    mapper, relationship, Session, create_session, sessionmaker,
    exc as orm_exc)
from sqlalchemy.testing import (
    eq_, assert_raises, assert_raises_message, fixtures)
from sqlalchemy.testing.assertsql import CompiledSQL
import uuid


def make_uuid():
    return uuid.uuid4().hex


class VersioningTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table('version_table', metadata,
              Column('id', Integer, primary_key=True,
                     test_needs_autoincrement=True),
              Column('version_id', Integer, nullable=False),
              Column('value', String(40), nullable=False))

    @classmethod
    def setup_classes(cls):
        class Foo(cls.Basic):
            pass

    def _fixture(self):
        Foo, version_table = self.classes.Foo, self.tables.version_table

        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
        s1 = Session()
        return s1

    @engines.close_open_connections
    def test_notsane_warning(self):
        Foo = self.classes.Foo

        save = testing.db.dialect.supports_sane_rowcount
        testing.db.dialect.supports_sane_rowcount = False
        try:
            s1 = self._fixture()
            f1 = Foo(value='f1')
            f2 = Foo(value='f2')
            s1.add_all((f1, f2))
            s1.commit()

            f1.value = 'f1rev2'
            assert_raises(sa.exc.SAWarning, s1.commit)
        finally:
            testing.db.dialect.supports_sane_rowcount = save

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
    def test_basic(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        f2 = Foo(value='f2')
        s1.add_all((f1, f2))
        s1.commit()

        f1.value = 'f1rev2'
        s1.commit()

        s2 = create_session(autocommit=False)
        f1_s = s2.query(Foo).get(f1.id)
        f1_s.value = 'f1rev3'
        s2.commit()

        f1.value = 'f1rev3mine'

        # Only dialects with a sane rowcount can detect the
        # StaleDataError
        if testing.db.dialect.supports_sane_rowcount:
            assert_raises_message(
                sa.orm.exc.StaleDataError,
                r"UPDATE statement on table 'version_table' expected "
                r"to update 1 row\(s\); 0 were matched.", s1.commit),
            s1.rollback()
        else:
            s1.commit()

        # new in 0.5 !  don't need to close the session
        f1 = s1.query(Foo).get(f1.id)
        f2 = s1.query(Foo).get(f2.id)

        f1_s.value = 'f1rev4'
        s2.commit()

        s1.delete(f1)
        s1.delete(f2)

        if testing.db.dialect.supports_sane_rowcount:
            assert_raises_message(
                sa.orm.exc.StaleDataError,
                r"DELETE statement on table 'version_table' expected "
                r"to delete 2 row\(s\); 1 were matched.",
                s1.commit)
        else:
            s1.commit()

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
    def test_bump_version(self):
        """test that version number can be bumped.

        Ensures that the UPDATE or DELETE is against the
        last committed version of version_id_col, not the modified
        state.

        """

        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        s1.add(f1)
        s1.commit()
        eq_(f1.version_id, 1)
        f1.version_id = 2
        s1.commit()
        eq_(f1.version_id, 2)

        # skip an id, test that history
        # is honored
        f1.version_id = 4
        f1.value = "something new"
        s1.commit()
        eq_(f1.version_id, 4)

        f1.version_id = 5
        s1.delete(f1)
        s1.commit()
        eq_(s1.query(Foo).count(), 0)

    @testing.emits_warning(r'.*does not support updated rowcount')
    @engines.close_open_connections
    def test_versioncheck(self):
        """query.with_lockmode performs a 'version check' on an already loaded
        instance"""

        Foo = self.classes.Foo

        s1 = self._fixture()
        f1s1 = Foo(value='f1 value')
        s1.add(f1s1)
        s1.commit()

        s2 = create_session(autocommit=False)
        f1s2 = s2.query(Foo).get(f1s1.id)
        f1s2.value = 'f1 new value'
        s2.commit()

        # load, version is wrong
        assert_raises_message(
            sa.orm.exc.StaleDataError,
            r"Instance .* has version id '\d+' which does not "
            r"match database-loaded version id '\d+'",
            s1.query(Foo).with_lockmode('read').get, f1s1.id
        )

        # reload it - this expires the old version first
        s1.refresh(f1s1, lockmode='read')

        # now assert version OK
        s1.query(Foo).with_lockmode('read').get(f1s1.id)

        # assert brand new load is OK too
        s1.close()
        s1.query(Foo).with_lockmode('read').get(f1s1.id)

    @testing.emits_warning(r'.*does not support updated rowcount')
    @engines.close_open_connections
    @testing.requires.update_nowait
    def test_versioncheck_for_update(self):
        """query.with_lockmode performs a 'version check' on an already loaded
        instance"""

        Foo = self.classes.Foo

        s1 = self._fixture()
        f1s1 = Foo(value='f1 value')
        s1.add(f1s1)
        s1.commit()

        s2 = create_session(autocommit=False)
        f1s2 = s2.query(Foo).get(f1s1.id)
        s2.refresh(f1s2, lockmode='update')
        f1s2.value = 'f1 new value'

        assert_raises(
            exc.DBAPIError,
            s1.refresh, f1s1, lockmode='update_nowait'
        )
        s1.rollback()

        s2.commit()
        s1.refresh(f1s1, lockmode='update_nowait')
        assert f1s1.version_id == f1s2.version_id

    @testing.emits_warning(r'.*does not support updated rowcount')
    @engines.close_open_connections
    def test_noversioncheck(self):
        """test query.with_lockmode works when the mapper has no version id
        col"""

        Foo, version_table = self.classes.Foo, self.tables.version_table

        s1 = create_session(autocommit=False)
        mapper(Foo, version_table)
        f1s1 = Foo(value="foo", version_id=0)
        s1.add(f1s1)
        s1.commit()

        s2 = create_session(autocommit=False)
        f1s2 = s2.query(Foo).with_lockmode('read').get(f1s1.id)
        assert f1s2.id == f1s1.id
        assert f1s2.value == f1s1.value

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_merge_no_version(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        s1.add(f1)
        s1.commit()

        f1.value = 'f2'
        s1.commit()

        f2 = Foo(id=f1.id, value='f3')
        f3 = s1.merge(f2)
        assert f3 is f1
        s1.commit()
        eq_(f3.version_id, 3)

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_merge_correct_version(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        s1.add(f1)
        s1.commit()

        f1.value = 'f2'
        s1.commit()

        f2 = Foo(id=f1.id, value='f3', version_id=2)
        f3 = s1.merge(f2)
        assert f3 is f1
        s1.commit()
        eq_(f3.version_id, 3)

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_merge_incorrect_version(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        s1.add(f1)
        s1.commit()

        f1.value = 'f2'
        s1.commit()

        f2 = Foo(id=f1.id, value='f3', version_id=1)
        assert_raises_message(
            orm_exc.StaleDataError,
            "Version id '1' on merged state "
            "<Foo at .*?> does not match existing version '2'. "
            "Leave the version attribute unset when "
            "merging to update the most recent version.",
            s1.merge, f2
        )

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_merge_incorrect_version_not_in_session(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(value='f1')
        s1.add(f1)
        s1.commit()

        f1.value = 'f2'
        s1.commit()

        f2 = Foo(id=f1.id, value='f3', version_id=1)
        s1.close()

        assert_raises_message(
            orm_exc.StaleDataError,
            "Version id '1' on merged state "
            "<Foo at .*?> does not match existing version '2'. "
            "Leave the version attribute unset when "
            "merging to update the most recent version.",
            s1.merge, f2
        )


class ColumnTypeTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        class SpecialType(TypeDecorator):
            impl = Date

            def process_bind_param(self, value, dialect):
                assert isinstance(value, datetime.date)
                return value

        Table('version_table', metadata,
              Column('id', SpecialType, primary_key=True),
              Column('version_id', Integer, nullable=False),
              Column('value', String(40), nullable=False))

    @classmethod
    def setup_classes(cls):
        class Foo(cls.Basic):
            pass

    def _fixture(self):
        Foo, version_table = self.classes.Foo, self.tables.version_table

        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
        s1 = Session()
        return s1

    @engines.close_open_connections
    def test_update(self):
        Foo = self.classes.Foo

        s1 = self._fixture()
        f1 = Foo(id=datetime.date.today(), value='f1')
        s1.add(f1)
        s1.commit()

        f1.value = 'f1rev2'
        s1.commit()


class RowSwitchTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            'p', metadata,
            Column('id', String(10), primary_key=True),
            Column('version_id', Integer, default=1, nullable=False),
            Column('data', String(50))
        )
        Table(
            'c', metadata,
            Column('id', String(10), ForeignKey('p.id'), primary_key=True),
            Column('version_id', Integer, default=1, nullable=False),
            Column('data', String(50))
        )

    @classmethod
    def setup_classes(cls):

        class P(cls.Basic):
            pass

        class C(cls.Basic):
            pass

    @classmethod
    def setup_mappers(cls):
        p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P

        mapper(
            P, p, version_id_col=p.c.version_id, properties={
                'c': relationship(
                    C, uselist=False, cascade='all, delete-orphan')})
        mapper(C, c, version_id_col=c.c.version_id)

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_row_switch(self):
        P = self.classes.P

        session = sessionmaker()()
        session.add(P(id='P1', data='P version 1'))
        session.commit()
        session.close()

        p = session.query(P).first()
        session.delete(p)
        session.add(P(id='P1', data="really a row-switch"))
        session.commit()

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_child_row_switch(self):
        P, C = self.classes.P, self.classes.C

        assert P.c.property.strategy.use_get

        session = sessionmaker()()
        session.add(P(id='P1', data='P version 1'))
        session.commit()
        session.close()

        p = session.query(P).first()
        p.c = C(data='child version 1')
        session.commit()

        p = session.query(P).first()
        p.c = C(data='child row-switch')
        session.commit()


class AlternateGeneratorTest(fixtures.MappedTest):
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            'p', metadata,
            Column('id', String(10), primary_key=True),
            Column('version_id', String(32), nullable=False),
            Column('data', String(50))
        )
        Table(
            'c', metadata,
            Column('id', String(10), ForeignKey('p.id'), primary_key=True),
            Column('version_id', String(32), nullable=False),
            Column('data', String(50))
        )

    @classmethod
    def setup_classes(cls):

        class P(cls.Basic):
            pass

        class C(cls.Basic):
            pass

    @classmethod
    def setup_mappers(cls):
        p, c, C, P = cls.tables.p, cls.tables.c, cls.classes.C, cls.classes.P

        mapper(
            P, p, version_id_col=p.c.version_id,
            version_id_generator=lambda x: make_uuid(),
            properties={
                'c': relationship(
                    C, uselist=False, cascade='all, delete-orphan')
            })
        mapper(
            C, c, version_id_col=c.c.version_id,
            version_id_generator=lambda x: make_uuid(),
        )

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support updated rowcount')
    def test_row_switch(self):
        P = self.classes.P

        session = sessionmaker()()
        session.add(P(id='P1', data='P version 1'))
        session.commit()
        session.close()

        p = session.query(P).first()
        session.delete(p)
        session.add(P(id='P1', data="really a row-switch"))
        session.commit()

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
    def test_child_row_switch_one(self):
        P, C = self.classes.P, self.classes.C

        assert P.c.property.strategy.use_get

        session = sessionmaker()()
        session.add(P(id='P1', data='P version 1'))
        session.commit()
        session.close()

        p = session.query(P).first()
        p.c = C(data='child version 1')
        session.commit()

        p = session.query(P).first()
        p.c = C(data='child row-switch')
        session.commit()

    @testing.emits_warning_on(
        '+zxjdbc', r'.*does not support (update|delete)d rowcount')
    def test_child_row_switch_two(self):
        P = self.classes.P

        Session = sessionmaker()

        # TODO: not sure this test is
        # testing exactly what its looking for

        sess1 = Session()
        sess1.add(P(id='P1', data='P version 1'))
        sess1.commit()
        sess1.close()

        p1 = sess1.query(P).first()

        sess2 = Session()
        p2 = sess2.query(P).first()

        sess1.delete(p1)
        sess1.commit()

        # this can be removed and it still passes
        sess1.add(P(id='P1', data='P version 2'))
        sess1.commit()

        p2.data = 'P overwritten by concurrent tx'
        if testing.db.dialect.supports_sane_rowcount:
            assert_raises_message(
                orm.exc.StaleDataError,
                r"UPDATE statement on table 'p' expected to update "
                r"1 row\(s\); 0 were matched.",
                sess2.commit
            )
        else:
            sess2.commit


class InheritanceTwoVersionIdsTest(fixtures.MappedTest):
    """Test versioning where both parent/child table have a
    versioning column.

    """
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            'base', metadata,
            Column(
                'id', Integer, primary_key=True,
                test_needs_autoincrement=True),
            Column('version_id', Integer, nullable=True),
            Column('data', String(50))
        )
        Table(
            'sub', metadata,
            Column('id', Integer, ForeignKey('base.id'), primary_key=True),
            Column('version_id', Integer, nullable=False),
            Column('sub_data', String(50))
        )

    @classmethod
    def setup_classes(cls):

        class Base(cls.Basic):
            pass

        class Sub(Base):
            pass

    def test_base_both(self):
        Base, sub, base, Sub = (
            self.classes.Base, self.tables.sub, self.tables.base,
            self.classes.Sub)

        mapper(Base, base, version_id_col=base.c.version_id)
        mapper(Sub, sub, inherits=Base)

        session = Session()
        b1 = Base(data='b1')
        session.add(b1)
        session.commit()
        eq_(b1.version_id, 1)
        # base is populated
        eq_(select([base.c.version_id]).scalar(), 1)

    def test_sub_both(self):
        Base, sub, base, Sub = (
            self.classes.Base, self.tables.sub, self.tables.base,
            self.classes.Sub)

        mapper(Base, base, version_id_col=base.c.version_id)
        mapper(Sub, sub, inherits=Base)

        session = Session()
        s1 = Sub(data='s1', sub_data='s1')
        session.add(s1)
        session.commit()

        # table is populated
        eq_(select([sub.c.version_id]).scalar(), 1)

        # base is populated
        eq_(select([base.c.version_id]).scalar(), 1)

    def test_sub_only(self):
        Base, sub, base, Sub = (
            self.classes.Base, self.tables.sub, self.tables.base,
            self.classes.Sub)

        mapper(Base, base)
        mapper(Sub, sub, inherits=Base, version_id_col=sub.c.version_id)

        session = Session()
        s1 = Sub(data='s1', sub_data='s1')
        session.add(s1)
        session.commit()

        # table is populated
        eq_(select([sub.c.version_id]).scalar(), 1)

        # base is not
        eq_(select([base.c.version_id]).scalar(), None)

    def test_mismatch_version_col_warning(self):
        Base, sub, base, Sub = (
            self.classes.Base, self.tables.sub, self.tables.base,
            self.classes.Sub)

        mapper(Base, base, version_id_col=base.c.version_id)

        assert_raises_message(
            exc.SAWarning,
            "Inheriting version_id_col 'version_id' does not "
            "match inherited version_id_col 'version_id' and will not "
            "automatically populate the inherited versioning column. "
            "version_id_col should only be specified on "
            "the base-most mapper that includes versioning.",
            mapper, Sub, sub, inherits=Base,
            version_id_col=sub.c.version_id)


class ServerVersioningTest(fixtures.MappedTest):
    run_define_tables = 'each'
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        from sqlalchemy.sql import ColumnElement
        from sqlalchemy.ext.compiler import compiles
        import itertools

        counter = itertools.count(1)

        class IncDefault(ColumnElement):
            pass

        @compiles(IncDefault)
        def compile(element, compiler, **kw):
            # cache the counter value on the statement
            # itself so the assertsql system gets the same
            # value when it compiles the statement a second time
            stmt = compiler.statement
            if hasattr(stmt, "_counter"):
                return stmt._counter
            else:
                stmt._counter = str(next(counter))
                return stmt._counter

        Table(
            'version_table', metadata,
            Column(
                'id', Integer, primary_key=True,
                test_needs_autoincrement=True),
            Column(
                'version_id', Integer, nullable=False,
                default=IncDefault(), onupdate=IncDefault()),
            Column('value', String(40), nullable=False))

    @classmethod
    def setup_classes(cls):

        class Foo(cls.Basic):
            pass

        class Bar(cls.Basic):
            pass

    def _fixture(self, expire_on_commit=True):
        Foo, version_table = self.classes.Foo, self.tables.version_table

        mapper(
            Foo, version_table, version_id_col=version_table.c.version_id,
            version_id_generator=False,
        )

        s1 = Session(expire_on_commit=expire_on_commit)
        return s1

    def test_insert_col(self):
        sess = self._fixture()

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)

        statements = [
            # note that the assertsql tests the rule against
            # "default" - on a "returning" backend, the statement
            # includes "RETURNING"
            CompiledSQL(
                "INSERT INTO version_table (version_id, value) "
                "VALUES (1, :value)",
                lambda ctx: [{'value': 'f1'}]
            )
        ]
        if not testing.db.dialect.implicit_returning:
            # DBs without implicit returning, we must immediately
            # SELECT for the new version id
            statements.append(
                CompiledSQL(
                    "SELECT version_table.version_id "
                    "AS version_table_version_id "
                    "FROM version_table WHERE version_table.id = :param_1",
                    lambda ctx: [{"param_1": 1}]
                )
            )
        self.assert_sql_execution(testing.db, sess.flush, *statements)

    def test_update_col(self):
        sess = self._fixture()

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)
        sess.flush()

        f1.value = 'f2'

        statements = [
            # note that the assertsql tests the rule against
            # "default" - on a "returning" backend, the statement
            # includes "RETURNING"
            CompiledSQL(
                "UPDATE version_table SET version_id=2, value=:value "
                "WHERE version_table.id = :version_table_id AND "
                "version_table.version_id = :version_table_version_id",
                lambda ctx: [
                    {
                        "version_table_id": 1,
                        "version_table_version_id": 1, "value": "f2"}]
            )
        ]
        if not testing.db.dialect.implicit_returning:
            # DBs without implicit returning, we must immediately
            # SELECT for the new version id
            statements.append(
                CompiledSQL(
                    "SELECT version_table.version_id "
                    "AS version_table_version_id "
                    "FROM version_table WHERE version_table.id = :param_1",
                    lambda ctx: [{"param_1": 1}]
                )
            )
        self.assert_sql_execution(testing.db, sess.flush, *statements)

    def test_delete_col(self):
        sess = self._fixture()

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)
        sess.flush()

        sess.delete(f1)

        statements = [
            # note that the assertsql tests the rule against
            # "default" - on a "returning" backend, the statement
            # includes "RETURNING"
            CompiledSQL(
                "DELETE FROM version_table "
                "WHERE version_table.id = :id AND "
                "version_table.version_id = :version_id",
                lambda ctx: [{"id": 1, "version_id": 1}]
            )
        ]
        self.assert_sql_execution(testing.db, sess.flush, *statements)

    def test_concurrent_mod_err_expire_on_commit(self):
        sess = self._fixture()

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)
        sess.commit()

        f1.value

        s2 = Session()
        f2 = s2.query(self.classes.Foo).first()
        f2.value = 'f2'
        s2.commit()

        f1.value = 'f3'

        assert_raises_message(
            orm.exc.StaleDataError,
            r"UPDATE statement on table 'version_table' expected to "
            r"update 1 row\(s\); 0 were matched.",
            sess.commit
        )

    def test_concurrent_mod_err_noexpire_on_commit(self):
        sess = self._fixture(expire_on_commit=False)

        f1 = self.classes.Foo(value='f1')
        sess.add(f1)
        sess.commit()

        # here, we're not expired overall, so no load occurs and we
        # stay without a version id, unless we've emitted
        # a SELECT for it within the flush.
        f1.value

        s2 = Session(expire_on_commit=False)
        f2 = s2.query(self.classes.Foo).first()
        f2.value = 'f2'
        s2.commit()

        f1.value = 'f3'

        assert_raises_message(
            orm.exc.StaleDataError,
            r"UPDATE statement on table 'version_table' expected to "
            r"update 1 row\(s\); 0 were matched.",
            sess.commit
        )


class ManualVersionTest(fixtures.MappedTest):
    run_define_tables = 'each'
    __backend__ = True

    @classmethod
    def define_tables(cls, metadata):
        Table(
            "a", metadata,
            Column(
                'id', Integer, primary_key=True,
                test_needs_autoincrement=True),
            Column('data', String(30)),
            Column('vid', Integer)
        )

    @classmethod
    def setup_classes(cls):
        class A(cls.Basic):
            pass

    @classmethod
    def setup_mappers(cls):
        mapper(
            cls.classes.A, cls.tables.a, version_id_col=cls.tables.a.c.vid,
            version_id_generator=False)

    def test_insert(self):
        sess = Session()
        a1 = self.classes.A()

        a1.vid = 1
        sess.add(a1)
        sess.commit()

        eq_(a1.vid, 1)

    def test_update(self):
        sess = Session()
        a1 = self.classes.A()

        a1.vid = 1
        a1.data = 'd1'
        sess.add(a1)
        sess.commit()

        a1.vid = 2
        a1.data = 'd2'

        sess.commit()

        eq_(a1.vid, 2)

    def test_update_concurrent_check(self):
        sess = Session()
        a1 = self.classes.A()

        a1.vid = 1
        a1.data = 'd1'
        sess.add(a1)
        sess.commit()

        a1.vid = 2
        sess.execute(self.tables.a.update().values(vid=3))
        a1.data = 'd2'
        assert_raises(
            orm_exc.StaleDataError,
            sess.commit
        )

    def test_update_version_conditional(self):
        sess = Session()
        a1 = self.classes.A()

        a1.vid = 1
        a1.data = 'd1'
        sess.add(a1)
        sess.commit()

        # change the data and UPDATE without
        # incrementing version id
        a1.data = 'd2'
        sess.commit()

        eq_(a1.vid, 1)

        a1.data = 'd3'
        a1.vid = 2
        sess.commit()

        eq_(a1.vid, 2)
