1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
import unittest
from sqlalchemy import Table, MetaData, Column, String, create_engine
from sqlalchemy.engine.reflection import Inspector
from alembic.util import CommandError
version_table = Table('version_table', MetaData(),
Column('version_num', String(32), nullable=False))
class TestMigrationContext(unittest.TestCase):
_bind = []
@property
def bind(self):
if not self._bind:
engine = create_engine('sqlite:///', echo=True)
self._bind.append(engine)
return self._bind[0]
def setUp(self):
self.connection = self.bind.connect()
self.transaction = self.connection.begin()
def tearDown(self):
version_table.drop(self.connection, checkfirst=True)
self.transaction.rollback()
def make_one(self, **kwargs):
from alembic.migration import MigrationContext
return MigrationContext.configure(**kwargs)
def get_revision(self):
result = self.connection.execute(version_table.select())
rows = result.fetchall()
if len(rows) == 0:
return None
self.assertEqual(len(rows), 1)
return rows[0]['version_num']
def test_config_default_version_table_name(self):
context = self.make_one(dialect_name='sqlite')
self.assertEqual(context._version.name, 'alembic_version')
def test_config_explicit_version_table_name(self):
context = self.make_one(dialect_name='sqlite',
opts={'version_table': 'explicit'})
self.assertEqual(context._version.name, 'explicit')
def test_get_current_revision_creates_version_table(self):
context = self.make_one(connection=self.connection,
opts={'version_table': 'version_table'})
self.assertEqual(context.get_current_revision(), None)
insp = Inspector(self.connection)
self.assertTrue('version_table' in insp.get_table_names())
def test_get_current_revision(self):
context = self.make_one(connection=self.connection,
opts={'version_table': 'version_table'})
version_table.create(self.connection)
self.assertEqual(context.get_current_revision(), None)
self.connection.execute(
version_table.insert().values(version_num='revid'))
self.assertEqual(context.get_current_revision(), 'revid')
def test_get_current_revision_error_if_starting_rev_given_online(self):
context = self.make_one(connection=self.connection,
opts={'starting_rev': 'boo'})
self.assertRaises(CommandError, context.get_current_revision)
def test_get_current_revision_offline(self):
context = self.make_one(dialect_name='sqlite',
opts={'starting_rev': 'startrev',
'as_sql': True})
self.assertEqual(context.get_current_revision(), 'startrev')
def test__update_current_rev(self):
version_table.create(self.connection)
context = self.make_one(connection=self.connection,
opts={'version_table': 'version_table'})
context._update_current_rev(None, 'a')
self.assertEqual(self.get_revision(), 'a')
context._update_current_rev('a', 'b')
self.assertEqual(self.get_revision(), 'b')
context._update_current_rev('b', None)
self.assertEqual(self.get_revision(), None)
|