1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
import ormar
import pytest
from ormar.databases.connection import DatabaseConnection
from sqlalchemy import text
from tests.lifespan import init_tests
from tests.settings import ASYNC_DATABASE_URL, DATABASE_URL, create_config
base_ormar_config = create_config()
class Team(ormar.Model):
ormar_config = base_ormar_config.copy(tablename="teams")
id: int = ormar.Integer(primary_key=True)
name: str = ormar.String(max_length=100)
create_test_database = init_tests(base_ormar_config)
def test_url_is_replaced_to_async_and_accessible():
database = DatabaseConnection(ASYNC_DATABASE_URL)
assert database.url == ASYNC_DATABASE_URL
expected_drivers = {
"mysql": "mysql+aiomysql",
"sqlite": "sqlite+aiosqlite",
"postgresql": "postgresql+asyncpg",
}
dialect = DATABASE_URL.split(":")[0]
assert dialect in expected_drivers
assert expected_drivers[dialect] in database.url
@pytest.mark.asyncio
async def test_getting_raw_connection():
database = DatabaseConnection(ASYNC_DATABASE_URL)
async with database:
async with database.connection() as conn:
result = await conn.execute(text("SELECT 123"))
async with database.transaction():
async with database.connection() as conn2:
assert conn2 != conn
async with database.connection() as conn3:
assert conn3 == conn2
async with database.connection() as conn4:
assert conn4 != conn
assert conn3 == conn4
async with database.connection() as conn5:
assert conn5 != conn4
assert conn5 != conn
assert result.fetchone()[0] == 123
@pytest.mark.asyncio
async def test_getting_commit_in_transaction():
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True) as tran:
assert tran._depth == 0
async with base_ormar_config.database.transaction() as tran2:
assert tran2._depth == 1
await Team.objects.create(name="Red Team")
await Team.objects.create(name="Blue Team")
async with base_ormar_config.database.transaction() as tran3:
assert tran3._depth == 2
await Team.objects.create(name="Yellow Team")
yellow = await Team.objects.get(name="Yellow Team")
yellow.name = "Green Team"
await yellow.update()
teams = await Team.objects.all()
assert len(teams) == 3
assert {teams.name for teams in teams} == {
"Red Team",
"Blue Team",
"Green Team",
}
@pytest.mark.asyncio
async def test_exception_in_transaction_rollbacks():
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(force_rollback=True):
async with base_ormar_config.database.transaction():
await Team.objects.create(name="Red Team")
await Team.objects.create(name="Blue Team")
try:
async with base_ormar_config.database.transaction() as tran2:
assert tran2._depth == 2
await Team.objects.create(name="Yellow Team")
raise ValueError("test")
except ValueError:
pass
teams = await Team.objects.all()
assert len(teams) == 2
assert {teams.name for teams in teams} == {
"Red Team",
"Blue Team",
}
@pytest.mark.asyncio
async def test_parent_rollback_cascades_to_children():
"""Test that rolling back parent transaction also rolls back child transactions."""
async with base_ormar_config.database:
try:
async with base_ormar_config.database.transaction() as parent:
assert parent._depth == 0
await Team.objects.create(name="Parent Team")
async with base_ormar_config.database.transaction() as child1:
assert child1._depth == 1
await Team.objects.create(name="Child Team 1")
async with base_ormar_config.database.transaction() as child2:
assert child2._depth == 1
await Team.objects.create(name="Child Team 2")
async with base_ormar_config.database.transaction() as grandchild:
assert grandchild._depth == 2
await Team.objects.create(name="Grandchild Team")
raise ValueError("rollback parent")
except ValueError:
pass
teams = await Team.objects.all()
assert len(teams) == 0
@pytest.mark.asyncio
async def test_force_rollback_cascades():
"""Test that force_rollback also cascades to children."""
async with base_ormar_config.database:
async with base_ormar_config.database.transaction(
force_rollback=True
) as parent:
assert parent._depth == 0
await Team.objects.create(name="Parent Team")
async with base_ormar_config.database.transaction() as child:
assert child._depth == 1
await Team.objects.create(name="Child Team")
teams = await Team.objects.all()
assert len(teams) == 0
|