File: publisher.py

package info (click to toggle)
python-proton-vpn-api-core 0.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 892 kB
  • sloc: python: 6,582; makefile: 8
file content (101 lines) | stat: -rw-r--r-- 3,660 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Implementation of the Publisher/Subscriber used to signal VPN connection
state changes.


Copyright (c) 2023 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
import asyncio
import inspect
from typing import Callable, List, Optional

from proton.vpn import logging

logger = logging.getLogger(__name__)


class Publisher:
    """Simple generic implementation of the publish-subscribe pattern."""

    def __init__(self, subscribers: Optional[List[Callable]] = None):
        self._subscribers = subscribers or []
        self._pending_tasks = set()

    def register(self, subscriber: Callable):
        """
        Registers a subscriber to be notified of new updates.

        The subscribers are not expected to block, as they will be notified
        sequentially, one after the other in the order in which they were
        registered.

        :param subscriber: callback that will be called with the expected
            args/kwargs whenever there is an update.
        :raises ValueError: if the subscriber is not callable.
        """
        if not callable(subscriber):
            raise ValueError(f"Subscriber to register is not callable: {subscriber}")

        if subscriber not in self._subscribers:
            self._subscribers.append(subscriber)

    def unregister(self, subscriber: Callable):
        """
        Unregisters a subscriber.

        :param subscriber: the subscriber to be unregistered.
        """
        if subscriber in self._subscribers:
            self._subscribers.remove(subscriber)

    def notify(self, *args, **kwargs):
        """
        Notifies the subscribers about a new update.

        All subscribers will be called

        Each backend and/or protocol have to call this method whenever the connection
        state changes, so that each subscriber can receive states changes whenever they occur.

            :param connection_status: the current status of the connection
            :type connection_status: ConnectionStateEnum

        """
        for subscriber in self._subscribers:
            try:
                if inspect.iscoroutinefunction(subscriber):
                    notification_task = asyncio.create_task(subscriber(*args, **kwargs))
                    self._pending_tasks.add(notification_task)
                    notification_task.add_done_callback(self._on_notification_task_done)
                else:
                    subscriber(*args, **kwargs)
            except Exception:  # pylint: disable=broad-except
                logger.exception(f"An error occurred notifying subscriber {subscriber}.")

    def _on_notification_task_done(self, task: asyncio.Task):
        self._pending_tasks.discard(task)
        task.result()

    def is_subscriber_registered(self, subscriber: Callable) -> bool:
        """Returns whether a subscriber is registered or not."""
        return subscriber in self._subscribers

    @property
    def number_of_subscribers(self) -> int:
        """Number of currently registered subscribers."""
        return len(self._subscribers)