File: test_connector.py

package info (click to toggle)
python-snitun 0.45.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 640 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (309 lines) | stat: -rw-r--r-- 9,306 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""Test client connector."""

import asyncio
import ipaddress
from typing import cast
from unittest.mock import AsyncMock, patch

import pytest

from snitun.client.connector import Connector, ConnectorHandler
from snitun.exceptions import MultiplexerTransportClose
from snitun.multiplexer.channel import MultiplexerChannel
from snitun.multiplexer.core import Multiplexer

from ..conftest import Client

IP_ADDR = ipaddress.ip_address("8.8.8.8")
BAD_ADDR = ipaddress.ip_address("8.8.1.1")


async def test_init_connector(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and init a connector."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822")
    multiplexer_client._new_connections = connector.handler

    channel = await multiplexer_server.create_channel(IP_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert test_endpoint
    test_connection = test_endpoint[0]

    await channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.close.set()


async def test_flow_connector(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and and perform a connector flow."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822")
    multiplexer_client._new_connections = connector.handler

    channel = await multiplexer_server.create_channel(IP_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert test_endpoint
    test_connection = test_endpoint[0]

    await channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.writer.write(b"Hiro")
    await test_connection.writer.drain()

    data = await channel.read()
    assert data == b"Hiro"

    test_connection.close.set()


async def test_close_connector_remote(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and init a connector with remote close."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822")
    multiplexer_client._new_connections = connector.handler

    channel = await multiplexer_server.create_channel(IP_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert test_endpoint
    test_connection = test_endpoint[0]

    await channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.writer.write(b"Hiro")
    await test_connection.writer.drain()

    data = await channel.read()
    assert data == b"Hiro"

    multiplexer_server.delete_channel(channel)
    data = await test_connection.reader.read(1024)
    assert not data

    test_connection.close.set()


async def test_close_connector_local(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and init a connector."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822")
    multiplexer_client._new_connections = connector.handler

    channel = await multiplexer_server.create_channel(IP_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert test_endpoint
    test_connection = test_endpoint[0]

    await channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.writer.write(b"Hiro")
    await test_connection.writer.drain()

    data = await channel.read()
    assert data == b"Hiro"

    test_connection.writer.close()
    test_connection.close.set()
    await asyncio.sleep(0.1)

    with pytest.raises(MultiplexerTransportClose):
        await channel.read()


async def test_init_connector_whitelist(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and init a connector with whitelist."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822", True)
    multiplexer_client._new_connections = connector.handler

    connector.whitelist.add(IP_ADDR)
    assert IP_ADDR in connector.whitelist
    channel = await multiplexer_server.create_channel(IP_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert test_endpoint
    test_connection = test_endpoint[0]

    await channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.close.set()


async def test_init_connector_whitelist_bad(
    test_endpoint: list[Client],
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test and init a connector with whitelist bad requests."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822", True)
    multiplexer_client._new_connections = connector.handler

    connector.whitelist.add(IP_ADDR)
    assert IP_ADDR in connector.whitelist
    assert BAD_ADDR not in connector.whitelist
    channel = await multiplexer_server.create_channel(BAD_ADDR, lambda _: None)
    await asyncio.sleep(0.1)

    assert not test_endpoint

    with pytest.raises(MultiplexerTransportClose):
        await channel.read()


async def test_connector_error_callback(
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test connector endpoint error callback."""
    callback = AsyncMock()
    connector = Connector("127.0.0.1", "8822", False, callback)

    channel = await multiplexer_client.create_channel(IP_ADDR, lambda _: None)

    callback.assert_not_called()

    with patch("asyncio.open_connection", side_effect=OSError("Lorem ipsum...")):
        await connector.handler(multiplexer_client, channel)

    callback.assert_called_once()


async def test_connector_no_error_callback(
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
) -> None:
    """Test connector with not endpoint error callback."""
    connector = Connector("127.0.0.1", "8822", False, None)
    channel = await multiplexer_client.create_channel(IP_ADDR, lambda _: None)
    with patch("asyncio.open_connection", side_effect=OSError("Lorem ipsum...")):
        await connector.handler(multiplexer_client, channel)


async def test_connector_handler_can_pause(
    multiplexer_client: Multiplexer,
    multiplexer_server: Multiplexer,
    test_endpoint: list[Client],
) -> None:
    """Test connector handler can pause."""
    assert not test_endpoint

    connector = Connector("127.0.0.1", "8822")
    multiplexer_client._new_connections = connector.handler

    connector_handler: ConnectorHandler | None = None

    def save_connector_handler(
        loop: asyncio.AbstractEventLoop,
        channel: MultiplexerChannel,
    ) -> ConnectorHandler:
        nonlocal connector_handler
        connector_handler = ConnectorHandler(loop, channel)
        return connector_handler

    with patch("snitun.client.connector.ConnectorHandler", save_connector_handler):
        server_channel = await multiplexer_server.create_channel(
            IP_ADDR,
            lambda _: None,
        )
        await asyncio.sleep(0.1)

    assert isinstance(connector_handler, ConnectorHandler)
    handler = cast(ConnectorHandler, connector_handler)
    client_channel = handler._channel
    assert client_channel._pause_resume_reader_callback is not None
    assert (
        client_channel._pause_resume_reader_callback
        == handler._pause_resume_reader_callback
    )

    assert test_endpoint
    test_connection = test_endpoint[0]

    await server_channel.write(b"Hallo")
    data = await test_connection.reader.read(1024)
    assert data == b"Hallo"

    test_connection.writer.write(b"Hiro")
    await test_connection.writer.drain()

    data = await server_channel.read()
    assert data == b"Hiro"

    assert handler._pause_future is None
    # Simulate that the remote input goes under water
    client_channel.on_remote_input_under_water(True)
    assert handler._pause_future is not None

    await server_channel.write(b"Goodbye")
    data = await test_connection.reader.read(1024)
    assert data == b"Goodbye"

    # This is an implementation detail that we might
    # change in the future, but for now we need to
    # to read one more message because we don't cancel
    # the current read when the reader pauses as the additional
    # complexity is not worth it.
    test_connection.writer.write(b"Should read one more")
    await test_connection.writer.drain()
    assert await server_channel.read() == b"Should read one more"

    test_connection.writer.write(b"ByeBye")
    await test_connection.writer.drain()

    read_task = asyncio.create_task(server_channel.read())
    await asyncio.sleep(0.1)
    # Make sure reader is actually paused
    assert not read_task.done()

    # Now simulate that the remote input is no longer under water
    client_channel.on_remote_input_under_water(False)
    assert handler._pause_future is None
    data = await read_task
    assert data == b"ByeBye"

    test_connection.writer.close()
    test_connection.close.set()
    await asyncio.sleep(0.1)

    with pytest.raises(MultiplexerTransportClose):
        await server_channel.read()