1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
|
"""Test for SSL SNI proxy."""
from __future__ import annotations
import asyncio
import errno
import ipaddress
from typing import cast
from unittest.mock import patch
import pytest
from snitun.multiplexer.core import Multiplexer
from snitun.server.listener_sni import ProxyPeerHandler, SNIProxy
from snitun.server.peer import Peer
from snitun.server.peer_manager import PeerManager
from ..conftest import Client
from .const_tls import TLS_1_2
IP_ADDR = ipaddress.ip_address("127.0.0.1")
async def test_proxy_up_down() -> None:
"""Simple start stop of proxy."""
proxy = SNIProxy({}, "127.0.0.1", "8863")
await proxy.start()
await proxy.stop()
@pytest.mark.parametrize(
"payloads",
[
[TLS_1_2],
[TLS_1_2[:6], TLS_1_2[6:]],
[TLS_1_2[:6], TLS_1_2[6:20], TLS_1_2[20:]],
[TLS_1_2[:6], TLS_1_2[6:20], TLS_1_2[20:32], TLS_1_2[32:]],
],
)
async def test_sni_proxy_flow(
multiplexer_client: Multiplexer,
test_client_ssl: Client,
payloads: list[bytes],
) -> None:
"""Test a normal flow of connection and exchange data."""
for payload in payloads:
test_client_ssl.writer.write(payload)
await asyncio.sleep(0.1)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert multiplexer_client._channels
channel = next(iter(multiplexer_client._channels.values()))
assert channel.ip_address == IP_ADDR
client_hello = await channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await channel.read()
assert data == b"Very secret!"
await channel.write(b"my answer")
data = await test_client_ssl.reader.read(1024)
assert data == b"my answer"
async def test_sni_proxy_flow_close_by_client(
multiplexer_client: Multiplexer,
test_client_ssl: Client,
) -> None:
"""Test a normal flow of connection data and close by client."""
loop = asyncio.get_running_loop()
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert multiplexer_client._channels
channel = next(iter(multiplexer_client._channels.values()))
assert channel.ip_address == IP_ADDR
client_hello = await channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await channel.read()
assert data == b"Very secret!"
ssl_client_read = loop.create_task(test_client_ssl.reader.read(2024))
await asyncio.sleep(0.1)
assert not ssl_client_read.done()
multiplexer_client.delete_channel(channel)
await asyncio.sleep(0.1)
assert ssl_client_read.done()
async def test_sni_proxy_flow_close_by_server(
multiplexer_client: Multiplexer,
test_client_ssl: Client,
) -> None:
"""Test a normal flow of connection data and close by server."""
loop = asyncio.get_running_loop()
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert multiplexer_client._channels
channel = next(iter(multiplexer_client._channels.values()))
assert channel.ip_address == IP_ADDR
client_hello = await channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await channel.read()
assert data == b"Very secret!"
client_read = loop.create_task(channel.read())
await asyncio.sleep(0.1)
assert not client_read.done()
test_client_ssl.writer.close()
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
assert client_read.done()
async def test_sni_proxy_flow_peer_not(
peer: Peer,
multiplexer_client: Multiplexer,
test_client_ssl: Client,
) -> None:
"""Test a normal flow of connection with peer is not ready."""
peer._multiplexer = None # Fake peer state
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
async def test_sni_proxy_timeout(
multiplexer_client: Multiplexer,
test_client_ssl: Client,
raise_timeout: None,
) -> None:
"""Test a normal flow of connection and exchange data."""
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
async def test_sni_proxy_flow_timeout(
multiplexer_client: Multiplexer,
test_client_ssl: Client,
) -> None:
"""Test a normal flow of connection and exchange data."""
from snitun.server import listener_sni
listener_sni.TCP_SESSION_TIMEOUT = 0.2
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert multiplexer_client._channels
channel = next(iter(multiplexer_client._channels.values()))
assert channel.ip_address == IP_ADDR
client_hello = await channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await channel.read()
assert data == b"Very secret!"
await channel.write(b"my answer")
data = await test_client_ssl.reader.read(1024)
assert data == b"my answer"
await asyncio.sleep(0.3)
assert not multiplexer_client._channels
async def test_proxy_peer_handler_can_pause(
multiplexer_client: Multiplexer,
peer_manager: PeerManager,
) -> None:
"""Test proxy peer handler can pause."""
proxy_peer_handler: ProxyPeerHandler | None = None
loop = asyncio.get_running_loop()
def save_proxy_peer_handler(
loop: asyncio.AbstractEventLoop,
ip_address: ipaddress.IPv4Address,
) -> ProxyPeerHandler:
nonlocal proxy_peer_handler
proxy_peer_handler = ProxyPeerHandler(loop, ip_address)
return proxy_peer_handler
with patch("snitun.server.listener_sni.ProxyPeerHandler", save_proxy_peer_handler):
proxy = SNIProxy(peer_manager, "127.0.0.1", "8863")
await proxy.start()
reader, writer = await asyncio.open_connection(host="127.0.0.1", port="8863")
test_client_ssl = Client(reader, writer)
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert isinstance(proxy_peer_handler, ProxyPeerHandler)
handler = cast(ProxyPeerHandler, proxy_peer_handler)
client_channel = handler._channel
assert client_channel._pause_resume_reader_callback is not None
assert (
client_channel._pause_resume_reader_callback
== handler._pause_resume_reader_callback
)
assert multiplexer_client._channels
server_channel = next(iter(multiplexer_client._channels.values()))
assert server_channel.ip_address == IP_ADDR
client_hello = await server_channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await server_channel.read()
assert data == b"Very secret!"
# Now simulate that the remote input is under water
client_channel.on_remote_input_under_water(True)
assert handler._pause_future is not None
assert not handler._pause_future.done()
# This is an implementation detail that we might
# change in the future, but for now we need to
# to read one more message because we don't cancel
# the current read when the reader pauses as the additional
# complexity is not worth it.
test_client_ssl.writer.write(b"one more in before we pause")
await test_client_ssl.writer.drain()
data = await server_channel.read()
assert data == b"one more in before we pause"
test_client_ssl.writer.write(b"now we are paused")
await test_client_ssl.writer.drain()
read_task = loop.create_task(server_channel.read())
await asyncio.sleep(0.1)
# Make sure reader is actually paused
assert not read_task.done()
# Now simulate that the remote input is no longer under water
assert handler._pause_future is not None
assert not handler._pause_future.done()
client_channel.on_remote_input_under_water(False)
assert handler._pause_future is None
data = await read_task
assert data == b"now we are paused"
test_client_ssl.writer.close()
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
await proxy.stop()
async def test_proxy_peer_os_error_on_write(
multiplexer_client: Multiplexer,
peer_manager: PeerManager,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test proxy peer handler handles oserror."""
proxy_peer_handler: ProxyPeerHandler | None = None
class InstrumentedProxyPeerHandler(ProxyPeerHandler):
"""Instrumented Proxy Peer Handler.
This class is used to test the ProxyPeerHandler class
and save the reader and writer for testing.
"""
writer: asyncio.StreamWriter
reader: asyncio.StreamReader
async def start(
self,
multiplexer: Multiplexer,
client_hello: bytes,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
self.reader = reader
self.writer = writer
await super().start(multiplexer, client_hello, reader, writer)
def save_proxy_peer_handler(
loop: asyncio.AbstractEventLoop,
ip_address: ipaddress.IPv4Address,
) -> ProxyPeerHandler:
nonlocal proxy_peer_handler
proxy_peer_handler = InstrumentedProxyPeerHandler(loop, ip_address)
return proxy_peer_handler
with patch("snitun.server.listener_sni.ProxyPeerHandler", save_proxy_peer_handler):
proxy = SNIProxy(peer_manager, "127.0.0.1", "8863")
await proxy.start()
reader, writer = await asyncio.open_connection(host="127.0.0.1", port="8863")
test_client_ssl = Client(reader, writer)
test_client_ssl.writer.write(TLS_1_2)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)
assert isinstance(proxy_peer_handler, ProxyPeerHandler)
assert multiplexer_client._channels
server_channel = next(iter(multiplexer_client._channels.values()))
assert server_channel.ip_address == IP_ADDR
client_hello = await server_channel.read()
assert client_hello == TLS_1_2
test_client_ssl.writer.write(b"Very secret!")
await test_client_ssl.writer.drain()
data = await server_channel.read()
assert data == b"Very secret!"
with patch.object(
proxy_peer_handler.writer,
"write",
side_effect=OSError(errno.EPIPE, "Broken Pipe"),
):
await server_channel.write(b"some data that will trigger oserror")
await asyncio.sleep(0.1)
assert not multiplexer_client._channels
assert "Broken Pipe" in caplog.text
await proxy.stop()
|