1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362
|
"""
The different VPN connection states and their transitions is defined here.
Copyright (c) 2023 Proton AG
This file is part of Proton VPN.
Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, ClassVar
from proton.vpn import logging
from proton.vpn.connection import events
from proton.vpn.connection.enum import ConnectionStateEnum, KillSwitchSetting
from proton.vpn.connection.events import EventContext
from proton.vpn.connection.exceptions import ConcurrentConnectionsError
from proton.vpn.killswitch.interface import KillSwitch
if TYPE_CHECKING:
from proton.vpn.connection.vpnconnection import VPNConnection
logger = logging.getLogger(__name__)
@dataclass
class StateContext:
"""
Relevant state context data.
Attributes:
event: Event that led to the current state.
connection: current VPN connection. They only case where this
attribute could be None is on the initial state, if there is not
already an existing VPN connection.
reconnection: optional VPN connection to connect to as soon as stopping the current one.
kill_switch: kill switch implementation.
kill_switch_setting: on, off, permanent.
"""
event: events.Event = field(default_factory=events.Initialized)
connection: Optional["VPNConnection"] = None
reconnection: Optional["VPNConnection"] = None
kill_switch: ClassVar[KillSwitch] = None
kill_switch_setting: ClassVar[KillSwitchSetting] = None
class State(ABC):
"""
This is the base state from which all other states derive from. Each new
state has to implement the `on_event` method.
Since these states are backend agnostic. When implement a new backend the
person implementing it has to have special care in correctly translating
the backend specific events to known events
(see `proton.vpn.connection.events`).
Each state acts on the `on_event` method. Generally, if a state receives
an unexpected event, it will then not update the state but rather keep the
same state and should log the occurrence.
The general idea of state transitions:
1) Connect happy path: Disconnected -> Connecting -> Connected
2) Connect with error path: Disconnected -> Connecting -> Error
3) Disconnect happy path: Connected -> Disconnecting -> Disconnected
4) Active connection error path: Connected -> Error
Certain states will have to call methods from the state machine
(see `Disconnected`, `Connected`). Both of these states call
`vpn_connection.start()` and `vpn_connection.stop()`.
It should be noted that these methods should be run in an async way so that
it does not block the execution of the next line.
States also have `context` (which are fetched from events). These can help
in discovering potential issues on why certain states might an unexpected
behavior. It is worth mentioning though that the contexts will always
be backend specific.
"""
type = None
def __init__(self, context: StateContext = None):
self.context = context or StateContext()
if self.type is None:
raise TypeError("Undefined attribute \"state\" ")
def _assert_no_concurrent_connections(self, event: events.Event):
not_up_event = not isinstance(event, events.Up)
different_connection = event.context.connection is not self.context.connection
if not_up_event and different_connection:
# Any state should always receive events for the same connection, the only
# exception being when the Up event is received. In this case, the Up event
# always carries a new connection: the new connection to be initiated.
raise ConcurrentConnectionsError(
f"State {self} expected events from {self.context.connection} "
f"but received an event from {event.context.connection} instead."
)
def on_event(self, event: events.Event) -> State:
"""Returns the new state based on the received event."""
self._assert_no_concurrent_connections(event)
event.check_for_errors()
new_state = self._on_event(event)
if new_state is self:
logger.warning(
f"{self.type.name} state received unexpected "
f"event: {type(event).__name__}",
category="CONN", event="WARNING"
)
return new_state
@abstractmethod
def _on_event(
self, event: events.Event
) -> State:
"""Given an event, it returns the new state."""
async def run_tasks(self) -> Optional[events.Event]:
"""Tasks to be run when this state instance becomes the current VPN state."""
@property
def forwarded_port(self) -> Optional[int]:
"""Returns the forwarded port if it exists."""
return self.context.event.context.forwarded_port
class Disconnected(State):
"""
Disconnected is the initial state of a connection. It's also its final
state, except if the connection could not be established due to an error.
"""
type = ConnectionStateEnum.DISCONNECTED
def _on_event(self, event: events.Event):
if isinstance(event, events.Up):
return Connecting(StateContext(event=event, connection=event.context.connection))
return self
async def run_tasks(self):
# When the state machine is in disconnected state, a VPN connection
# may have not been created yet.
if self.context.connection:
await self.context.connection.remove_persistence()
if self.context.reconnection:
# The Kill switch is enabled to avoid leaks when switching servers, even when
# the kill switch setting is off.
await self.context.kill_switch.enable()
# When a reconnection is expected, an Up event is returned to start a new connection.
# straight away.
return events.Up(EventContext(connection=self.context.reconnection))
if self.context.kill_switch_setting == KillSwitchSetting.PERMANENT:
# This is an abstraction leak of the network manager KS.
# The only reason for enabling permanent KS here is to switch from the
# routed KS to the full KS if the user cancels the connection while in
# Connecting state. Otherwise, the full KS should already be there.
await self.context.kill_switch.enable(permanent=True)
else:
await self.context.kill_switch.disable()
await self.context.kill_switch.disable_ipv6_leak_protection()
return None
class Connecting(State):
"""
Connecting is the state reached when a VPN connection is requested.
"""
type = ConnectionStateEnum.CONNECTING
_counter = 0
def _on_event(self, event: events.Event):
if isinstance(event, events.Connected):
return Connected(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Down):
return Disconnecting(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Error):
return Error(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Up):
# If a new connection is requested while in `Connecting` state then
# cancel the current one and pass the requested connection so that it's
# started as soon as the current connection is down.
return Disconnecting(
StateContext(
event=event,
connection=self.context.connection,
reconnection=event.context.connection
)
)
if isinstance(event, events.Disconnected):
# Another process disconnected the VPN, otherwise the Disconnected
# event would've been received by the Disconnecting state.
return Disconnected(StateContext(event=event, connection=event.context.connection))
return self
async def run_tasks(self):
permanent_ks = self.context.kill_switch_setting == KillSwitchSetting.PERMANENT
# The reason for always enabling the kill switch independently of the kill switch setting
# is to avoid leaks when switching servers, even with the kill switch turned off.
# However, when the kill switch setting is off, the kill switch has to be removed when
# reaching the connected state.
await self.context.kill_switch.enable(
self.context.connection.server,
permanent=permanent_ks
)
await self.context.connection.start()
class Connected(State):
"""
Connected is the state reached once the VPN connection has been successfully
established.
"""
type = ConnectionStateEnum.CONNECTED
def _on_event(self, event: events.Event):
if isinstance(event, events.Down):
return Disconnecting(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Up):
# If a new connection is requested while in `Connected` state then
# cancel the current one and pass the requested connection so that it's
# started as soon as the current connection is down.
return Disconnecting(
StateContext(
event=event,
connection=self.context.connection,
reconnection=event.context.connection
)
)
if isinstance(event, events.Error):
return Error(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Disconnected):
# Another process disconnected the VPN, otherwise the Disconnected
# event would've been received by the Disconnecting state.
return Disconnected(StateContext(event=event, connection=event.context.connection))
return self
async def run_tasks(self):
if self.context.kill_switch_setting == KillSwitchSetting.OFF:
await self.context.kill_switch.enable_ipv6_leak_protection()
await self.context.kill_switch.disable()
else:
# This is specific to the routing table KS implementation and should be removed.
# At this point we switch from the routed KS to the full-on KS.
await self.context.kill_switch.enable(
permanent=(self.context.kill_switch_setting == KillSwitchSetting.PERMANENT)
)
await self.context.connection.add_persistence()
class Disconnecting(State):
"""
Disconnecting is state reached when VPN disconnection is requested.
"""
type = ConnectionStateEnum.DISCONNECTING
def _on_event(self, event: events.Event):
if isinstance(event, (events.Disconnected, events.Error)):
# Note that error events signal disconnection from the VPN due to
# unexpected reasons. In this case, since the goal of the
# disconnecting state is to reach the disconnected state,
# both disconnected and error events lead to the desired state.
if isinstance(event, events.Error):
logger.warning(
"Error event while disconnecting: %s (%s)",
type(event).__name__,
event.context.error
)
return Disconnected(
StateContext(
event=event,
connection=event.context.connection,
reconnection=self.context.reconnection
)
)
if isinstance(event, events.Up):
# If a new connection is requested while in the `Disconnecting` state then
# store the requested connection in the state context so that it's started
# as soon as the current connection is down.
self.context.reconnection = event.context.connection
return self
async def run_tasks(self):
await self.context.connection.stop()
class Error(State):
"""
Error is the state reached after a connection error.
"""
type = ConnectionStateEnum.ERROR
def _on_event(self, event: events.Event):
if isinstance(event, events.Down):
return Disconnected(StateContext(event=event, connection=event.context.connection))
if isinstance(event, events.Up):
return Disconnecting(
StateContext(
event=event,
connection=self.context.connection,
reconnection=event.context.connection
)
)
if isinstance(event, events.Connected):
return Connected(
StateContext(
event=event,
connection=self.context.connection,
)
)
if isinstance(event, events.Error):
return Error(StateContext(event=event, connection=event.context.connection))
return self
async def run_tasks(self):
logger.warning(
"Reached connection error state: %s (%s)",
type(self.context.event).__name__,
self.context.event.context.error
)
|