File: jwks_server.py

package info (click to toggle)
scitokens-cpp 1.3.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,172 kB
  • sloc: cpp: 11,717; ansic: 596; sh: 161; python: 132; makefile: 22
file content (160 lines) | stat: -rwxr-xr-x 5,800 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env python3
"""
Simple Python web server that hosts JWKS and supports OIDC discovery.
Used for integration testing of scitokens-cpp.
"""

import argparse
import json
import os
import signal
import socket
import ssl
import sys
from http.server import HTTPServer, BaseHTTPRequestHandler
from pathlib import Path


class JWKSHandler(BaseHTTPRequestHandler):
    """HTTP handler for JWKS and discovery endpoints."""
    
    # Use HTTP/1.1 for proper connection handling
    protocol_version = 'HTTP/1.1'

    def log_message(self, format, *args):
        """Override to log to file instead of stderr."""
        if hasattr(self.server, 'log_file'):
            with open(self.server.log_file, 'a') as f:
                f.write("%s - - [%s] %s\n" % (
                    self.address_string(),
                    self.log_date_time_string(),
                    format % args))

    def do_GET(self):
        """Handle GET requests for JWKS and discovery."""
        if self.path == '/.well-known/openid-configuration':
            self.serve_discovery()
        elif self.path == '/oauth2/certs' or self.path == '/jwks':
            self.serve_jwks()
        else:
            self.send_error(404, "Not Found")

    def serve_discovery(self):
        """Serve OIDC discovery document."""
        issuer = self.server.issuer_url
        discovery = {
            "issuer": issuer,
            "jwks_uri": f"{issuer}/oauth2/certs",
            "token_endpoint": f"{issuer}/token",
            "authorization_endpoint": f"{issuer}/authorize",
        }
        
        content = json.dumps(discovery).encode()
        self.send_response(200)
        self.send_header('Content-Type', 'application/json')
        self.send_header('Content-Length', str(len(content)))
        self.end_headers()
        self.wfile.write(content)

    def serve_jwks(self):
        """Serve JWKS document."""
        with open(self.server.jwks_file, 'r') as f:
            jwks_content = f.read()
        
        content = jwks_content.encode()
        self.send_response(200)
        self.send_header('Content-Type', 'application/json')
        self.send_header('Content-Length', str(len(content)))
        self.end_headers()
        self.wfile.write(content)


def main():
    parser = argparse.ArgumentParser(description='JWKS test server')
    parser.add_argument('--jwks', required=True, help='Path to JWKS file')
    parser.add_argument('--build-dir', required=True, help='Build directory')
    parser.add_argument('--test-name', default='integration', help='Test name')
    parser.add_argument('--cert', help='Path to TLS certificate file')
    parser.add_argument('--key', help='Path to TLS key file')
    args = parser.parse_args()

    # Determine if we're using HTTPS
    use_https = args.cert and args.key
    protocol = "https" if use_https else "http"
    
    # Create test directory
    test_dir = Path(args.build_dir) / 'tests' / args.test_name
    test_dir.mkdir(parents=True, exist_ok=True)
    
    # Create ready file to signal server is ready
    ready_file = test_dir / 'server_ready'
    log_file = test_dir / 'server.log'
    
    # Setup HTTP server - bind to port 0 to get a free port automatically
    server = HTTPServer(('localhost', 0), JWKSHandler)
    server.jwks_file = args.jwks
    server.log_file = str(log_file)
    
    # Setup TLS if certificates provided
    if use_https:
        context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
        context.load_cert_chain(args.cert, args.key)
        # Set minimum TLS version to 1.2 for security
        # Use ssl.TLSVersion for Python 3.7+, fall back to options for Python 3.6 (EL8)
        try:
            context.minimum_version = ssl.TLSVersion.TLSv1_2
        except AttributeError:
            # Python 3.6 doesn't have ssl.TLSVersion, use options instead
            context.options |= ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
        # Set cipher suites for OpenSSL 3.0.2 compatibility
        # SECLEVEL=1 allows 2048-bit RSA and SHA-1 for test certificates
        try:
            context.set_ciphers('DEFAULT:@SECLEVEL=1')
        except ssl.SSLError:
            # Fallback for older Python/OpenSSL
            context.set_ciphers('DEFAULT')
        # Disable TLS session tickets to avoid issues with session resumption
        context.options |= ssl.OP_NO_TICKET
        # Allow self-signed certificates for testing
        context.check_hostname = False
        context.verify_mode = ssl.CERT_NONE
        server.socket = context.wrap_socket(server.socket, server_side=True)
    
    # Get the actual port that was assigned
    port = server.server_address[1]
    issuer_url = f"{protocol}://localhost:{port}"
    server.issuer_url = issuer_url
    
    # Write server info to ready file
    with open(ready_file, 'w') as f:
        f.write(f"PID={os.getpid()}\n")
        f.write(f"ISSUER_URL={issuer_url}\n")
        f.write(f"PORT={port}\n")
    
    print(f"Server started on {issuer_url}", flush=True)
    print(f"Server PID: {os.getpid()}", flush=True)
    print(f"Server ready file: {ready_file}", flush=True)
    
    # Handle shutdown gracefully - set a flag that will be checked
    shutdown_requested = [False]
    
    def signal_handler(signum, frame):
        print("Shutting down server...", flush=True)
        shutdown_requested[0] = True
        # Shutdown needs to be called from a different thread or we need to exit
        # Using os._exit to immediately terminate
        os._exit(0)
    
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)
    
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass
    finally:
        server.server_close()


if __name__ == '__main__':
    main()