File: client_middleware_cookbook.py

package info (click to toggle)
python-aiohttp 3.13.3-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 17,020 kB
  • sloc: python: 62,860; ansic: 20,773; makefile: 429; sh: 3
file content (143 lines) | stat: -rw-r--r-- 5,000 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""This is a collection of semi-complete examples that get included into the cookbook page."""

import asyncio
import logging
import time
from collections.abc import AsyncIterator, Sequence
from contextlib import asynccontextmanager, suppress

from aiohttp import (
    ClientError,
    ClientHandlerType,
    ClientRequest,
    ClientResponse,
    ClientSession,
    TCPConnector,
)
from aiohttp.abc import ResolveResult
from aiohttp.tracing import Trace


class SSRFError(ClientError):
    """A request was made to a blacklisted host."""


async def retry_middleware(
    req: ClientRequest, handler: ClientHandlerType
) -> ClientResponse:
    for _ in range(3):  # Try up to 3 times
        resp = await handler(req)
        if resp.ok:
            return resp
    return resp


async def api_logging_middleware(
    req: ClientRequest, handler: ClientHandlerType
) -> ClientResponse:
    # We use middlewares=() to avoid infinite recursion.
    async with req.session.post("/log", data=req.url.host, middlewares=()) as resp:
        if not resp.ok:
            logging.warning("Log endpoint failed")

    return await handler(req)


class TokenRefresh401Middleware:
    def __init__(self, refresh_token: str, access_token: str):
        self.access_token = access_token
        self.refresh_token = refresh_token
        self.lock = asyncio.Lock()

    async def __call__(
        self, req: ClientRequest, handler: ClientHandlerType
    ) -> ClientResponse:
        for _ in range(2):  # Retry at most one time
            token = self.access_token
            req.headers["Authorization"] = f"Bearer {token}"
            resp = await handler(req)
            if resp.status != 401:
                return resp
            async with self.lock:
                if token != self.access_token:  # Already refreshed
                    continue
                url = "https://api.example/refresh"
                async with req.session.post(url, data=self.refresh_token) as resp:
                    # Add error handling as needed
                    data = await resp.json()
                    self.access_token = data["access_token"]
        return resp


class TokenRefreshExpiryMiddleware:
    def __init__(self, refresh_token: str):
        self.access_token = ""
        self.expires_at = 0
        self.refresh_token = refresh_token
        self.lock = asyncio.Lock()

    async def __call__(
        self, req: ClientRequest, handler: ClientHandlerType
    ) -> ClientResponse:
        if self.expires_at <= time.time():
            token = self.access_token
            async with self.lock:
                if token == self.access_token:  # Still not refreshed
                    url = "https://api.example/refresh"
                    async with req.session.post(url, data=self.refresh_token) as resp:
                        # Add error handling as needed
                        data = await resp.json()
                        self.access_token = data["access_token"]
                        self.expires_at = data["expires_at"]

        req.headers["Authorization"] = f"Bearer {self.access_token}"
        return await handler(req)


async def token_refresh_preemptively_example() -> None:
    async def set_token(session: ClientSession, event: asyncio.Event) -> None:
        while True:
            async with session.post("/refresh") as resp:
                token = await resp.json()
                session.headers["Authorization"] = f"Bearer {token['auth']}"
                event.set()
                await asyncio.sleep(token["valid_duration"])

    @asynccontextmanager
    async def auto_refresh_client() -> AsyncIterator[ClientSession]:
        async with ClientSession() as session:
            ready = asyncio.Event()
            t = asyncio.create_task(set_token(session, ready))
            await ready.wait()
            yield session
            t.cancel()
            with suppress(asyncio.CancelledError):
                await t

    async with auto_refresh_client() as sess:
        ...


async def ssrf_middleware(
    req: ClientRequest, handler: ClientHandlerType
) -> ClientResponse:
    # WARNING: This is a simplified example for demonstration purposes only.
    # A complete implementation should also check:
    # - IPv6 loopback (::1)
    # - Private IP ranges (10.x.x.x, 192.168.x.x, 172.16-31.x.x)
    # - Link-local addresses (169.254.x.x, fe80::/10)
    # - Other internal hostnames and aliases
    if req.url.host in {"127.0.0.1", "localhost"}:
        raise SSRFError(req.url.host)
    return await handler(req)


class SSRFConnector(TCPConnector):
    async def _resolve_host(
        self, host: str, port: int, traces: Sequence[Trace] | None = None
    ) -> list[ResolveResult]:
        res = await super()._resolve_host(host, port, traces)
        # WARNING: This is a simplified example - should also check ::1, private ranges, etc.
        if any(r["host"] in {"127.0.0.1"} for r in res):
            raise SSRFError()
        return res