File: dispatcher.py

package info (click to toggle)
python-urllib3 2.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,340 kB
  • sloc: python: 26,167; makefile: 122; javascript: 92; sh: 11
file content (108 lines) | stat: -rw-r--r-- 4,306 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations

import asyncio
from functools import partial
from typing import Callable, Dict

from ..asyncio.task_group import TaskGroup
from ..typing import ASGIFramework, Scope

MAX_QUEUE_SIZE = 10


class _DispatcherMiddleware:
    def __init__(self, mounts: Dict[str, ASGIFramework]) -> None:
        self.mounts = mounts

    async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None:
        if scope["type"] == "lifespan":
            await self._handle_lifespan(scope, receive, send)
        else:
            for path, app in self.mounts.items():
                if scope["path"].startswith(path):
                    scope["path"] = scope["path"][len(path) :] or "/"
                    return await app(scope, receive, send)
            await send(
                {
                    "type": "http.response.start",
                    "status": 404,
                    "headers": [(b"content-length", b"0")],
                }
            )
            await send({"type": "http.response.body"})

    async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
        pass


class AsyncioDispatcherMiddleware(_DispatcherMiddleware):
    async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
        self.app_queues: Dict[str, asyncio.Queue] = {
            path: asyncio.Queue(MAX_QUEUE_SIZE) for path in self.mounts
        }
        self.startup_complete = {path: False for path in self.mounts}
        self.shutdown_complete = {path: False for path in self.mounts}

        async with TaskGroup(asyncio.get_event_loop()) as task_group:
            for path, app in self.mounts.items():
                task_group.spawn(
                    app,
                    scope,
                    self.app_queues[path].get,
                    partial(self.send, path, send),
                )

            while True:
                message = await receive()
                for queue in self.app_queues.values():
                    await queue.put(message)
                if message["type"] == "lifespan.shutdown":
                    break

    async def send(self, path: str, send: Callable, message: dict) -> None:
        if message["type"] == "lifespan.startup.complete":
            self.startup_complete[path] = True
            if all(self.startup_complete.values()):
                await send({"type": "lifespan.startup.complete"})
        elif message["type"] == "lifespan.shutdown.complete":
            self.shutdown_complete[path] = True
            if all(self.shutdown_complete.values()):
                await send({"type": "lifespan.shutdown.complete"})


class TrioDispatcherMiddleware(_DispatcherMiddleware):
    async def _handle_lifespan(self, scope: Scope, receive: Callable, send: Callable) -> None:
        import trio

        self.app_queues = {path: trio.open_memory_channel(MAX_QUEUE_SIZE) for path in self.mounts}
        self.startup_complete = {path: False for path in self.mounts}
        self.shutdown_complete = {path: False for path in self.mounts}

        async with trio.open_nursery() as nursery:
            for path, app in self.mounts.items():
                nursery.start_soon(
                    app,
                    scope,
                    self.app_queues[path][1].receive,
                    partial(self.send, path, send),
                )

            while True:
                message = await receive()
                for channels in self.app_queues.values():
                    await channels[0].send(message)
                if message["type"] == "lifespan.shutdown":
                    break

    async def send(self, path: str, send: Callable, message: dict) -> None:
        if message["type"] == "lifespan.startup.complete":
            self.startup_complete[path] = True
            if all(self.startup_complete.values()):
                await send({"type": "lifespan.startup.complete"})
        elif message["type"] == "lifespan.shutdown.complete":
            self.shutdown_complete[path] = True
            if all(self.shutdown_complete.values()):
                await send({"type": "lifespan.shutdown.complete"})


DispatcherMiddleware = AsyncioDispatcherMiddleware  # Remove with version 0.11