1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
|
"""
The functions in this module can be used for testing that the constraints of
your models. Each assert function runs SQL UPDATEs that check for the existence
of given constraint. Consider the following model::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(200), nullable=True)
email = sa.Column(sa.String(255), nullable=False)
user = User(name='John Doe', email='john@example.com')
session.add(user)
session.commit()
We can easily test the constraints by assert_* functions::
from sqlalchemy_utils import (
assert_nullable,
assert_non_nullable,
assert_max_length
)
assert_nullable(user, 'name')
assert_non_nullable(user, 'email')
assert_max_length(user, 'name', 200)
# raises AssertionError because the max length of email is 255
assert_max_length(user, 'email', 300)
"""
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.exc import DataError, IntegrityError
def _update_field(obj, field, value):
session = sa.orm.object_session(obj)
table = sa.inspect(obj.__class__).columns[field].table
query = table.update().values(**{field: value})
session.execute(query)
session.flush()
def _expect_successful_update(obj, field, value, reraise_exc):
try:
_update_field(obj, field, value)
except (reraise_exc) as e:
session = sa.orm.object_session(obj)
session.rollback()
assert False, str(e)
def _expect_failing_update(obj, field, value, expected_exc):
try:
_update_field(obj, field, value)
except expected_exc:
pass
else:
raise AssertionError('Expected update to raise %s' % expected_exc)
finally:
session = sa.orm.object_session(obj)
session.rollback()
def _repeated_value(type_):
if isinstance(type_, ARRAY):
if isinstance(type_.item_type, sa.Integer):
return [0]
elif isinstance(type_.item_type, sa.String):
return [u'a']
elif isinstance(type_.item_type, sa.Numeric):
return [Decimal('0')]
else:
raise TypeError('Unknown array item type')
else:
return u'a'
def _expected_exception(type_):
if isinstance(type_, ARRAY):
return IntegrityError
else:
return DataError
def assert_nullable(obj, column):
"""
Assert that given column is nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_successful_update(obj, column, None, IntegrityError)
def assert_non_nullable(obj, column):
"""
Assert that given column is not nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_failing_update(obj, column, None, IntegrityError)
def assert_max_length(obj, column, max_length):
"""
Assert that the given column is of given max length. This function supports
string typed columns as well as PostgreSQL array typed columns.
In the following example we add a check constraint that user can have a
maximum of 5 favorite colors and then test this.::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
__table_args__ = (
sa.CheckConstraint(
sa.func.array_length(favorite_colors, 1) <= 5
)
)
user = User(name='John Doe', favorite_colors=['red', 'blue'])
session.add(user)
session.commit()
assert_max_length(user, 'favorite_colors', 5)
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_length: Maximum length of given column
"""
type_ = sa.inspect(obj.__class__).columns[column].type
_expect_successful_update(
obj,
column,
_repeated_value(type_) * max_length,
_expected_exception(type_)
)
_expect_failing_update(
obj,
column,
_repeated_value(type_) * (max_length + 1),
_expected_exception(type_)
)
def assert_min_value(obj, column, min_value):
"""
Assert that the given column must have a minimum value of `min_value`.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param min_value: The minimum allowed value for given column
"""
_expect_successful_update(obj, column, min_value, IntegrityError)
_expect_failing_update(obj, column, min_value - 1, IntegrityError)
def assert_max_value(obj, column, min_value):
"""
Assert that the given column must have a minimum value of `max_value`.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_value: The maximum allowed value for given column
"""
_expect_successful_update(obj, column, min_value, IntegrityError)
_expect_failing_update(obj, column, min_value + 1, IntegrityError)
|