File: peer_manager.py

package info (click to toggle)
python-snitun 0.45.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 664 kB
  • sloc: python: 6,681; sh: 5; makefile: 3
file content (139 lines) | stat: -rw-r--r-- 4,679 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""Manage peer connections."""

from __future__ import annotations

import asyncio
from collections.abc import Callable, ValuesView
from datetime import UTC, datetime
from enum import Enum
import json
import logging

from cryptography.fernet import Fernet, InvalidToken, MultiFernet

from ..exceptions import SniTunInvalidPeer
from ..utils.asyncio import asyncio_timeout
from ..utils.server import TokenData
from .peer import Peer

_LOGGER = logging.getLogger(__name__)


class PeerManagerEvent(str, Enum):
    """Peer Manager event flags."""

    CONNECTED = "connected"
    DISCONNECTED = "disconnected"


class PeerManager:
    """Manage Peer connections."""

    def __init__(
        self,
        fernet_tokens: list[str],
        throttling: int | None = None,
        event_callback: Callable[[Peer, PeerManagerEvent], None] | None = None,
    ) -> None:
        """Initialize Peer Manager."""
        self._fernet = MultiFernet([Fernet(key) for key in fernet_tokens])
        self._loop = asyncio.get_event_loop()
        self._throttling = throttling
        self._event_callback = event_callback
        self._peers: dict[str, Peer] = {}

    @property
    def connections(self) -> int:
        """Return count of connected devices."""
        return len(self._peers)

    def iter_peers(self) -> ValuesView[Peer]:
        """Iterate over all peers."""
        return self._peers.values()

    def create_peer(self, fernet_data: bytes) -> Peer:
        """Create a new peer from crypt config."""
        try:
            data = self._fernet.decrypt(fernet_data).decode("utf-8")
            config: TokenData = json.loads(data)
        except (InvalidToken, json.JSONDecodeError, UnicodeDecodeError) as err:
            raise SniTunInvalidPeer("Invalid fernet token") from err

        # Check if token is valid
        valid = datetime.fromtimestamp(config["valid"], tz=UTC)
        if valid < datetime.now(tz=UTC):
            raise SniTunInvalidPeer("Token was expired")

        # Extract configuration
        hostname = config["hostname"]
        aes_key = bytes.fromhex(config["aes_key"])
        aes_iv = bytes.fromhex(config["aes_iv"])

        return Peer(
            hostname,
            valid,
            aes_key,
            aes_iv,
            protocol_version=config.get("protocol_version", 0),
            throttling=self._throttling,
            alias=config.get("alias", []),
        )

    def add_peer(self, peer: Peer) -> None:
        """Register peer to internal hostname list."""
        if self.peer_available(peer.hostname) and (
            multiplexer := self._peers[peer.hostname].multiplexer
        ):
            _LOGGER.warning("Found stale peer connection")
            multiplexer.shutdown()

        _LOGGER.debug("New peer connection: %s", peer.hostname)
        self._peers[peer.hostname] = peer
        for alias in peer.alias:
            _LOGGER.debug("New peer connection alias: %s for %s", alias, peer.hostname)
            self._peers[alias] = peer

        if self._event_callback:
            self._loop.call_soon(self._event_callback, peer, PeerManagerEvent.CONNECTED)

    def remove_peer(self, peer: Peer) -> None:
        """Remove peer from list."""
        if self._peers.get(peer.hostname) != peer:
            return
        _LOGGER.debug("Close peer connection: %s", peer.hostname)
        for hostname in peer.all_hostnames:
            self._peers.pop(hostname, None)

        if self._event_callback:
            self._loop.call_soon(
                self._event_callback,
                peer,
                PeerManagerEvent.DISCONNECTED,
            )

    def peer_available(self, hostname: str) -> bool:
        """Check if peer available and return True or False."""
        if hostname in self._peers:
            return self._peers[hostname].is_ready
        return False

    def get_peer(self, hostname: str) -> Peer | None:
        """Get peer."""
        return self._peers.get(hostname)

    async def close_connections(self, timeout: int = 10) -> None:  # noqa: ASYNC109
        """Close all peer connections.

        Use this function only if you do not controll the server socket.
        """
        peers = list(self._peers.values())
        for peer in peers:
            if peer.is_connected and peer.multiplexer:
                peer.multiplexer.shutdown()

        if waiters := [peer.wait_disconnect() for peer in peers]:
            try:
                async with asyncio_timeout.timeout(timeout):
                    await asyncio.gather(*waiters, return_exceptions=True)
            except TimeoutError:
                _LOGGER.error("Timeout while waiting for peer disconnect")