File: scheduler.py

package info (click to toggle)
python-proton-vpn-api-core 0.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 892 kB
  • sloc: python: 6,582; makefile: 8
file content (218 lines) | stat: -rw-r--r-- 8,511 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""
Copyright (c) 2024 Proton AG

This file is part of Proton VPN.

Proton VPN is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton VPN is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
import asyncio
import inspect
import time
from asyncio import CancelledError

from dataclasses import dataclass
from typing import Optional, Coroutine, List, Callable


@dataclass
class RunAgain:
    """Object to be returned by a task to be run again after a certain amount of time."""
    delay_in_ms: int

    @staticmethod
    def after_seconds(seconds: float):
        """Returns a RunAgain object to be run after a certain amount of seconds."""
        return RunAgain(delay_in_ms=int(seconds * 1000))


@dataclass
class TaskRecord:
    """Record with details of the task to be executed and when."""
    id: int  # pylint: disable=invalid-name
    timestamp: float
    async_function: Callable[[], Coroutine]
    background_task: Optional[asyncio.Task] = None


class Scheduler:
    """
    Task scheduler.

    The goal of this implementation is to improve the accuracy of the built-in scheduler
    when the system is suspended/resumed. The built-in scheduler does not take into account
    the time the system has been suspended after a task has been scheduled to run after a
    certain amount of time. In this case, the clock is paused and then resumed.

    The way this implementation workarounds this issue is by keeping a record of tasks to
    be executed and the timestamp at which they should be executed. Then it periodically
    checks the lists for any tasks that should be executed and runs them.
    """

    def __init__(self, check_interval_in_ms: int = 10_000):
        self._check_interval_in_ms = check_interval_in_ms
        self._error_callback = None
        self._last_task_id: int = 0
        self._task_list: List[TaskRecord] = []
        self._scheduler_task: Optional[asyncio.Task] = None

    def set_error_callback(self, error_callback: Callable[[Exception], None] = None):
        """Sets the error callback to be called when an error occurs while executing a task."""
        self._error_callback = error_callback

    def unset_error_callback(self):
        """Unsets the error callback."""
        self._error_callback = None

    @property
    def task_list(self):
        """Returns the list of tasks currently scheduled."""
        return self._task_list

    @property
    def is_started(self):
        """Returns whether the scheduler has been started or not."""
        return self._scheduler_task is not None

    @property
    def number_of_remaining_tasks(self):
        """Returns the number of remaining tasks to be executed."""
        return len([record for record in self._task_list if not record.background_task])

    def get_tasks_ready_to_fire(self) -> List[TaskRecord]:
        """
        Returns the tasks that are ready to fire, that is the tasks with a timestamp lower or
        equal than the current unix time."""
        now = time.time()
        return list(filter(
            lambda record: record.timestamp <= now and not record.background_task,
            self._task_list
        ))

    def start(self):
        """Starts the scheduler."""
        if self.is_started:  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
            raise RuntimeError("Scheduler was already started.")

        self._scheduler_task = asyncio.create_task(self._run_periodic_task_list_check())

    async def stop(self):
        """Stops the scheduler and discards all remaining tasks."""
        if self.is_started:    # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
            self._scheduler_task.cancel()

            for record in self._task_list:
                if record.background_task:
                    record.background_task.cancel()
            self._task_list = []

            await self.wait_for_shutdown()
            self._scheduler_task = None

    async def wait_for_shutdown(self, timeout=1):
        """Waits for the scheduler to be stopped."""
        if self.is_started:  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.maintainability.is-function-without-parentheses.is-function-without-parentheses
            try:
                await asyncio.wait_for(self._scheduler_task, timeout)
            except CancelledError:
                pass

    def run_soon(self, async_function: Callable[[], Coroutine]) -> int:
        """
        Runs the coroutine as soon as possible.
        :returns: the scheduled task id.
        """
        return self.run_after(0, async_function)

    def run_after(
            self, delay_in_seconds: float, async_function: Callable[[], Coroutine]
    ) -> int:
        """
        Runs the coroutine after a delay specified in seconds.
        :returns: the scheduled task id.
        """
        return self.run_at(time.time() + delay_in_seconds, async_function)

    def run_at(
            self, timestamp: float, async_function: Callable[[], Coroutine]
    ) -> int:
        """
        Runs the task at the specified timestamp.
        :returns: the scheduled task id.
        """
        if not inspect.iscoroutinefunction(async_function):
            raise ValueError("A coroutine function was expected.")

        self._last_task_id += 1

        record = TaskRecord(
            id=self._last_task_id,
            timestamp=timestamp,
            async_function=async_function
        )
        self._task_list.append(record)

        return record.id

    def cancel_task(self, task_id):
        """Cancels a task to be executed given its task id."""
        for task in self._task_list:  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.correctness.list-modify-iterating.list-modify-while-iterate
            if task.id == task_id:
                if task.background_task:
                    task.background_task.cancel()
                else:
                    self._task_list.remove(task)
                break  # noqa: E501 # pylint: disable=line-too-long # nosemgrep: python.lang.correctness.list-modify-iterating.list-modify-while-iterate

    async def _run_periodic_task_list_check(self):
        while True:
            self.run_tasks_ready_to_fire()
            await asyncio.sleep(self._check_interval_in_ms / 1000)

    def run_tasks_ready_to_fire(self):
        """
        Runs the tasks ready to be executed, that is the tasks with a timestamp lower or equal
        than the current unix time, and removes them from the list.
        """
        tasks_ready_to_fire = self.get_tasks_ready_to_fire()

        # Run the tasks that are ready to be run.
        for task_record in tasks_ready_to_fire:
            task = asyncio.create_task(task_record.async_function())
            task_record.background_task = task
            task.add_done_callback(self._on_task_done)

    def _on_task_done(self, task: asyncio.Task):
        # Get the task record associated with the task.
        task_record = next(filter(lambda record: record.background_task == task, self._task_list))

        result = None
        try:
            # Bubble up exceptions, if any.
            result = task.result()
        except CancelledError:
            # CancelledError is raised when the task is cancelled.
            pass
        except Exception as exc:  # pylint: disable=broad-except
            self._task_list.remove(task_record)
            if not self._error_callback:
                raise exc
            self._error_callback(exc)
            return

        if isinstance(result, RunAgain):
            # if the task record is to be run again then it's rescheduled.
            task_record.timestamp = time.time() + result.delay_in_ms / 1000
            task_record.background_task = None
        else:
            self._task_list.remove(task_record)