File: test_nmkillswitch.py

package info (click to toggle)
python-proton-vpn-api-core 4.16.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,312 kB
  • sloc: python: 11,057; makefile: 9
file content (118 lines) | stat: -rw-r--r-- 3,719 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from unittest.mock import Mock, AsyncMock, call
import pytest
from ipaddress import ip_network, collapse_addresses

from proton.vpn.backend.networkmanager.killswitch.default import NMKillSwitch
from proton.vpn.backend.networkmanager.killswitch.default.killswitch_connection_handler \
    import build_routes_list, LOCAL_AGENT_SERVER_ADDR


@pytest.fixture
def vpn_server():
    vpn_server_mock = Mock()
    vpn_server_mock.server_ip = "1.1.1.1"

    return vpn_server_mock


@pytest.mark.asyncio
async def test_enable_without_vpn_server_adds_full_ks_and_removes_routed_ks():
    ks_handler_mock = AsyncMock()
    nm_killswitch = NMKillSwitch(ks_handler_mock)

    await nm_killswitch.enable()

    assert ks_handler_mock.method_calls == [
        call.add_full_killswitch_connection(False),
        call.remove_routed_killswitch_connection(),
    ]


@pytest.mark.asyncio
async def test_enable_with_vpn_server(vpn_server):
    """
    When enabling the KS specifying a vpn server to connect to we expect:
     1) The full KS is added first, to block all network traffic until the routed KS is set up.
     2) The routed KS is removed (if found).
     2) A new routed KS whitelisting the VPN server IP is added.
     4) The full KS is removed to let the routed KS take over.
    """
    ks_handler_mock = AsyncMock()
    nm_killswitch = NMKillSwitch(ks_handler_mock)

    await nm_killswitch.enable(vpn_server)

    assert ks_handler_mock.method_calls == [
        call.add_full_killswitch_connection(False),
        call.remove_routed_killswitch_connection(),
        call.add_routed_killswitch_connection(vpn_server.server_ip, False),
        call.remove_full_killswitch_connection()
    ]


@pytest.mark.asyncio
async def test_disable_killswitch_removes_full_and_routed_ks():
    ks_handler_mock = AsyncMock()
    nm_killswitch = NMKillSwitch(ks_handler_mock)

    await nm_killswitch.disable()

    assert ks_handler_mock.method_calls == [
        call.remove_full_killswitch_connection(),
        call.remove_routed_killswitch_connection()
    ]


@pytest.mark.asyncio
async def test_enable_ipv6_leak_protection_adds_ipv6_ks():
    ks_handler_mock = AsyncMock()

    nm_killswitch = NMKillSwitch(ks_handler_mock)
    await nm_killswitch.enable_ipv6_leak_protection()

    assert ks_handler_mock.method_calls == [
        call.add_ipv6_leak_protection()
    ]


@pytest.mark.asyncio
async def test_disable_ipv6_leak_protection_removes_ipv6_ks():
    ks_handler_mock = AsyncMock()

    nm_killswitch = NMKillSwitch(ks_handler_mock)
    await nm_killswitch.disable_ipv6_leak_protection()

    assert ks_handler_mock.method_calls == [
        call.remove_ipv6_leak_protection()
    ]


def test_build_routes_list():
    vpn_server = "192.168.2.1"

    allowed_routes = [
        ip_network(LOCAL_AGENT_SERVER_ADDR),
        ip_network(vpn_server)
    ]
    forbidden_routes = list(build_routes_list(vpn_server))
    all_routes = list(collapse_addresses(allowed_routes + forbidden_routes))

    assert all_routes == [ip_network('0.0.0.0/0')]