File: aiohttp_client.py

package info (click to toggle)
python-snitun 0.45.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 640 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (139 lines) | stat: -rw-r--r-- 4,243 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Helper for handle aiohttp internal server."""

from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine
from contextlib import suppress
import logging
import socket
import ssl
from typing import Any

from aiohttp.web import AppRunner, SockSite

from ..client.client_peer import ClientPeer
from ..client.connector import Connector
from . import DEFAULT_PROTOCOL_VERSION
from .asyncio import asyncio_timeout

_LOGGER = logging.getLogger(__name__)


class SniTunClientAioHttp:
    """Help to handle a internal aiohttp app runner."""

    def __init__(
        self,
        runner: AppRunner,
        context: ssl.SSLContext,
        snitun_server: str,
        snitun_port: int | None = None,
    ) -> None:
        """Initialize SniTunClient with aiohttp."""
        self._connector: Connector | None = None
        self._client = ClientPeer(snitun_server, snitun_port)
        self._socket = socket.socket()
        self._server_name = f"{snitun_server}:{snitun_port}"

        # Init interface
        self._socket.setblocking(False)
        self._socket.bind(("127.0.0.1", 0))
        self._site = SockSite(runner, self._socket, ssl_context=context)

    @property
    def is_connected(self) -> bool:
        """Return True if we are connected to snitun."""
        return self._client.is_connected

    @property
    def whitelist(self) -> set:
        """Return whitelist from connector."""
        if self._connector:
            return self._connector.whitelist
        return set()

    def wait(self) -> asyncio.Future[None]:
        """Block until connection to snitun is closed."""
        return self._client.wait()

    async def start(
        self,
        whitelist: bool = False,
        endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
        | None = None,
    ) -> None:
        """Start internal server."""
        await self._site.start()

        host, port = self._socket.getsockname()[:2]
        self._connector = Connector(
            host,
            port,
            whitelist,
            endpoint_connection_error_callback=endpoint_connection_error_callback,
        )

        _LOGGER.info("AioHTTP snitun client started on %s:%s", host, port)

    async def stop(self, *, wait: bool = False) -> None:
        """
        Stop internal server.

        Args:
            wait: wait for the socket to close.
        """
        await self.disconnect()
        with suppress(OSError):
            self._socket.close()

        with suppress(RuntimeError):
            self._site._runner._unreg_site(self._site)  # noqa: SLF001

        if wait:
            # Wait for the socket to close
            await _async_waitfor_socket_closed(self._socket)

        _LOGGER.info("AioHTTP snitun client closed")

    async def connect(
        self,
        fernet_key: bytes,
        aes_key: bytes,
        aes_iv: bytes,
        throttling: int | None = None,
        protocol_version: int = DEFAULT_PROTOCOL_VERSION,
    ) -> None:
        """Connect to SniTun server."""
        if self._client.is_connected:
            return
        assert self._connector is not None, "Connector is not initialized"
        await self._client.start(
            self._connector,
            fernet_key,
            aes_key,
            aes_iv,
            throttling=throttling,
            protocol_version=protocol_version,
        )
        _LOGGER.info("AioHTTP snitun client connected to: %s", self._server_name)

    async def disconnect(self) -> None:
        """Disconnect from SniTun server."""
        if not self._client.is_connected:
            return
        await self._client.stop()
        _LOGGER.info("AioHTTP snitun client disconnected from: %s", self._server_name)


async def _async_waitfor_socket_closed(sock: socket.socket | None = None) -> None:
    """Wait for the socket to be closed."""
    if sock is None:
        return
    loop = asyncio.get_event_loop()
    try:
        async with asyncio_timeout.timeout(60):
            while (await loop.run_in_executor(None, sock.fileno)) != -1:
                await asyncio.sleep(1)
    except TimeoutError:
        _LOGGER.warning("Timeout while waiting for the socket to close.")