File: using_abstract_authentication_middleware.py

package info (click to toggle)
litestar 2.19.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,500 kB
  • sloc: python: 70,169; makefile: 254; javascript: 105; sh: 60
file content (88 lines) | stat: -rw-r--r-- 3,013 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from dataclasses import dataclass
from typing import Any

import anyio

from litestar import Litestar, MediaType, Request, Response, WebSocket, get, websocket
from litestar.connection import ASGIConnection
from litestar.datastructures import State
from litestar.di import Provide
from litestar.exceptions import NotAuthorizedException, NotFoundException
from litestar.middleware import AbstractAuthenticationMiddleware, AuthenticationResult
from litestar.middleware.base import DefineMiddleware

API_KEY_HEADER = "X-API-KEY"

TOKEN_USER_DATABASE = {"1": "user_authorized"}


@dataclass
class MyUser:
    name: str


@dataclass
class MyToken:
    api_key: str


class CustomAuthenticationMiddleware(AbstractAuthenticationMiddleware):
    async def authenticate_request(self, connection: ASGIConnection) -> AuthenticationResult:
        """Given a request, parse the request api key stored in the header and retrieve the user correlating to the token from the DB"""

        # retrieve the auth header
        auth_header = connection.headers.get(API_KEY_HEADER)
        if not auth_header:
            raise NotAuthorizedException()

        # this would be a database call
        token = MyToken(api_key=auth_header)
        user = MyUser(name=TOKEN_USER_DATABASE.get(token.api_key))
        if not user.name:
            raise NotAuthorizedException()
        return AuthenticationResult(user=user, auth=token)


@get("/")
def my_http_handler(request: Request[MyUser, MyToken, State]) -> None:
    user = request.user  # correctly typed as MyUser
    auth = request.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)


@websocket("/")
async def my_ws_handler(socket: WebSocket[MyUser, MyToken, State]) -> None:
    user = socket.user  # correctly typed as MyUser
    auth = socket.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)


@get(path="/", exclude_from_auth=True)
async def site_index() -> Response:
    """Site index"""
    exists = await anyio.Path("index.html").exists()
    if exists:
        async with await anyio.open_file(anyio.Path("index.html")) as file:
            content = await file.read()
            return Response(content=content, status_code=200, media_type=MediaType.HTML)
    raise NotFoundException("Site index was not found")


async def my_dependency(request: Request[MyUser, MyToken, State]) -> Any:
    user = request.user  # correctly typed as MyUser
    auth = request.auth  # correctly typed as MyToken
    assert isinstance(user, MyUser)
    assert isinstance(auth, MyToken)


# you can optionally exclude certain paths from authentication.
# the following excludes all routes mounted at or under `/schema*`
auth_mw = DefineMiddleware(CustomAuthenticationMiddleware, exclude="schema")

app = Litestar(
    route_handlers=[site_index, my_http_handler, my_ws_handler],
    middleware=[auth_mw],
    dependencies={"some_dependency": Provide(my_dependency)},
)