File: states.py

package info (click to toggle)
python-proton-vpn-api-core 0.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 892 kB
  • sloc: python: 6,582; makefile: 8
file content (362 lines) | stat: -rw-r--r-- 14,156 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
The different VPN connection states and their transitions is defined here.


Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, ClassVar

from proton.vpn import logging
from proton.vpn.connection import events
from proton.vpn.connection.enum import ConnectionStateEnum, KillSwitchSetting
from proton.vpn.connection.events import EventContext
from proton.vpn.connection.exceptions import ConcurrentConnectionsError
from proton.vpn.killswitch.interface import KillSwitch


if TYPE_CHECKING:
    from proton.vpn.connection.vpnconnection import VPNConnection


logger = logging.getLogger(__name__)


@dataclass
class StateContext:
    """
    Relevant state context data.

    Attributes:
        event: Event that led to the current state.
        connection: current VPN connection. They only case where this
            attribute could be None is on the initial state, if there is not
            already an existing VPN connection.
        reconnection: optional VPN connection to connect to as soon as stopping the current one.
        kill_switch: kill switch implementation.
        kill_switch_setting: on, off, permanent.
    """
    event: events.Event = field(default_factory=events.Initialized)
    connection: Optional["VPNConnection"] = None
    reconnection: Optional["VPNConnection"] = None
    kill_switch: ClassVar[KillSwitch] = None
    kill_switch_setting: ClassVar[KillSwitchSetting] = None


class State(ABC):
    """
    This is the base state from which all other states derive from. Each new
    state has to implement the `on_event` method.

    Since these states are backend agnostic. When implement a new backend the
    person implementing it has to have special care in correctly translating
    the backend specific events to known events
    (see `proton.vpn.connection.events`).

    Each state acts on the `on_event` method. Generally, if a state receives
    an unexpected event, it will then not update the state but rather keep the
    same state and should log the occurrence.

    The general idea of state transitions:

        1) Connect happy path:      Disconnected -> Connecting -> Connected
        2) Connect with error path: Disconnected -> Connecting -> Error
        3) Disconnect happy path:   Connected -> Disconnecting -> Disconnected
        4) Active connection error path: Connected -> Error

    Certain states will have to call methods from the state machine
    (see `Disconnected`, `Connected`). Both of these states call
    `vpn_connection.start()` and `vpn_connection.stop()`.
    It should be noted that these methods should be run in an async way so that
    it does not block the execution of the next line.

    States also have `context` (which are fetched from events). These can help
    in discovering potential issues on why certain states might an unexpected
    behavior. It is worth mentioning though that the contexts will always
    be backend specific.
    """
    type = None

    def __init__(self, context: StateContext = None):
        self.context = context or StateContext()

        if self.type is None:
            raise TypeError("Undefined attribute \"state\" ")

    def _assert_no_concurrent_connections(self, event: events.Event):
        not_up_event = not isinstance(event, events.Up)
        different_connection = event.context.connection is not self.context.connection
        if not_up_event and different_connection:
            # Any state should always receive events for the same connection, the only
            # exception being when the Up event is received. In this case, the Up event
            # always carries a new connection: the new connection to be initiated.
            raise ConcurrentConnectionsError(
                f"State {self} expected events from {self.context.connection} "
                f"but received an event from {event.context.connection} instead."
            )

    def on_event(self, event: events.Event) -> State:
        """Returns the new state based on the received event."""
        self._assert_no_concurrent_connections(event)

        event.check_for_errors()

        new_state = self._on_event(event)

        if new_state is self:
            logger.warning(
                f"{self.type.name} state received unexpected "
                f"event: {type(event).__name__}",
                category="CONN", event="WARNING"
            )

        return new_state

    @abstractmethod
    def _on_event(
            self, event: events.Event
    ) -> State:
        """Given an event, it returns the new state."""

    async def run_tasks(self) -> Optional[events.Event]:
        """Tasks to be run when this state instance becomes the current VPN state."""

    @property
    def forwarded_port(self) -> Optional[int]:
        """Returns the forwarded port if it exists."""
        return self.context.event.context.forwarded_port


class Disconnected(State):
    """
    Disconnected is the initial state of a connection. It's also its final
    state, except if the connection could not be established due to an error.
    """
    type = ConnectionStateEnum.DISCONNECTED

    def _on_event(self, event: events.Event):
        if isinstance(event, events.Up):
            return Connecting(StateContext(event=event, connection=event.context.connection))

        return self

    async def run_tasks(self):
        # When the state machine is in disconnected state, a VPN connection
        # may have not been created yet.
        if self.context.connection:
            await self.context.connection.remove_persistence()

        if self.context.reconnection:
            # The Kill switch is enabled to avoid leaks when switching servers, even when
            # the kill switch setting is off.
            await self.context.kill_switch.enable()

            # When a reconnection is expected, an Up event is returned to start a new connection.
            # straight away.
            return events.Up(EventContext(connection=self.context.reconnection))

        if self.context.kill_switch_setting == KillSwitchSetting.PERMANENT:
            # This is an abstraction leak of the network manager KS.
            # The only reason for enabling permanent KS here is to switch from the
            # routed KS to the full KS if the user cancels the connection while in
            # Connecting state. Otherwise, the full KS should already be there.
            await self.context.kill_switch.enable(permanent=True)
        else:
            await self.context.kill_switch.disable()
            await self.context.kill_switch.disable_ipv6_leak_protection()

        return None


