File: connector.py

package info (click to toggle)
python-snitun 0.45.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 664 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (177 lines) | stat: -rw-r--r-- 6,124 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Connector to end resource."""

from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine
from contextlib import suppress
import ipaddress
from ipaddress import IPv4Address
import logging
from typing import Any

from ..exceptions import MultiplexerTransportClose, MultiplexerTransportError
from ..multiplexer.channel import ChannelFlowControlBase, MultiplexerChannel
from ..multiplexer.core import Multiplexer

_LOGGER = logging.getLogger(__name__)


class Connector:
    """Connector to end resource."""

    def __init__(
        self,
        end_host: str,
        end_port: int | None = None,
        whitelist: bool = False,
        endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
        | None = None,
    ) -> None:
        """Initialize Connector."""
        self._loop = asyncio.get_event_loop()
        self._end_host = end_host
        self._end_port = end_port or 443
        self._whitelist: set[IPv4Address] = set()
        self._whitelist_enabled = whitelist
        self._endpoint_connection_error_callback = endpoint_connection_error_callback

    @property
    def whitelist(self) -> set:
        """Allow to block requests per IP Return None or access to a set."""
        return self._whitelist

    def _whitelist_policy(self, ip_address: ipaddress.IPv4Address) -> bool:
        """Return True if the ip address can access to endpoint."""
        if self._whitelist_enabled:
            return ip_address in self._whitelist
        return True

    async def handler(
        self,
        multiplexer: Multiplexer,
        channel: MultiplexerChannel,
    ) -> None:
        """Handle new connection from SNIProxy."""
        _LOGGER.debug(
            "Receive from %s a request for %s",
            channel.ip_address,
            self._end_host,
        )

        # Check policy
        if not self._whitelist_policy(channel.ip_address):
            _LOGGER.warning("Block request from %s per policy", channel.ip_address)
            multiplexer.delete_channel(channel)
            return

        await ConnectorHandler(self._loop, channel).start(
            multiplexer,
            self._end_host,
            self._end_port,
            self._endpoint_connection_error_callback,
        )


class ConnectorHandler(ChannelFlowControlBase):
    """Handle connection to endpoint."""

    def __init__(
        self,
        loop: asyncio.AbstractEventLoop,
        channel: MultiplexerChannel,
    ) -> None:
        """Initialize ConnectorHandler."""
        super().__init__(loop)
        self._channel = channel

    async def start(
        self,
        multiplexer: Multiplexer,
        end_host: str,
        end_port: int,
        endpoint_connection_error_callback: Callable[[], Coroutine[Any, Any, None]]
        | None = None,
    ) -> None:
        """Start handler."""
        channel = self._channel
        channel.set_pause_resume_reader_callback(self._pause_resume_reader_callback)
        # Open connection to endpoint
        try:
            reader, writer = await asyncio.open_connection(host=end_host, port=end_port)
        except OSError:
            _LOGGER.error(
                "Can't connect to endpoint %s:%s",
                end_host,
                end_port,
            )
            multiplexer.delete_channel(channel)
            if endpoint_connection_error_callback:
                await endpoint_connection_error_callback()
            return

        from_endpoint: asyncio.Future[None] | asyncio.Task[bytes] | None = None
        from_peer: asyncio.Task[bytes] | None = None
        try:
            # Process stream from multiplexer
            while not writer.transport.is_closing():
                if not from_endpoint:
                    # If the multiplexer channel queue is under water, pause the reader
                    # by waiting for the future to be set, once the queue is not under
                    # water the future will be set and cleared to resume the reader
                    from_endpoint = self._pause_future or self._loop.create_task(
                        reader.read(4096),  # type: ignore[arg-type]
                    )
                if not from_peer:
                    from_peer = self._loop.create_task(channel.read())

                # Wait until data need to be processed
                await asyncio.wait(
                    [from_endpoint, from_peer],
                    return_when=asyncio.FIRST_COMPLETED,
                )

                # From proxy
                if from_endpoint.done():
                    if from_endpoint_exc := from_endpoint.exception():
                        raise from_endpoint_exc

                    if (from_endpoint_result := from_endpoint.result()) is not None:
                        await channel.write(from_endpoint_result)
                    from_endpoint = None

                # From peer
                if from_peer.done():
                    if from_peer_exc := from_peer.exception():
                        raise from_peer_exc

                    writer.write(from_peer.result())
                    from_peer = None

                    # Flush buffer
                    await writer.drain()

        except (MultiplexerTransportError, OSError, RuntimeError):
            _LOGGER.debug("Transport closed by endpoint for %s", channel.id)
            multiplexer.delete_channel(channel)

        except MultiplexerTransportClose:
            _LOGGER.debug("Peer close connection for %s", channel.id)

        finally:
            # Cleanup peer reader
            if from_peer:
                if not from_peer.done():
                    from_peer.cancel()
                else:
                    # Avoid exception was never retrieved
                    from_peer.exception()

            # Cleanup endpoint reader
            if from_endpoint and not from_endpoint.done():
                from_endpoint.cancel()

            # Close Transport
            if not writer.transport.is_closing():
                with suppress(OSError):
                    writer.close()