File: test_wgkillswitch.py

package info (click to toggle)
python-proton-vpn-api-core 4.16.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,312 kB
  • sloc: python: 11,057; makefile: 9
file content (112 lines) | stat: -rw-r--r-- 3,497 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from unittest.mock import Mock, AsyncMock, call, PropertyMock, patch
import pytest

from proton.vpn.backend.networkmanager.killswitch.wireguard import WGKillSwitch
from proton.vpn.backend.networkmanager.killswitch.wireguard.killswitch_connection_handler import KillSwitchConnectionHandler


@pytest.fixture
def vpn_server():
    vpn_server_mock = Mock()
    vpn_server_mock.server_ip = "1.1.1.1"

    return vpn_server_mock


@pytest.mark.asyncio
async def test_enable_without_vpn_server_adds_ks_connection():
    ks_handler_mock = AsyncMock()
    wg_ks = WGKillSwitch(ks_handler_mock)

    await wg_ks.enable()

    assert ks_handler_mock.method_calls == [
        call.add_kill_switch_connection(False)
    ]


@pytest.mark.asyncio
async def test_enable_with_vpn_server_adds_ks_connection_and_route_for_server(vpn_server):
    ks_handler_mock = AsyncMock()
    wg_ks = WGKillSwitch(ks_handler_mock)

    await wg_ks.enable(vpn_server)

    assert ks_handler_mock.method_calls == [
        call.add_kill_switch_connection(False),
        call.add_vpn_server_route(server_ip=vpn_server.server_ip)
    ]


@pytest.mark.asyncio
async def test_disable_killswitch_removes_full_and_route_for_server():
    ks_handler_mock = AsyncMock()
    wg_ks = WGKillSwitch(ks_handler_mock)

    await wg_ks.disable()

    assert ks_handler_mock.method_calls == [
        call.remove_killswitch_connection(),
        call.remove_vpn_server_route()
    ]


@pytest.mark.asyncio
async def test_enable_ipv6_leak_protection_adds_ipv6_ks():
    ks_handler_mock = AsyncMock()

    wg_ks = WGKillSwitch(ks_handler_mock)
    await wg_ks.enable_ipv6_leak_protection()

    assert ks_handler_mock.method_calls == [
        call.add_ipv6_leak_protection()
    ]


@pytest.mark.asyncio
async def test_disable_ipv6_leak_protection_removes_ipv6_ks():
    ks_handler_mock = AsyncMock()

    wg_ks = WGKillSwitch(ks_handler_mock)
    await wg_ks.disable_ipv6_leak_protection()

    assert ks_handler_mock.method_calls == [
        call.remove_ipv6_leak_protection()
    ]


@pytest.fixture
def monkey_patch_connection_handler():
    original = KillSwitchConnectionHandler.is_network_manager_running
    KillSwitchConnectionHandler.is_network_manager_running = PropertyMock(return_value=True)
    yield KillSwitchConnectionHandler
    KillSwitchConnectionHandler.is_network_manager_running = original


@pytest.mark.parametrize("validate_params_dict, assert_bool", [
    (None, False),
    ({}, False),
    ({"protocol": "openvpn"}, False),
    ({"protocol": "wireguard"}, True)
])
@patch("proton.vpn.backend.networkmanager.killswitch.wireguard.wgkillswitch.subprocess")
def test_backend_validate(mock_subprocess, validate_params_dict, assert_bool, monkey_patch_connection_handler):
    assert WGKillSwitch._validate(validate_params_dict) == assert_bool