1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
from pathlib import Path
from typing import Any
from litestar import Litestar, get
from litestar.testing import TestClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, Session
from advanced_alchemy._listeners import is_async_context
from advanced_alchemy.base import BigIntPrimaryKey
from advanced_alchemy.extensions.litestar import (
SQLAlchemyAsyncConfig,
SQLAlchemyInitPlugin,
SQLAlchemySyncConfig,
)
# Test Function
def test_litestar_is_async_context(tmp_path: Path) -> None:
"""Test that is_async_context is set correctly in Litestar dependency injection."""
db_path = tmp_path / "litestar_context_test.db"
class Base(DeclarativeBase):
pass
class SyncModel(BigIntPrimaryKey, Base): # type: ignore
__tablename__ = "sync_model_litestar_test"
name: Mapped[str]
class AsyncModel(BigIntPrimaryKey, Base): # type: ignore
__tablename__ = "async_model_litestar_test"
name: Mapped[str]
@get("/sync")
def sync_route(db_session: Session) -> dict[str, Any]:
instance = db_session.execute(select(SyncModel).where(SyncModel.id == 1)).scalar_one()
return {"id": instance.id, "name": instance.name, "is_async_context": is_async_context()}
@get("/async")
async def async_route(db_session: AsyncSession) -> dict[str, Any]:
instance = await db_session.execute(select(AsyncModel).where(AsyncModel.id == 1))
scalar_instance = instance.scalar_one()
return {"id": scalar_instance.id, "name": scalar_instance.name, "is_async_context": is_async_context()}
# Sync App
sync_config = SQLAlchemySyncConfig(connection_string=f"sqlite:///{db_path}")
sync_plugin = SQLAlchemyInitPlugin(config=sync_config)
@get("/test_sync")
def sync_handler(db_session: Session) -> dict[str, Any]:
# Perform a dummy operation if needed (e.g., db_session.execute(select(1)))
return {"is_async": is_async_context()}
sync_app = Litestar(route_handlers=[sync_handler], plugins=[sync_plugin])
# Create tables for sync app
with sync_config.get_engine().begin() as conn:
Base.metadata.create_all(conn)
with TestClient(app=sync_app) as sync_client:
response = sync_client.get("/test_sync")
assert response.status_code == 200
assert response.json() == {"is_async": False}
# Async App
async_config = SQLAlchemyAsyncConfig(connection_string=f"sqlite+aiosqlite:///{db_path}")
async_plugin = SQLAlchemyInitPlugin(config=async_config)
@get("/test_async")
async def async_handler(db_session: AsyncSession) -> dict[str, Any]:
# Perform a dummy operation if needed (e.g., await db_session.execute(select(1)))
return {"is_async": is_async_context()}
async_app = Litestar(route_handlers=[async_handler], plugins=[async_plugin])
# Create tables for async app (needs async context)
async def create_async_tables() -> None:
async with async_config.get_engine().begin() as conn:
await conn.run_sync(Base.metadata.create_all)
import asyncio
asyncio.run(create_async_tables())
with TestClient(app=async_app) as async_client:
response = async_client.get("/test_async")
assert response.status_code == 200
assert response.json() == {"is_async": True}
def test_plugin_is_async_context(tmp_path: Path) -> None:
"""Test that is_async_context is set correctly via plugin dependency injection."""
db_path = tmp_path / "litestar_plugin_context.db"
class Base(DeclarativeBase):
pass
class SyncModel(BigIntPrimaryKey, Base): # type: ignore
__tablename__ = "sync_model_litestar_test"
name: Mapped[str]
class AsyncModel(BigIntPrimaryKey, Base): # type: ignore
__tablename__ = "async_model_litestar_test"
name: Mapped[str]
@get("/sync")
def sync_route(db_session: Session) -> dict[str, Any]:
instance = db_session.execute(select(SyncModel).where(SyncModel.id == 1)).scalar_one()
return {"id": instance.id, "name": instance.name, "is_async_context": is_async_context()}
@get("/async")
async def async_route(db_session: AsyncSession) -> dict[str, Any]:
instance = await db_session.execute(select(AsyncModel).where(AsyncModel.id == 1))
scalar_instance = instance.scalar_one()
return {"id": scalar_instance.id, "name": scalar_instance.name, "is_async_context": is_async_context()}
# Sync App
sync_config = SQLAlchemySyncConfig(connection_string=f"sqlite:///{db_path}")
sync_plugin = SQLAlchemyInitPlugin(config=sync_config)
@get("/test_sync_plugin")
def sync_plugin_handler(db_session: Session) -> dict[str, Any]: # type: ignore[arg-type]
return {"is_async": is_async_context()}
sync_app = Litestar(route_handlers=[sync_plugin_handler], plugins=[sync_plugin])
# Create tables for sync app
with sync_config.get_engine().begin() as conn:
Base.metadata.create_all(conn)
with TestClient(app=sync_app) as sync_client:
response = sync_client.get("/test_sync_plugin")
assert response.status_code == 200
assert response.json() == {"is_async": False}
# Async App
async_config = SQLAlchemyAsyncConfig(connection_string=f"sqlite+aiosqlite:///{db_path}")
async_plugin = SQLAlchemyInitPlugin(config=async_config)
@get("/test_async_plugin")
async def async_plugin_handler(db_session: AsyncSession) -> dict[str, Any]: # type: ignore[arg-type]
return {"is_async": is_async_context()}
async_app = Litestar(route_handlers=[async_plugin_handler], plugins=[async_plugin])
# Create tables for async app
async def create_async_tables() -> None:
async with async_config.get_engine().begin() as conn:
await conn.run_sync(Base.metadata.create_all)
import asyncio
asyncio.run(create_async_tables())
with TestClient(app=async_app) as async_client:
response = async_client.get("/test_async_plugin")
assert response.status_code == 200
assert response.json() == {"is_async": True}
|