1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
import ssl
import typing
from multiprocessing import Process
from unittest import mock
import anyio
from anyio import create_tcp_listener
from anyio.streams.tls import TLSListener
from tiny_proxy import (
HttpProxyHandler,
Socks5ProxyHandler,
Socks4ProxyHandler,
HttpProxy,
Socks4Proxy,
Socks5Proxy,
AbstractProxy,
)
from tests.mocks import getaddrinfo_async_mock
class ProxyConfig(typing.NamedTuple):
proxy_type: str
host: str
port: int
username: typing.Optional[str] = None
password: typing.Optional[str] = None
ssl_certfile: typing.Optional[str] = None
ssl_keyfile: typing.Optional[str] = None
def to_dict(self):
d = {}
for key, val in self._asdict().items():
if val is not None:
d[key] = val
return d
cls_map = {
'http': HttpProxyHandler,
'socks4': Socks4ProxyHandler,
'socks5': Socks5ProxyHandler,
}
def connect_to_remote_factory(cls: typing.Type[AbstractProxy]):
"""
simulate target host connection timeout
"""
origin_connect_to_remote = cls.connect_to_remote
async def new_connect_to_remote(self):
await anyio.sleep(0.01)
return await origin_connect_to_remote(self)
return new_connect_to_remote
@mock.patch.object(
HttpProxy,
attribute='connect_to_remote',
new=connect_to_remote_factory(HttpProxy),
)
@mock.patch.object(
Socks4Proxy,
attribute='connect_to_remote',
new=connect_to_remote_factory(Socks4Proxy),
)
@mock.patch.object(
Socks5Proxy,
attribute='connect_to_remote',
new=connect_to_remote_factory(Socks5Proxy),
)
@mock.patch('anyio._core._sockets.getaddrinfo', new=getaddrinfo_async_mock(anyio.getaddrinfo))
def start(
proxy_type,
host,
port,
ssl_certfile=None,
ssl_keyfile=None,
**kwargs,
):
handler_cls = cls_map.get(proxy_type)
if not handler_cls:
raise RuntimeError(f'Unsupported type: {proxy_type}')
if ssl_certfile and ssl_keyfile:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
ssl_context = None
print(f'Starting {proxy_type} proxy on {host}:{port}...')
handler = handler_cls(**kwargs)
async def serve():
listener = await create_tcp_listener(local_host=host, local_port=port)
if ssl_context is not None:
listener = TLSListener(listener=listener, ssl_context=ssl_context)
async with listener:
await listener.serve(handler.handle)
anyio.run(serve)
class ProxyServer:
workers: typing.List[Process]
def __init__(self, config: typing.Iterable[ProxyConfig]):
self.config = config
self.workers = []
def start(self):
for cfg in self.config:
print(
'Starting {} proxy on {}:{}; certfile={}, keyfile={}...'.format(
cfg.proxy_type,
cfg.host,
cfg.port,
cfg.ssl_certfile,
cfg.ssl_keyfile,
)
)
p = Process(target=start, kwargs=cfg.to_dict(), daemon=True)
self.workers.append(p)
for p in self.workers:
p.start()
def terminate(self):
for p in self.workers:
p.terminate()
|