1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714
|
"""
transitions.extensions.asyncio
------------------------------
This module contains machine, state and event implementations for asynchronous callback processing.
`AsyncMachine` and `HierarchicalAsyncMachine` use `asyncio` for concurrency. The extension `transitions-anyio`
found at https://github.com/pytransitions/transitions-anyio illustrates how they can be extended to
make use of other concurrency libraries.
The module also contains the state mixin `AsyncTimeout` to asynchronously trigger timeout-related callbacks.
"""
# Overriding base methods of states, transitions and machines with async variants is not considered good practise.
# However, the alternative would mean to either increase the complexity of the base classes or copy code fragments
# and thus increase code complexity and reduce maintainability. If you know a better solution, please file an issue.
# pylint: disable=invalid-overridden-method
import logging
import asyncio
import contextvars
import inspect
from collections import deque
from functools import partial, reduce
import copy
from ..core import State, Condition, Transition, EventData, listify
from ..core import Event, MachineError, Machine
from .nesting import HierarchicalMachine, NestedState, NestedEvent, NestedTransition, resolve_order
_LOGGER = logging.getLogger(__name__)
_LOGGER.addHandler(logging.NullHandler())
class AsyncState(State):
"""A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously."""
async def enter(self, event_data):
"""Triggered when a state is entered.
Args:
event_data: (AsyncEventData): The currently processed event.
"""
_LOGGER.debug("%sEntering state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_enter, event_data)
_LOGGER.info("%sFinished processing state %s enter callbacks.", event_data.machine.name, self.name)
async def exit(self, event_data):
"""Triggered when a state is exited.
Args:
event_data: (AsyncEventData): The currently processed event.
"""
_LOGGER.debug("%sExiting state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_exit, event_data)
_LOGGER.info("%sFinished processing state %s exit callbacks.", event_data.machine.name, self.name)
class NestedAsyncState(NestedState, AsyncState):
"""A state that allows substates. Callback execution is done asynchronously."""
async def scoped_enter(self, event_data, scope=None):
self._scope = scope or []
await self.enter(event_data)
self._scope = []
async def scoped_exit(self, event_data, scope=None):
self._scope = scope or []
await self.exit(event_data)
self._scope = []
class AsyncCondition(Condition):
"""A helper class to await condition checks in the intended way."""
async def check(self, event_data):
"""Check whether the condition passes.
Args:
event_data (EventData): An EventData instance to pass to the
condition (if event sending is enabled) or to extract arguments
from (if event sending is disabled). Also contains the data
model attached to the current machine which is used to invoke
the condition.
"""
func = event_data.machine.resolve_callable(self.func, event_data)
res = func(event_data) if event_data.machine.send_event else func(*event_data.args, **event_data.kwargs)
if inspect.isawaitable(res):
return await res == self.target
return res == self.target
class AsyncTransition(Transition):
"""Representation of an asynchronous transition managed by a ``AsyncMachine`` instance."""
condition_cls = AsyncCondition
async def _eval_conditions(self, event_data):
res = await event_data.machine.await_all([partial(cond.check, event_data) for cond in self.conditions])
if not all(res):
_LOGGER.debug("%sTransition condition failed: Transition halted.", event_data.machine.name)
return False
return True
async def execute(self, event_data):
"""Executes the transition.
Args:
event_data (EventData): An instance of class EventData.
Returns: boolean indicating whether or not the transition was
successfully executed (True if successful, False if not).
"""
_LOGGER.debug("%sInitiating transition from state %s to state %s...",
event_data.machine.name, self.source, self.dest)
await event_data.machine.callbacks(self.prepare, event_data)
_LOGGER.debug("%sExecuted callbacks before conditions.", event_data.machine.name)
if not await self._eval_conditions(event_data):
return False
machine = event_data.machine
# cancel running tasks since the transition will happen
await machine.switch_model_context(event_data.model)
await event_data.machine.callbacks(event_data.machine.before_state_change, event_data)
await event_data.machine.callbacks(self.before, event_data)
_LOGGER.debug("%sExecuted callback before transition.", event_data.machine.name)
if self.dest: # if self.dest is None this is an internal transition with no actual state change
await self._change_state(event_data)
await event_data.machine.callbacks(self.after, event_data)
await event_data.machine.callbacks(event_data.machine.after_state_change, event_data)
_LOGGER.debug("%sExecuted callback after transition.", event_data.machine.name)
return True
async def _change_state(self, event_data):
if hasattr(event_data.machine, "model_graphs"):
graph = event_data.machine.model_graphs[id(event_data.model)]
graph.reset_styling()
graph.set_previous_transition(self.source, self.dest)
await event_data.machine.get_state(self.source).exit(event_data)
event_data.machine.set_state(self.dest, event_data.model)
event_data.update(getattr(event_data.model, event_data.machine.model_attribute))
dest = event_data.machine.get_state(self.dest)
await dest.enter(event_data)
if dest.final:
await event_data.machine.callbacks(event_data.machine.on_final, event_data)
class NestedAsyncTransition(AsyncTransition, NestedTransition):
"""Representation of an asynchronous transition managed by a ``HierarchicalMachine`` instance."""
async def _change_state(self, event_data):
if hasattr(event_data.machine, "model_graphs"):
graph = event_data.machine.model_graphs[id(event_data.model)]
graph.reset_styling()
graph.set_previous_transition(self.source, self.dest)
state_tree, exit_partials, enter_partials = self._resolve_transition(event_data)
for func in exit_partials:
await func()
self._update_model(event_data, state_tree)
for func in enter_partials:
await func()
with event_data.machine():
on_final_cbs, _ = self._final_check(event_data, state_tree, enter_partials)
for on_final_cb in on_final_cbs:
await on_final_cb()
class AsyncEventData(EventData):
"""A redefinition of the base EventData intended to easy type checking."""
class AsyncEvent(Event):
"""A collection of transitions assigned to the same trigger"""
async def trigger(self, model, *args, **kwargs):
"""Serially execute all transitions that match the current state,
halting as soon as one successfully completes. Note that `AsyncEvent` triggers must be awaited.
Args:
args and kwargs: Optional positional or named arguments that will
be passed onto the EventData object, enabling arbitrary state
information to be passed on to downstream triggered functions.
Returns: boolean indicating whether or not a transition was
successfully executed (True if successful, False if not).
"""
func = partial(self._trigger, EventData(None, self, self.machine, model, args=args, kwargs=kwargs))
return await self.machine.process_context(func, model)
async def _trigger(self, event_data):
event_data.state = self.machine.get_state(getattr(event_data.model, self.machine.model_attribute))
try:
if self._is_valid_source(event_data.state):
await self._process(event_data)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
_LOGGER.error("%sException was raised while processing the trigger: %s", self.machine.name, err)
event_data.error = err
if self.machine.on_exception:
await self.machine.callbacks(self.machine.on_exception, event_data)
else:
raise
finally:
try:
await self.machine.callbacks(self.machine.finalize_event, event_data)
_LOGGER.debug("%sExecuted machine finalize callbacks", self.machine.name)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
_LOGGER.error("%sWhile executing finalize callbacks a %s occurred: %s.",
self.machine.name,
type(err).__name__,
str(err))
return event_data.result
async def _process(self, event_data):
await self.machine.callbacks(self.machine.prepare_event, event_data)
_LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", self.machine.name)
for trans in self.transitions[event_data.state.name]:
event_data.transition = trans
event_data.result = await trans.execute(event_data)
if event_data.result:
break
class NestedAsyncEvent(NestedEvent):
"""A collection of transitions assigned to the same trigger.
This Event requires a (subclass of) `HierarchicalAsyncMachine`.
"""
async def trigger_nested(self, event_data):
"""Serially execute all transitions that match the current state,
halting as soon as one successfully completes. NOTE: This should only
be called by HierarchicalMachine instances.
Args:
event_data (AsyncEventData): The currently processed event.
Returns: boolean indicating whether or not a transition was
successfully executed (True if successful, False if not).
"""
machine = event_data.machine
model = event_data.model
state_tree = machine.build_state_tree(getattr(model, machine.model_attribute), machine.state_cls.separator)
state_tree = reduce(dict.get, machine.get_global_name(join=False), state_tree)
ordered_states = resolve_order(state_tree)
done = set()
event_data.event = self
for state_path in ordered_states:
state_name = machine.state_cls.separator.join(state_path)
if state_name not in done and state_name in self.transitions:
event_data.state = machine.get_state(state_name)
event_data.source_name = state_name
event_data.source_path = copy.copy(state_path)
await self._process(event_data)
if event_data.result:
elems = state_path
while elems:
done.add(machine.state_cls.separator.join(elems))
elems.pop()
return event_data.result
async def _process(self, event_data):
machine = event_data.machine
await machine.callbacks(event_data.machine.prepare_event, event_data)
_LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", machine.name)
for trans in self.transitions[event_data.source_name]:
event_data.transition = trans
event_data.result = await trans.execute(event_data)
if event_data.result:
break
class AsyncMachine(Machine):
"""Machine manages states, transitions and models. In case it is initialized without a specific model
(or specifically no model), it will also act as a model itself. Machine takes also care of decorating
models with conveniences functions related to added transitions and states during runtime.
Attributes:
states (OrderedDict): Collection of all registered states.
events (dict): Collection of transitions ordered by trigger/event.
models (list): List of models attached to the machine.
initial (str): Name of the initial state for new models.
prepare_event (list): Callbacks executed when an event is triggered.
before_state_change (list): Callbacks executed after condition checks but before transition is conducted.
Callbacks will be executed BEFORE the custom callbacks assigned to the transition.
after_state_change (list): Callbacks executed after the transition has been conducted.
Callbacks will be executed AFTER the custom callbacks assigned to the transition.
finalize_event (list): Callbacks will be executed after all transitions callbacks have been executed.
Callbacks mentioned here will also be called if a transition or condition check raised an error.
on_exception: A callable called when an event raises an exception. If not set,
the Exception will be raised instead.
queued (bool or str): Whether transitions in callbacks should be executed immediately (False) or sequentially.
send_event (bool): When True, any arguments passed to trigger methods will be wrapped in an EventData
object, allowing indirect and encapsulated access to data. When False, all positional and keyword
arguments will be passed directly to all callback methods.
auto_transitions (bool): When True (default), every state will automatically have an associated
to_{state}() convenience trigger in the base model.
ignore_invalid_triggers (bool): When True, any calls to trigger methods that are not valid for the
present state (e.g., calling an a_to_b() trigger when the current state is c) will be silently
ignored rather than raising an invalid transition exception.
name (str): Name of the ``Machine`` instance mainly used for easier log message distinction.
"""
state_cls = AsyncState
transition_cls = AsyncTransition
event_cls = AsyncEvent
async_tasks = {}
protected_tasks = []
current_context = contextvars.ContextVar('current_context', default=None)
def __init__(self, model=Machine.self_literal, states=None, initial='initial', transitions=None,
send_event=False, auto_transitions=True,
ordered_transitions=False, ignore_invalid_triggers=None,
before_state_change=None, after_state_change=None, name=None,
queued=False, prepare_event=None, finalize_event=None, model_attribute='state',
model_override=False, on_exception=None, on_final=None, **kwargs):
super().__init__(model=None, states=states, initial=initial, transitions=transitions,
send_event=send_event, auto_transitions=auto_transitions,
ordered_transitions=ordered_transitions, ignore_invalid_triggers=ignore_invalid_triggers,
before_state_change=before_state_change, after_state_change=after_state_change, name=name,
queued=bool(queued), prepare_event=prepare_event, finalize_event=finalize_event,
model_attribute=model_attribute, model_override=model_override,
on_exception=on_exception, on_final=on_final, **kwargs)
self._transition_queue_dict = _DictionaryMock(self._transition_queue) if queued is True else {}
self._queued = queued
for model in listify(model):
self.add_model(model)
def add_model(self, model, initial=None):
super().add_model(model, initial)
if self.has_queue == 'model':
for mod in listify(model):
self._transition_queue_dict[id(self) if mod is self.self_literal else id(mod)] = deque()
async def dispatch(self, trigger, *args, **kwargs):
"""Trigger an event on all models assigned to the machine.
Args:
trigger (str): Event name
*args (list): List of arguments passed to the event trigger
**kwargs (dict): Dictionary of keyword arguments passed to the event trigger
Returns:
bool The truth value of all triggers combined with AND
"""
results = await self.await_all([partial(getattr(model, trigger), *args, **kwargs) for model in self.models])
return all(results)
async def callbacks(self, funcs, event_data):
"""Triggers a list of callbacks"""
await self.await_all([partial(event_data.machine.callback, func, event_data) for func in funcs])
async def callback(self, func, event_data):
"""Trigger a callback function with passed event_data parameters. In case func is a string,
the callable will be resolved from the passed model in event_data. This function is not intended to
be called directly but through state and transition callback definitions.
Args:
func (string, callable): The callback function.
1. First, if the func is callable, just call it
2. Second, we try to import string assuming it is a path to a func
3. Fallback to a model attribute
event_data (EventData): An EventData instance to pass to the
callback (if event sending is enabled) or to extract arguments
from (if event sending is disabled).
"""
func = self.resolve_callable(func, event_data)
res = func(event_data) if self.send_event else func(*event_data.args, **event_data.kwargs)
if inspect.isawaitable(res):
await res
@staticmethod
async def await_all(callables):
"""
Executes callables without parameters in parallel and collects their results.
Args:
callables (list): A list of callable functions
Returns:
list: A list of results. Using asyncio the list will be in the same order as the passed callables.
"""
return await asyncio.gather(*[func() for func in callables])
async def switch_model_context(self, model):
"""
This method is called by an `AsyncTransition` when all conditional tests have passed
and the transition will happen. This requires already running tasks to be cancelled.
Args:
model (object): The currently processed model
"""
for running_task in self.async_tasks.get(id(model), []):
if self.current_context.get() == running_task or running_task in self.protected_tasks:
continue
if running_task.done() is False:
_LOGGER.debug("Cancel running tasks...")
running_task.cancel()
async def process_context(self, func, model):
"""
This function is called by an `AsyncEvent` to make callbacks processed in Event._trigger cancellable.
Using asyncio this will result in a try-catch block catching CancelledEvents.
Args:
func (partial): The partial of Event._trigger with all parameters already assigned
model (object): The currently processed model
Returns:
bool: returns the success state of the triggered event
"""
if self.current_context.get() is None:
self.current_context.set(asyncio.current_task())
if id(model) in self.async_tasks:
self.async_tasks[id(model)].append(asyncio.current_task())
else:
self.async_tasks[id(model)] = [asyncio.current_task()]
try:
res = await self._process_async(func, model)
except asyncio.CancelledError:
res = False
finally:
self.async_tasks[id(model)].remove(asyncio.current_task())
if len(self.async_tasks[id(model)]) == 0:
del self.async_tasks[id(model)]
else:
res = await self._process_async(func, model)
return res
def remove_model(self, model):
"""Remove a model from the state machine. The model will still contain all previously added triggers
and callbacks, but will not receive updates when states or transitions are added to the Machine.
If an event queue is used, all queued events of that model will be removed."""
models = listify(model)
if self.has_queue == 'model':
for mod in models:
del self._transition_queue_dict[id(mod)]
self.models.remove(mod)
else:
for mod in models:
self.models.remove(mod)
if len(self._transition_queue) > 0:
queue = self._transition_queue
new_queue = [queue.popleft()] + [e for e in queue if e.args[0].model not in models]
self._transition_queue.clear()
self._transition_queue.extend(new_queue)
async def _can_trigger(self, model, trigger, *args, **kwargs):
state = self.get_model_state(model)
event_data = AsyncEventData(state, AsyncEvent(name=trigger, machine=self), self, model, args, kwargs)
for trigger_name in self.get_triggers(state):
if trigger_name != trigger:
continue
for transition in self.events[trigger_name].transitions[state.name]:
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue
event_data.transition = transition
try:
await self.callbacks(self.prepare_event, event_data)
await self.callbacks(transition.prepare, event_data)
if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
await self.callbacks(self.on_exception, event_data)
else:
raise
return False
def _process(self, trigger):
raise RuntimeError("AsyncMachine should not call `Machine._process`. Use `Machine._process_async` instead.")
async def _process_async(self, trigger, model):
# default processing
if not self.has_queue:
if not self._transition_queue:
# if trigger raises an Error, it has to be handled by the Machine.process caller
return await trigger()
raise MachineError("Attempt to process events synchronously while transition queue is not empty!")
self._transition_queue_dict[id(model)].append(trigger)
# another entry in the queue implies a running transition; skip immediate execution
if len(self._transition_queue_dict[id(model)]) > 1:
return True
while self._transition_queue_dict[id(model)]:
try:
await self._transition_queue_dict[id(model)][0]()
except BaseException:
# if a transition raises an exception, clear queue and delegate exception handling
self._transition_queue_dict[id(model)].clear()
raise
try:
self._transition_queue_dict[id(model)].popleft()
except KeyError:
return True
return True
class HierarchicalAsyncMachine(HierarchicalMachine, AsyncMachine):
"""Asynchronous variant of transitions.extensions.nesting.HierarchicalMachine.
An asynchronous hierarchical machine REQUIRES AsyncNestedStates, AsyncNestedEvent and AsyncNestedTransitions
(or any subclass of it) to operate.
"""
state_cls = NestedAsyncState
transition_cls = NestedAsyncTransition
event_cls = NestedAsyncEvent
async def trigger_event(self, model, trigger, *args, **kwargs):
"""Processes events recursively and forwards arguments if suitable events are found.
This function is usually bound to models with model and trigger arguments already
resolved as a partial. Execution will halt when a nested transition has been executed
successfully.
Args:
model (object): targeted model
trigger (str): event name
*args: positional parameters passed to the event and its callbacks
**kwargs: keyword arguments passed to the event and its callbacks
Returns:
bool: whether a transition has been executed successfully
Raises:
MachineError: When no suitable transition could be found and ignore_invalid_trigger
is not True. Note that a transition which is not executed due to conditions
is still considered valid.
"""
event_data = AsyncEventData(state=None, event=None, machine=self, model=model, args=args, kwargs=kwargs)
event_data.result = None
return await self.process_context(partial(self._trigger_event, event_data, trigger), model)
async def _trigger_event(self, event_data, trigger):
try:
with self():
res = await self._trigger_event_nested(event_data, trigger, None)
event_data.result = self._check_event_result(res, event_data.model, trigger)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
event_data.error = err
if self.on_exception:
await self.callbacks(self.on_exception, event_data)
else:
raise
finally:
try:
await self.callbacks(self.finalize_event, event_data)
_LOGGER.debug("%sExecuted machine finalize callbacks", self.name)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
_LOGGER.error("%sWhile executing finalize callbacks a %s occurred: %s.",
self.name,
type(err).__name__,
str(err))
return event_data.result
async def _trigger_event_nested(self, event_data, _trigger, _state_tree):
model = event_data.model
if _state_tree is None:
_state_tree = self.build_state_tree(listify(getattr(model, self.model_attribute)),
self.state_cls.separator)
res = {}
for key, value in _state_tree.items():
if value:
with self(key):
tmp = await self._trigger_event_nested(event_data, _trigger, value)
if tmp is not None:
res[key] = tmp
if not res.get(key, None) and _trigger in self.events:
tmp = await self.events[_trigger].trigger_nested(event_data)
if tmp is not None:
res[key] = tmp
return None if not res or all(v is None for v in res.values()) else any(res.values())
async def _can_trigger(self, model, trigger, *args, **kwargs):
state_tree = self.build_state_tree(getattr(model, self.model_attribute), self.state_cls.separator)
ordered_states = resolve_order(state_tree)
for state_path in ordered_states:
with self():
return await self._can_trigger_nested(model, trigger, state_path, *args, **kwargs)
async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
if trigger in self.events:
source_path = copy.copy(path)
while source_path:
event_data = AsyncEventData(self.get_state(source_path), AsyncEvent(name=trigger, machine=self), self,
model, args, kwargs)
state_name = self.state_cls.separator.join(source_path)
for transition in self.events[trigger].transitions.get(state_name, []):
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue
event_data.transition = transition
try:
await self.callbacks(self.prepare_event, event_data)
await self.callbacks(transition.prepare, event_data)
if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
await self.callbacks(self.on_exception, event_data)
else:
raise
source_path.pop(-1)
if path:
with self(path.pop(0)):
return await self._can_trigger_nested(model, trigger, path, *args, **kwargs)
return False
class AsyncTimeout(AsyncState):
"""
Adds timeout functionality to an asynchronous state. Timeouts are handled model-specific.
Attributes:
timeout (float): Seconds after which a timeout function should be
called.
on_timeout (list): Functions to call when a timeout is triggered.
runner (dict): Keeps track of running timeout tasks to cancel when a state is exited.
"""
dynamic_methods = ["on_timeout"]
def __init__(self, *args, **kwargs):
"""
Args:
**kwargs: If kwargs contain 'timeout', assign the float value to
self.timeout. If timeout is set, 'on_timeout' needs to be
passed with kwargs as well or an AttributeError will be thrown
if timeout is not passed or equal 0.
"""
self.timeout = kwargs.pop("timeout", 0)
self._on_timeout = None
if self.timeout > 0:
try:
self.on_timeout = kwargs.pop("on_timeout")
except KeyError:
raise AttributeError("Timeout state requires 'on_timeout' when timeout is set.") from None
else:
self.on_timeout = kwargs.pop("on_timeout", None)
self.runner = {}
super().__init__(*args, **kwargs)
async def enter(self, event_data):
"""
Extends `transitions.core.State.enter` by starting a timeout timer for
the current model when the state is entered and self.timeout is larger
than 0.
Args:
event_data (EventData): events representing the currently processed event.
"""
if self.timeout > 0:
self.runner[id(event_data.model)] = self.create_timer(event_data)
await super().enter(event_data)
async def exit(self, event_data):
"""
Cancels running timeout tasks stored in `self.runner` first (when not note) before
calling further exit callbacks.
Args:
event_data (EventData): Data representing the currently processed event.
Returns:
"""
timer_task = self.runner.get(id(event_data.model), None)
if timer_task is not None and not timer_task.done():
timer_task.cancel()
await super().exit(event_data)
def create_timer(self, event_data):
"""
Creates and returns a running timer. Shields self._process_timeout to prevent cancellation when
transitioning away from the current state (which cancels the timer) while processing timeout callbacks.
Args:
event_data (EventData): Data representing the currently processed event.
Returns (cancellable): A running timer with a cancel method
"""
async def _timeout():
try:
await asyncio.sleep(self.timeout)
await asyncio.shield(self._process_timeout(event_data))
except asyncio.CancelledError:
pass
return asyncio.ensure_future(_timeout())
async def _process_timeout(self, event_data):
_LOGGER.debug("%sTimeout state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_timeout, event_data)
_LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name)
@property
def on_timeout(self):
"""
List of strings and callables to be called when the state timeouts.
"""
return self._on_timeout
@on_timeout.setter
def on_timeout(self, value):
"""Listifies passed values and assigns them to on_timeout."""
self._on_timeout = listify(value)
class _DictionaryMock(dict):
def __init__(self, item):
super().__init__()
self._value = item
def __setitem__(self, key, item):
self._value = item
def __getitem__(self, key):
return self._value
def __repr__(self):
return repr("{{'*': {0}}}".format(self._value))
|