1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
"""SniTun client for server connection."""
from __future__ import annotations
import asyncio
import hashlib
import logging
from ..exceptions import (
MultiplexerTransportDecrypt,
MultiplexerTransportError,
SniTunConnectionError,
)
from ..multiplexer.core import Multiplexer
from ..multiplexer.crypto import CryptoTransport
from ..utils import DEFAULT_PROTOCOL_VERSION
from ..utils.asyncio import asyncio_timeout, make_task_waiter_future
from .connector import Connector
_LOGGER = logging.getLogger(__name__)
CONNECTION_TIMEOUT = 60
class ClientPeer:
"""Client to SniTun Server."""
def __init__(self, snitun_host: str, snitun_port: int | None = None) -> None:
"""Initialize ClientPeer connector."""
self._multiplexer: Multiplexer | None = None
self._loop = asyncio.get_event_loop()
self._snitun_host = snitun_host
self._snitun_port = snitun_port or 8080
self._handler_task: asyncio.Task[None] | None = None
@property
def is_connected(self) -> bool:
"""Return true, if a connection exists."""
return self._multiplexer is not None
def wait(self) -> asyncio.Future[None]:
"""Block until connection to peer is closed."""
if not self._multiplexer or not self._handler_task:
raise RuntimeError("No SniTun connection available")
# Wait until the handler task is done
# as we know the connection is closed
return make_task_waiter_future(self._handler_task)
async def start(
self,
connector: Connector,
fernet_token: bytes,
aes_key: bytes,
aes_iv: bytes,
throttling: int | None = None,
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
) -> None:
"""Connect an start ClientPeer."""
if self._multiplexer:
raise RuntimeError("SniTun connection available")
# Connect to SniTun server
_LOGGER.debug(
"Opening connection to %s:%s",
self._snitun_host,
self._snitun_port,
)
try:
async with asyncio_timeout.timeout(CONNECTION_TIMEOUT):
reader, writer = await asyncio.open_connection(
host=self._snitun_host,
port=self._snitun_port,
)
except TimeoutError:
raise SniTunConnectionError(
"Connection timeout for SniTun server "
f"{self._snitun_host}:{self._snitun_port}",
) from None
except OSError as err:
raise SniTunConnectionError(
"Can't connect to SniTun server "
f"{self._snitun_host}:{self._snitun_port} with: {err}",
) from err
# Send fernet token
writer.write(fernet_token)
try:
async with asyncio_timeout.timeout(CONNECTION_TIMEOUT):
await writer.drain()
except TimeoutError:
raise SniTunConnectionError(
"Timeout for writting connection token",
) from None
# Challenge/Response
crypto = CryptoTransport(aes_key, aes_iv)
try:
async with asyncio_timeout.timeout(CONNECTION_TIMEOUT):
challenge = await reader.readexactly(32)
answer = hashlib.sha256(crypto.decrypt(challenge)).digest()
writer.write(crypto.encrypt(answer))
await writer.drain()
except TimeoutError:
raise SniTunConnectionError(
"Challenge/Response timeout error to SniTun server",
) from None
except (
MultiplexerTransportDecrypt,
asyncio.IncompleteReadError,
OSError,
) as err:
raise SniTunConnectionError(
f"Challenge/Response error with SniTun server ({err})",
) from err
# Run multiplexer
self._multiplexer = Multiplexer(
crypto,
reader,
writer,
# By default we always assume the server can handle the
# latest protocol version since the server is deployed
# before the client is updated in the wild, however
# we can override this if needed by passing a different
# protocol version.
protocol_version,
new_connections=connector.handler,
throttling=throttling,
)
# Task a process for pings/cleanups
assert not self._handler_task or self._handler_task.done(), (
"SniTun connection already running"
)
self._handler_task = self._loop.create_task(self._handler())
async def stop(self) -> None:
"""Stop connection to SniTun server."""
if not self._multiplexer:
raise RuntimeError("No SniTun connection available")
self._multiplexer.shutdown()
await self._multiplexer.wait()
await self._stop_handler()
async def _stop_handler(self) -> None:
"""Stop the handler."""
assert self._handler_task, "Handler task not started"
self._handler_task.cancel()
try:
await self._handler_task
except asyncio.CancelledError:
# Don't swallow cancellation
if (current_task := asyncio.current_task()) and current_task.cancelling():
raise
finally:
self._handler_task = None
async def _handler(self) -> None:
"""Wait until connection is closed."""
async def _wait_with_timeout(multiplexer: Multiplexer) -> None:
try:
async with asyncio_timeout.timeout(50):
await multiplexer.wait()
except TimeoutError:
await multiplexer.ping()
try:
while self._multiplexer and self._multiplexer.is_connected:
await _wait_with_timeout(self._multiplexer)
except MultiplexerTransportError:
pass
finally:
if self._multiplexer:
self._multiplexer.shutdown()
self._multiplexer = None
|