class Connecting(State):
    """
    Connecting is the state reached when a VPN connection is requested.
    """
    type = ConnectionStateEnum.CONNECTING
    _counter = 0

    def _on_event(self, event: events.Event):
        if isinstance(event, events.Connected):
            return Connected(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Down):
            return Disconnecting(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Error):
            return Error(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Up):
            # If a new connection is requested while in `Connecting` state then
            # cancel the current one and pass the requested connection so that it's
            # started as soon as the current connection is down.
            return Disconnecting(
                StateContext(
                    event=event,
                    connection=self.context.connection,
                    reconnection=event.context.connection
                )
            )

        if isinstance(event, events.Disconnected):
            # Another process disconnected the VPN, otherwise the Disconnected
            # event would've been received by the Disconnecting state.
            return Disconnected(StateContext(event=event, connection=event.context.connection))

        return self

    async def run_tasks(self):
        permanent_ks = self.context.kill_switch_setting == KillSwitchSetting.PERMANENT

        # The reason for always enabling the kill switch independently of the kill switch setting
        # is to avoid leaks when switching servers, even with the kill switch turned off.
        # However, when the kill switch setting is off, the kill switch has to be removed when
        # reaching the connected state.
        await self.context.kill_switch.enable(
            self.context.connection.server,
            permanent=permanent_ks
        )

        await self.context.connection.start()


class Connected(State):
    """
    Connected is the state reached once the VPN connection has been successfully
    established.
    """
    type = ConnectionStateEnum.CONNECTED

    def _on_event(self, event: events.Event):
        if isinstance(event, events.Down):
            return Disconnecting(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Up):
            # If a new connection is requested while in `Connected` state then
            # cancel the current one and pass the requested connection so that it's
            # started as soon as the current connection is down.
            return Disconnecting(
                StateContext(
                    event=event,
                    connection=self.context.connection,
                    reconnection=event.context.connection
                )
            )

        if isinstance(event, events.Error):
            return Error(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Disconnected):
            # Another process disconnected the VPN, otherwise the Disconnected
            # event would've been received by the Disconnecting state.
            return Disconnected(StateContext(event=event, connection=event.context.connection))

        return self

    async def run_tasks(self):
        if self.context.kill_switch_setting == KillSwitchSetting.OFF:
            await self.context.kill_switch.enable_ipv6_leak_protection()
            await self.context.kill_switch.disable()
        else:
            # This is specific to the routing table KS implementation and should be removed.
            # At this point we switch from the routed KS to the full-on KS.
            await self.context.kill_switch.enable(
                permanent=(self.context.kill_switch_setting == KillSwitchSetting.PERMANENT)
            )

        await self.context.connection.add_persistence()


class Disconnecting(State):
    """
    Disconnecting is state reached when VPN disconnection is requested.
    """
    type = ConnectionStateEnum.DISCONNECTING

    def _on_event(self, event: events.Event):
        if isinstance(event, (events.Disconnected, events.Error)):
            # Note that error events signal disconnection from the VPN due to
            # unexpected reasons. In this case, since the goal of the
            # disconnecting state is to reach the disconnected state,
            # both disconnected and error events lead to the desired state.
            if isinstance(event, events.Error):
                logger.warning(
                    "Error event while disconnecting: %s (%s)",
                    type(event).__name__,
                    event.context.error
                )
            return Disconnected(
                StateContext(
                    event=event,
                    connection=event.context.connection,
                    reconnection=self.context.reconnection
                )
            )

        if isinstance(event, events.Up):
            # If a new connection is requested while in the `Disconnecting` state then
            # store the requested connection in the state context so that it's started
            # as soon as the current connection is down.
            self.context.reconnection = event.context.connection

        return self

    async def run_tasks(self):
        await self.context.connection.stop()


class Error(State):
    """
    Error is the state reached after a connection error.
    """
    type = ConnectionStateEnum.ERROR

    def _on_event(self, event: events.Event):
        if isinstance(event, events.Down):
            return Disconnected(StateContext(event=event, connection=event.context.connection))

        if isinstance(event, events.Up):
            return Disconnecting(
                StateContext(
                    event=event,
                    connection=self.context.connection,
                    reconnection=event.context.connection
                )
            )

        if isinstance(event, events.Connected):
            return Connected(
                StateContext(
                    event=event,
                    connection=self.context.connection,
                )
            )

        if isinstance(event, events.Error):
            return Error(StateContext(event=event, connection=event.context.connection))

        return self

    async def run_tasks(self):
        logger.warning(
            "Reached connection error state: %s (%s)",
            type(self.context.event).__name__,
            self.context.event.context.error
        )