File: channel.py

package info (click to toggle)
python-snitun 0.45.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 640 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (300 lines) | stat: -rw-r--r-- 10,889 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""Multiplexer channel."""

from __future__ import annotations

import asyncio
from collections.abc import Callable
from contextlib import suppress
from ipaddress import IPv4Address
import logging
import os

from ..exceptions import MultiplexerTransportClose, MultiplexerTransportError
from ..utils.asyncio import asyncio_timeout
from ..utils.ipaddress import ip_address_to_bytes
from .const import (
    INCOMING_QUEUE_HIGH_WATERMARK,
    INCOMING_QUEUE_LOW_WATERMARK,
    INCOMING_QUEUE_MAX_BYTES_CHANNEL,
    INCOMING_QUEUE_MAX_BYTES_CHANNEL_V0,
)
from .message import (
    CHANNEL_FLOW_CLOSE,
    CHANNEL_FLOW_DATA,
    CHANNEL_FLOW_NEW,
    CHANNEL_FLOW_PAUSE,
    CHANNEL_FLOW_RESUME,
    MIN_PROTOCOL_VERSION_FOR_PAUSE_RESUME,
    MultiplexerChannelId,
    MultiplexerMessage,
)
from .queue import MultiplexerMultiChannelQueue, MultiplexerSingleChannelQueue

_LOGGER = logging.getLogger(__name__)


class ChannelFlowControlBase:
    """A channel that implements flow control."""

    _channel: MultiplexerChannel

    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
        """Initialize a channel that implements flow control."""
        self._loop = loop
        self._pause_future: asyncio.Future[None] | None = None
        self._debug = _LOGGER.isEnabledFor(logging.DEBUG)

    def _pause_resume_reader_callback(self, pause: bool) -> None:
        """Pause and resume reader."""
        channel = self._channel
        ip_address = channel.ip_address
        id_ = channel.id
        if not pause:
            if self._pause_future and not self._pause_future.done():
                if self._debug:
                    _LOGGER.debug("Resuming reader for %s (%s)", ip_address, id_)
                self._pause_future.set_result(None)
                self._pause_future = None
                return
            # Already resumed - this is idempotent, no error needed
            if self._debug:
                _LOGGER.debug(
                    "Reader already resumed for %s (%s), ignoring",
                    ip_address,
                    id_,
                )
            return

        if self._pause_future is None or self._pause_future.done():
            if self._debug:
                _LOGGER.debug("Pause reader for %s (%s)", ip_address, id_)
            self._pause_future = self._loop.create_future()
            return

        # Already paused - this is idempotent, no error needed
        if self._debug:
            _LOGGER.debug(
                "Reader already paused for %s (%s), ignoring",
                ip_address,
                id_,
            )


class MultiplexerChannel:
    """Represent a multiplexer channel."""

    __slots__ = (
        "_closing",
        "_debug",
        "_id",
        "_input",
        "_ip_address",
        "_local_output_under_water",
        "_output",
        "_pause_resume_reader_callback",
        "_peer_protocol_version",
        "_reader_paused",
        "_remote_input_under_water",
        "_throttling",
    )

    def __init__(
        self,
        output: MultiplexerMultiChannelQueue,
        ip_address: IPv4Address,
        peer_protocol_version: int,
        pause_resume_reader_callback: Callable[[bool], None] | None = None,
        channel_id: MultiplexerChannelId | None = None,
        throttling: float | None = None,
    ) -> None:
        """Initialize Multiplexer Channel."""
        if peer_protocol_version == 0:
            # For protocol version 0, we use a larger queue since
            # we can't tell the client to pause/resume reading.
            # This is a temporary solution until we can remove protocol version 0.
            incoming_queue_max_bytes_channel = INCOMING_QUEUE_MAX_BYTES_CHANNEL_V0
        else:
            incoming_queue_max_bytes_channel = INCOMING_QUEUE_MAX_BYTES_CHANNEL
        self._input = MultiplexerSingleChannelQueue(
            incoming_queue_max_bytes_channel,
            INCOMING_QUEUE_LOW_WATERMARK,
            INCOMING_QUEUE_HIGH_WATERMARK,
            self._on_local_input_under_water,
        )
        self._output = output
        self._id = channel_id or MultiplexerChannelId(os.urandom(16))
        self._ip_address = ip_address
        self._peer_protocol_version = peer_protocol_version
        self._throttling = throttling
        self._closing = False
        # Backpressure - We track when our output queue is under water
        # or the remote input queue is under water so we can pause reading
        # of whatever is connected to this channel to prevent overflowing
        # either queue.
        self._local_output_under_water = False
        self._remote_input_under_water = False
        self._output.create_channel(self._id, self._on_local_output_under_water)
        self._pause_resume_reader_callback = pause_resume_reader_callback
        self._reader_paused = False
        self._debug = _LOGGER.isEnabledFor(logging.DEBUG)

    def set_pause_resume_reader_callback(
        self,
        pause_resume_reader_callback: Callable[[bool], None],
    ) -> None:
        """Set pause resume reader callback."""
        self._pause_resume_reader_callback = pause_resume_reader_callback

    def _on_local_input_under_water(self, under_water: bool) -> None:
        """On callback from the input queue when goes under water or recovers."""
        if self._peer_protocol_version < MIN_PROTOCOL_VERSION_FOR_PAUSE_RESUME:
            if self._debug:
                _LOGGER.debug(
                    "Remote does not support pause/resume, ignoring %s input water",
                    self._id,
                )
            return
        msg_type = CHANNEL_FLOW_PAUSE if under_water else CHANNEL_FLOW_RESUME
        # Tell the remote that our input queue is under water so it
        # can pause reading from whatever is connected to this channel
        if self._debug:
            _LOGGER.debug(
                "Informing remote that %s input is now %s water",
                self._id,
                "under" if under_water else "above",
            )
        try:
            self._output.put_nowait(self._id, MultiplexerMessage(self._id, msg_type))
        except asyncio.QueueFull:
            _LOGGER.warning(
                "%s: Cannot send pause/resume message to peer, output queue is full",
                self._id,
            )

    def _on_local_output_under_water(self, under_water: bool) -> None:
        """On callback from the output queue when goes under water or recovers."""
        if self._debug:
            _LOGGER.debug(
                "Local output is under water: %s for %s",
                under_water,
                self._id,
            )
        self._local_output_under_water = under_water
        self._pause_or_resume_reader()

    def on_remote_input_under_water(self, under_water: bool) -> None:
        """Call when remote input is under water."""
        if self._debug:
            _LOGGER.debug(
                "Remote input is under water: %s for %s",
                under_water,
                self._id,
            )
        self._remote_input_under_water = under_water
        self._pause_or_resume_reader()

    def _pause_or_resume_reader(self) -> None:
        """Pause or resume reader."""
        # Pause if either local output or remote input is under water
        # Resume if both local output and remote input are not under water
        if self._pause_resume_reader_callback is None:
            return
        pause_reader = self._local_output_under_water or self._remote_input_under_water
        if self._reader_paused != pause_reader:
            # Call the callback first, then update state if successful
            # This ensures state consistency even if callback fails
            self._pause_resume_reader_callback(pause_reader)
            self._reader_paused = pause_reader

    @property
    def id(self) -> MultiplexerChannelId:
        """Return ID of this channel."""
        return self._id

    @property
    def ip_address(self) -> IPv4Address:
        """Return caller IP4Address."""
        return self._ip_address

    @property
    def unhealthy(self) -> bool:
        """Return True if an error has occurred."""
        return self._input.full()

    @property
    def closing(self) -> bool:
        """Return True if channel is in closing state."""
        return self._closing

    def close(self) -> None:
        """Close channel on next run."""
        _LOGGER.debug("Schedule close channel %s", self._id)
        self._closing = True
        with suppress(asyncio.QueueFull):
            self._input.put_nowait(None)

    async def write(self, data: bytes) -> None:
        """Send data to peer."""
        if not data:
            raise MultiplexerTransportError
        if self._closing:
            raise MultiplexerTransportClose

        # Create message
        message = tuple.__new__(
            MultiplexerMessage,
            (self._id, CHANNEL_FLOW_DATA, data, b""),
        )

        try:
            # Try to avoid the timer handle if we can
            # add to the queue without waiting
            self._output.put_nowait(self._id, message)
        except asyncio.QueueFull:
            try:
                async with asyncio_timeout.timeout(5):
                    await self._output.put(self._id, message)
            except TimeoutError:
                if self._debug:
                    _LOGGER.debug("Can't write to peer transport")
                raise MultiplexerTransportError from None

        if self._throttling is not None:
            await asyncio.sleep(self._throttling)

    async def read(self) -> bytes:
        """Read data from peer."""
        if self._closing and self._input.empty():
            message = None
        else:
            message = await self._input.get()

        # Send data
        if message is not None:
            return message.data

        _LOGGER.debug("Read a close message for channel %s", self._id)
        raise MultiplexerTransportClose

    def init_close(self) -> MultiplexerMessage:
        """Init close message for transport."""
        if self._debug:
            _LOGGER.debug("Sending close channel %s", self._id)
        return MultiplexerMessage(self._id, CHANNEL_FLOW_CLOSE)

    def init_new(self) -> MultiplexerMessage:
        """Init new session for transport."""
        if self._debug:
            _LOGGER.debug("Sending new channel %s", self._id)
        extra = b"4" + ip_address_to_bytes(self.ip_address)
        return MultiplexerMessage(self._id, CHANNEL_FLOW_NEW, b"", extra)

    def message_transport(self, message: MultiplexerMessage) -> None:
        """Only for internal usage of core transport."""
        if self._closing:
            return

        try:
            self._input.put_nowait(message)
        except asyncio.QueueFull:
            _LOGGER.warning("Channel %s input is full", self._id)