1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
import os
import socket
import socketserver
import ssl
import threading
import typing
import unittest
from pygopherd import testutil
from pygopherd.server import BaseServer, ForkingTCPServer, ThreadingTCPServer
crt_file = os.path.join(testutil.TEST_DATA, "demo.crt")
key_file = os.path.join(testutil.TEST_DATA, "demo.key")
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(crt_file, key_file)
class EchoHandler(socketserver.StreamRequestHandler):
def handle(self):
data = self.rfile.readline()
self.wfile.write(data)
class ServerTestCase(unittest.TestCase):
server_class: typing.Type[BaseServer]
server: BaseServer
thread: threading.Thread
@classmethod
def setUpClass(cls):
"""
Spin up a test server in a separate thread.
"""
config = testutil.get_config()
server_address = ("localhost", 0)
cls.server = cls.server_class(
config, server_address, EchoHandler, context=context
)
cls.thread = threading.Thread(target=cls.server.serve_forever)
cls.thread.start()
@classmethod
def tearDownClass(cls):
cls.server.shutdown()
cls.thread.join(timeout=5)
class ThreadingTCPServerTestCase(ServerTestCase):
server_class = ThreadingTCPServer
def test_send_data(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect(self.server.server_address)
sock.sendall(b"Hello World\n")
self.assertEqual(sock.recv(4096), b"Hello World\n")
def test_send_data_tls(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with ssl.wrap_socket(sock) as ssock:
ssock.connect(self.server.server_address)
ssock.sendall(b"Hello World\n")
self.assertEqual(ssock.recv(4096), b"Hello World\n")
class ForkingTCPServerTestCase(ServerTestCase):
server_class = ForkingTCPServer
def test_send_data(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect(self.server.server_address)
sock.sendall(b"Hello World\n")
self.assertEqual(sock.recv(4096), b"Hello World\n")
def test_send_data_tls(self):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with ssl.wrap_socket(sock) as ssock:
ssock.connect(self.server.server_address)
ssock.sendall(b"Hello World\n")
self.assertEqual(ssock.recv(4096), b"Hello World\n")
|