1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
from collections import deque, defaultdict
from transitions.core import listify
from transitions.extensions.markup import HierarchicalMarkupMachine
_placeholder_body = "raise RuntimeError('This should be overridden')"
def generate_base_model(config):
m = HierarchicalMarkupMachine(**config)
triggers = set()
markup = m.markup
model_attribute = markup.get("model_attribute", "state")
trigger_block = ""
state_block = ""
callback_block = ""
callbacks = set(
[cb for cb in markup["prepare_event"]]
+ [cb for cb in markup["before_state_change"]]
+ [cb for cb in markup["after_state_change"]]
+ [cb for cb in markup["on_exception"]]
+ [cb for cb in markup["on_final"]]
+ [cb for cb in markup["finalize_event"]]
)
for trans in markup["transitions"]:
triggers.add(trans["trigger"])
stack = [(markup["states"], markup["transitions"], "")]
has_nested_states = any("children" in state for state in markup["states"])
while stack:
states, transitions, prefix = stack.pop()
for state in states:
state_name = state["name"]
state_block += (
f" def is_{prefix}{state_name}(self{', allow_substates=False' if has_nested_states else ''})"
f" -> bool: {_placeholder_body}\n"
)
if m.auto_transitions:
state_block += (
f" def to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n"
f" def may_to_{prefix}{state_name}(self) -> bool: {_placeholder_body}\n"
)
state_block += "\n"
for tran in transitions:
triggers.add(tran["trigger"])
new_set = set(
[cb for cb in tran.get("prepare", [])]
+ [cb for cb in tran.get("conditions", [])]
+ [cb for cb in tran.get("unless", [])]
+ [cb for cb in tran.get("before", [])]
+ [cb for cb in tran.get("after", [])]
)
callbacks.update(new_set)
if "children" in state:
stack.append((state["children"], state.get("transitions", []), prefix + state_name + "_"))
for trigger_name in triggers:
trigger_block += (
f" def {trigger_name}(self) -> bool: {_placeholder_body}\n"
f" def may_{trigger_name}(self) -> bool: {_placeholder_body}\n"
)
extra_params = "event_data: EventData" if m.send_event else "*args: List[Any], **kwargs: Dict[str, Any]"
for callback_name in callbacks:
if isinstance(callback_name, str):
callback_block += (f" @abstractmethod\n"
f" def {callback_name}(self, {extra_params}) -> Optional[bool]: ...\n")
template = f"""# autogenerated by transitions
from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from transitions.core import CallbacksArg, StateIdentifier, EventData
class BaseModel(metaclass=ABCMeta):
{model_attribute}: "StateIdentifier" = ""
def trigger(self, name: str) -> bool: {_placeholder_body}
{trigger_block}
{state_block}\
{callback_block}"""
return template
def with_model_definitions(cls):
add_model = getattr(cls, "add_model")
def add_model_override(self, model, initial=None):
self.model_override = True
for model in listify(model):
model = self if model == "self" else model
for name, specs in TriggerPlaceholder.definitions.get(model.__class__).items():
for spec in specs:
if isinstance(spec, list):
self.add_transition(name, *spec)
elif isinstance(spec, dict):
self.add_transition(name, **spec)
else:
raise ValueError("Cannot add {} for event {} to machine", spec, name)
add_model(self, model, initial)
setattr(cls, 'add_model', add_model_override)
return cls
class TriggerPlaceholder:
definitions = defaultdict(lambda: defaultdict(list))
def __init__(self, configs):
self.configs = deque(configs)
def __set_name__(self, owner, name):
for config in self.configs:
TriggerPlaceholder.definitions[owner][name].append(config)
def __call__(self, *args, **kwargs):
raise RuntimeError("Trigger was not initialized correctly!")
def event(*configs):
return TriggerPlaceholder(configs)
def add_transitions(*configs):
def _outer(trigger_func):
if isinstance(trigger_func, TriggerPlaceholder):
for config in reversed(configs):
trigger_func.configs.appendleft(config)
else:
trigger_func = TriggerPlaceholder(configs)
return trigger_func
return _outer
def transition(source, dest=None, conditions=None, unless=None, before=None, after=None, prepare=None):
return {"source": source, "dest": dest, "conditions": conditions, "unless": unless, "before": before,
"after": after, "prepare": prepare}
|