File: test_example_sqla_pre_fetched_data.py

package info (click to toggle)
python-polyfactory 2.22.2-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (53 lines) | stat: -rw-r--r-- 1,697 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations

from sqlalchemy import ForeignKey, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

async_engine = create_async_engine("sqlite+aiosqlite:///:memory:")


class Base(DeclarativeBase): ...


class User(Base):
    __tablename__ = "users"

    id: Mapped[int] = mapped_column(primary_key=True)


class Department(Base):
    __tablename__ = "departments"

    id: Mapped[int] = mapped_column(primary_key=True)
    director_id: Mapped[str] = mapped_column(ForeignKey("users.id"))


class UserFactory(SQLAlchemyFactory[User]): ...


class DepartmentFactory(SQLAlchemyFactory[Department]): ...


async def get_director_ids() -> int:
    async with AsyncSession(async_engine) as session:
        result = (await session.scalars(select(User.id))).all()
        return UserFactory.__random__.choice(result)


async def test_factory_with_pre_fetched_async_data() -> None:
    async with async_engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)
        await conn.run_sync(Base.metadata.create_all)

    async with AsyncSession(async_engine) as session:
        UserFactory.__async_session__ = session
        await UserFactory.create_batch_async(3)

    async with AsyncSession(async_engine) as session:
        DepartmentFactory.__async_session__ = session
        department = await DepartmentFactory.create_async(director_id=await get_director_ids())
        user = await session.scalar(select(User).where(User.id == department.director_id))
        assert isinstance(user, User)