File: sqlalchemy_repository_extension.py

package info (click to toggle)
litestar 2.21.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,568 kB
  • sloc: python: 70,588; makefile: 254; javascript: 104; sh: 60
file content (188 lines) | stat: -rw-r--r-- 6,050 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from __future__ import annotations

import random
import re
import string
import unicodedata
from typing import TYPE_CHECKING, Any
from uuid import UUID

from pydantic import BaseModel as _BaseModel
from pydantic import TypeAdapter
from sqlalchemy.orm import Mapped, declarative_mixin, mapped_column
from sqlalchemy.types import String

from litestar import Litestar, get, post
from litestar.di import Provide
from litestar.plugins.sqlalchemy import (
    AsyncSessionConfig,
    SQLAlchemyAsyncConfig,
    SQLAlchemyInitPlugin,
    base,
    repository,
)

if TYPE_CHECKING:
    from sqlalchemy.ext.asyncio import AsyncSession


class BaseModel(_BaseModel):
    """Extend Pydantic's BaseModel to enable ORM mode"""

    model_config = {"from_attributes": True}


# we are going to add a simple "slug" to our model that is a URL safe surrogate key to
# our database record.
@declarative_mixin
class SlugKey:
    """Slug unique Field Model Mixin."""

    __abstract__ = True
    slug: Mapped[str] = mapped_column(String(length=100), nullable=False, unique=True, sort_order=-9)


# this class can be re-used with any model that has the `SlugKey` Mixin
class SQLAlchemyAsyncSlugRepository(repository.SQLAlchemyAsyncRepository[repository.ModelT]):
    """Extends the repository to include slug model features.."""

    async def get_available_slug(
        self,
        value_to_slugify: str,
        **kwargs: Any,
    ) -> str:
        """Get a unique slug for the supplied value.

        If the value is found to exist, a random 4 digit character is appended to the end.
        There may be a better way to do this, but I wanted to limit the number of
        additional database calls.

        Args:
            value_to_slugify (str): A string that should be converted to a unique slug.
            **kwargs: stuff

        Returns:
            str: a unique slug for the supplied value. This is safe for URLs and other
            unique identifiers.
        """
        slug = self._slugify(value_to_slugify)
        if await self._is_slug_unique(slug):
            return slug
        # generate a random 4 digit alphanumeric string to make the slug unique and
        # avoid another DB lookup.
        random_string = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
        return f"{slug}-{random_string}"

    @staticmethod
    def _slugify(value: str) -> str:
        """slugify.

        Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
        dashes to single dashes. Remove characters that aren't alphanumerics,
        underscores, or hyphens. Convert to lowercase. Also strip leading and
        trailing whitespace, dashes, and underscores.

        Args:
            value (str): the string to slugify

        Returns:
            str: a slugified string of the value parameter
        """
        value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
        value = re.sub(r"[^\w\s-]", "", value.lower())
        return re.sub(r"[-\s]+", "-", value).strip("-_")

    async def _is_slug_unique(
        self,
        slug: str,
        **kwargs: Any,
    ) -> bool:
        return await self.get_one_or_none(slug=slug) is None


# The `UUIDAuditBase` class includes the same UUID` based primary key (`id`) and 2
# additional columns: `created_at` and `updated_at`. `created_at` is a timestamp of when the
# record created, and `updated_at` is the last time the record was modified.
class BlogPost(base.UUIDAuditBase, SlugKey):
    title: Mapped[str]
    content: Mapped[str]


class BlogPostRepository(SQLAlchemyAsyncSlugRepository[BlogPost]):
    """Blog Post repository."""

    model_type = BlogPost


class BlogPostDTO(BaseModel):
    id: UUID | None
    slug: str
    title: str
    content: str


class BlogPostCreate(BaseModel):
    title: str
    content: str


# we can optionally override the default `select` used for the repository to pass in
# specific SQL options such as join details
async def provide_blog_post_repo(db_session: AsyncSession) -> BlogPostRepository:
    """This provides a simple example demonstrating how to override the join options
    for the repository."""
    return BlogPostRepository(session=db_session)


session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
    connection_string="sqlite+aiosqlite:///test.sqlite", session_config=session_config
)  # Create 'async_session' dependency.
sqlalchemy_plugin = SQLAlchemyInitPlugin(config=sqlalchemy_config)


async def on_startup() -> None:
    """Initializes the database."""
    async with sqlalchemy_config.get_engine().begin() as conn:
        await conn.run_sync(base.UUIDAuditBase.metadata.create_all)


@get(path="/")
async def get_blogs(
    blog_post_repo: BlogPostRepository,
) -> list[BlogPostDTO]:
    """Interact with SQLAlchemy engine and session."""
    objs = await blog_post_repo.list()
    type_adapter = TypeAdapter(list[BlogPostDTO])
    return type_adapter.validate_python(objs)


@get(path="/{post_slug:str}")
async def get_blog_details(
    post_slug: str,
    blog_post_repo: BlogPostRepository,
) -> BlogPostDTO:
    """Interact with SQLAlchemy engine and session."""
    obj = await blog_post_repo.get_one(slug=post_slug)
    return BlogPostDTO.model_validate(obj)


@post(path="/")
async def create_blog(
    blog_post_repo: BlogPostRepository,
    data: BlogPostCreate,
) -> BlogPostDTO:
    """Create a new blog post."""
    _data = data.model_dump(exclude_unset=True, by_alias=False, exclude_none=True)
    _data["slug"] = await blog_post_repo.get_available_slug(_data["title"])
    obj = await blog_post_repo.add(BlogPost(**_data))
    await blog_post_repo.session.commit()
    return BlogPostDTO.model_validate(obj)


app = Litestar(
    route_handlers=[create_blog, get_blogs, get_blog_details],
    dependencies={"blog_post_repo": Provide(provide_blog_post_repo, sync_to_thread=False)},
    on_startup=[on_startup],
    plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)],
)