1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
"""Tests for aiohttp snitun client."""
from unittest.mock import AsyncMock, MagicMock, patch
from snitun.utils.aiohttp_client import SniTunClientAioHttp
async def test_init_client() -> None:
"""Init aiohttp client for test."""
with patch("snitun.utils.aiohttp_client.SockSite"):
client = SniTunClientAioHttp(None, None, "127.0.0.1")
assert not client.is_connected
async def test_client_stop_no_wait() -> None:
"""Test that we do not wait if wait is not passed to the stop"""
with patch("snitun.utils.aiohttp_client.SockSite"):
client = SniTunClientAioHttp(None, None, "127.0.0.1")
with patch(
"snitun.utils.aiohttp_client._async_waitfor_socket_closed",
) as waitfor_socket_closed:
waitfor_socket_closed.assert_not_called()
await client.stop()
waitfor_socket_closed.assert_not_called()
await client.stop(wait=True)
waitfor_socket_closed.assert_called()
async def test_client_connect_with_protocol_version() -> None:
"""Test connecting with a custom protocol version."""
mock_client_peer = MagicMock()
mock_client_peer.start = AsyncMock()
mock_client_peer.is_connected = False
mock_connector = MagicMock()
mock_site = MagicMock()
mock_site.start = AsyncMock()
with (
patch("snitun.utils.aiohttp_client.SockSite", return_value=mock_site),
patch("snitun.utils.aiohttp_client.ClientPeer", return_value=mock_client_peer),
patch("snitun.utils.aiohttp_client.Connector", return_value=mock_connector),
):
client = SniTunClientAioHttp(None, None, "127.0.0.1")
await client.start()
await client.connect(
fernet_key=b"test_token",
aes_key=b"0" * 32,
aes_iv=b"0" * 16,
)
mock_client_peer.start.assert_called_once()
args = mock_client_peer.start.call_args
assert "protocol_version" in args.kwargs
assert args.kwargs["protocol_version"] == 0 # DEFAULT_PROTOCOL_VERSION
mock_client_peer.start.reset_mock()
await client.connect(
fernet_key=b"test_token",
aes_key=b"0" * 32,
aes_iv=b"0" * 16,
protocol_version=0,
)
mock_client_peer.start.assert_called_once()
args = mock_client_peer.start.call_args
assert "protocol_version" in args.kwargs
assert args.kwargs["protocol_version"] == 0
|