1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
"""Manages global tasks."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Generator
import logging
from typing import TYPE_CHECKING, Any
from xknx.core import XknxConnectionState
AsyncCallbackType = Callable[[], Coroutine[Any, Any, None]]
if TYPE_CHECKING:
from xknx import XKNX
logger = logging.getLogger("xknx.log")
class Task:
"""Handles a given task."""
def __init__(
self,
name: str,
async_func: AsyncCallbackType,
restart_after_reconnect: bool = False,
) -> None:
"""Initialize Task class."""
self.name = name
self.async_func = async_func
self.restart_after_reconnect = restart_after_reconnect
self._task: asyncio.Task[None] | None = None
def start(self) -> Task:
"""Start a task."""
self._task = asyncio.create_task(self.async_func(), name=self.name)
return self
def __await__(self) -> Generator[None, None, None]:
"""Wait for task to be finished."""
if self._task:
yield from self._task
def cancel(self) -> None:
"""Cancel a task."""
if self._task:
self._task.cancel()
self._task = None
def done(self) -> bool:
"""Return if task is finished."""
return self._task is None or self._task.done()
def connection_lost(self) -> None:
"""Cancel a task if connection was lost and the task should be cancelled if no connection is established."""
if self.restart_after_reconnect and self._task:
logger.debug("Stopping task %s because of connection loss.", self.name)
self.cancel()
def reconnected(self) -> None:
"""Restart when reconnected to bus."""
if self.restart_after_reconnect and not self._task:
logger.debug(
"Restarting task %s as the connection to the bus was reestablished.",
self.name,
)
self.start()
class TaskRegistry:
"""Manages async tasks in XKNX."""
def __init__(self, xknx: XKNX) -> None:
"""Initialize TaskRegistry class."""
self.xknx = xknx
self.tasks: list[Task] = []
self._background_task: set[asyncio.Task[None]] = set()
def register(
self,
name: str,
async_func: AsyncCallbackType,
track_task: bool = True,
restart_after_reconnect: bool = False,
) -> Task:
"""Register new task."""
self.unregister(name)
_task: Task = Task(
name=name,
async_func=async_func,
restart_after_reconnect=restart_after_reconnect,
)
if track_task:
self.tasks.append(_task)
return _task
def unregister(self, name: str) -> None:
"""Unregister task."""
for task in self.tasks:
if task.name == name:
task.cancel()
self.tasks.remove(task)
def start(self) -> None:
"""Start task registry."""
self.xknx.connection_manager.register_connection_state_changed_cb(
self.connection_state_changed_cb
)
def stop(self) -> None:
"""Stop task registry and cancel all tasks."""
self.xknx.connection_manager.unregister_connection_state_changed_cb(
self.connection_state_changed_cb
)
for task in self.tasks:
task.cancel()
self.tasks = []
async def block_till_done(self) -> None:
"""Await all tracked tasks."""
await asyncio.gather(*self.tasks)
def connection_state_changed_cb(self, state: XknxConnectionState) -> None:
"""Handle connection state changes."""
for task in self.tasks:
if state == XknxConnectionState.CONNECTED:
task.reconnected()
else:
task.connection_lost()
def background(self, async_func: Coroutine[Any, Any, None]) -> None:
"""Run a task in the background. This task will not be tracked by the TaskRegistry."""
# Add task to the set. This creates a strong reference so it can't be garbage collected.
task = asyncio.create_task(async_func)
# To prevent keeping references to finished tasks forever,
self._background_task.add(task)
# make each task remove its own reference from the set after
# completion:
task.add_done_callback(self._background_task.discard)
|