1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
"""
Implementation of the Publisher/Subscriber used to signal VPN connection
state changes.
Copyright (c) 2023 Proton AG
This file is part of Proton VPN.
Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
"""
import asyncio
import inspect
from typing import Callable, List, Optional
from proton.vpn import logging
logger = logging.getLogger(__name__)
class Publisher:
"""Simple generic implementation of the publish-subscribe pattern."""
def __init__(self, subscribers: Optional[List[Callable]] = None):
self._subscribers = subscribers or []
self._pending_tasks = set()
def register(self, subscriber: Callable):
"""
Registers a subscriber to be notified of new updates.
The subscribers are not expected to block, as they will be notified
sequentially, one after the other in the order in which they were
registered.
:param subscriber: callback that will be called with the expected
args/kwargs whenever there is an update.
:raises ValueError: if the subscriber is not callable.
"""
if not callable(subscriber):
raise ValueError(f"Subscriber to register is not callable: {subscriber}")
if subscriber not in self._subscribers:
self._subscribers.append(subscriber)
def unregister(self, subscriber: Callable):
"""
Unregisters a subscriber.
:param subscriber: the subscriber to be unregistered.
"""
if subscriber in self._subscribers:
self._subscribers.remove(subscriber)
def notify(self, *args, **kwargs):
"""
Notifies the subscribers about a new update.
All subscribers will be called
Each backend and/or protocol have to call this method whenever the connection
state changes, so that each subscriber can receive states changes whenever they occur.
:param connection_status: the current status of the connection
:type connection_status: ConnectionStateEnum
"""
for subscriber in self._subscribers:
try:
if inspect.iscoroutinefunction(subscriber):
notification_task = asyncio.create_task(subscriber(*args, **kwargs))
self._pending_tasks.add(notification_task)
notification_task.add_done_callback(self._on_notification_task_done)
else:
subscriber(*args, **kwargs)
except Exception: # pylint: disable=broad-except
logger.exception(f"An error occurred notifying subscriber {subscriber}.")
def _on_notification_task_done(self, task: asyncio.Task):
self._pending_tasks.discard(task)
task.result()
def is_subscriber_registered(self, subscriber: Callable) -> bool:
"""Returns whether a subscriber is registered or not."""
return subscriber in self._subscribers
@property
def number_of_subscribers(self) -> int:
"""Number of currently registered subscribers."""
return len(self._subscribers)
|