File: basic_auth_middleware.py

package info (click to toggle)
python-aiohttp 3.13.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 16,952 kB
  • sloc: python: 61,881; ansic: 20,773; makefile: 414; sh: 3
file content (190 lines) | stat: -rw-r--r-- 6,441 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#!/usr/bin/env python3
"""
Example of using basic authentication middleware with aiohttp client.

This example shows how to implement a middleware that automatically adds
Basic Authentication headers to all requests. The middleware encodes the
username and password in base64 format as required by the HTTP Basic Auth
specification.

This example includes a test server that validates basic auth credentials.
"""

import asyncio
import base64
import binascii
import logging

from aiohttp import (
    ClientHandlerType,
    ClientRequest,
    ClientResponse,
    ClientSession,
    hdrs,
    web,
)

logging.basicConfig(level=logging.DEBUG)
_LOGGER = logging.getLogger(__name__)


class BasicAuthMiddleware:
    """Middleware that adds Basic Authentication to all requests."""

    def __init__(self, username: str, password: str) -> None:
        self.username = username
        self.password = password
        self._auth_header = self._encode_credentials()

    def _encode_credentials(self) -> str:
        """Encode username and password to base64."""
        credentials = f"{self.username}:{self.password}"
        encoded = base64.b64encode(credentials.encode()).decode()
        return f"Basic {encoded}"

    async def __call__(
        self,
        request: ClientRequest,
        handler: ClientHandlerType,
    ) -> ClientResponse:
        """Add Basic Auth header to the request."""
        # Only add auth if not already present
        if hdrs.AUTHORIZATION not in request.headers:
            request.headers[hdrs.AUTHORIZATION] = self._auth_header

        # Proceed with the request
        return await handler(request)


class TestServer:
    """Test server for basic auth endpoints."""

    async def handle_basic_auth(self, request: web.Request) -> web.Response:
        """Handle basic auth validation."""
        # Get expected credentials from path
        expected_user = request.match_info["user"]
        expected_pass = request.match_info["pass"]

        # Check if Authorization header is present
        auth_header = request.headers.get(hdrs.AUTHORIZATION, "")

        if not auth_header.startswith("Basic "):
            return web.Response(
                status=401,
                text="Unauthorized",
                headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
            )

        # Decode the credentials
        encoded_creds = auth_header[6:]  # Remove "Basic "
        try:
            decoded = base64.b64decode(encoded_creds).decode()
            username, password = decoded.split(":", 1)
        except (ValueError, binascii.Error):
            return web.Response(
                status=401,
                text="Invalid credentials format",
                headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
            )

        # Validate credentials
        if username != expected_user or password != expected_pass:
            return web.Response(
                status=401,
                text="Invalid username or password",
                headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="test"'},
            )

        return web.json_response({"authenticated": True, "user": username})

    async def handle_protected_resource(self, request: web.Request) -> web.Response:
        """A protected resource that requires any valid auth."""
        auth_header = request.headers.get(hdrs.AUTHORIZATION, "")

        if not auth_header.startswith("Basic "):
            return web.Response(
                status=401,
                text="Authentication required",
                headers={hdrs.WWW_AUTHENTICATE: 'Basic realm="protected"'},
            )

        return web.json_response(
            {
                "message": "Access granted to protected resource",
                "auth_provided": True,
            }
        )


async def run_test_server() -> web.AppRunner:
    """Run a simple test server with basic auth endpoints."""
    app = web.Application()
    server = TestServer()

    app.router.add_get("/basic-auth/{user}/{pass}", server.handle_basic_auth)
    app.router.add_get("/protected", server.handle_protected_resource)

    runner = web.AppRunner(app)
    await runner.setup()
    site = web.TCPSite(runner, "localhost", 8080)
    await site.start()
    return runner


async def run_tests() -> None:
    """Run all basic auth middleware tests."""
    # Create middleware instance
    auth_middleware = BasicAuthMiddleware("user", "pass")

    # Use middleware in session
    async with ClientSession(middlewares=(auth_middleware,)) as session:
        # Test 1: Correct credentials endpoint
        print("=== Test 1: Correct credentials ===")
        async with session.get("http://localhost:8080/basic-auth/user/pass") as resp:
            _LOGGER.info("Status: %s", resp.status)

            if resp.status == 200:
                data = await resp.json()
                _LOGGER.info("Response: %s", data)
                print("Authentication successful!")
                print(f"Authenticated: {data.get('authenticated')}")
                print(f"User: {data.get('user')}")
            else:
                print("Authentication failed!")
                print(f"Status: {resp.status}")
                text = await resp.text()
                print(f"Response: {text}")

        # Test 2: Wrong credentials endpoint
        print("\n=== Test 2: Wrong credentials endpoint ===")
        async with session.get("http://localhost:8080/basic-auth/other/secret") as resp:
            if resp.status == 401:
                print("Authentication failed as expected (wrong credentials)")
                text = await resp.text()
                print(f"Response: {text}")
            else:
                print(f"Unexpected status: {resp.status}")

        # Test 3: Protected resource
        print("\n=== Test 3: Access protected resource ===")
        async with session.get("http://localhost:8080/protected") as resp:
            if resp.status == 200:
                data = await resp.json()
                print("Successfully accessed protected resource!")
                print(f"Response: {data}")
            else:
                print(f"Failed to access protected resource: {resp.status}")


async def main() -> None:
    # Start test server
    server = await run_test_server()

    try:
        await run_tests()
    finally:
        await server.cleanup()


if __name__ == "__main__":
    asyncio.run(main())