File: diagrams_base.py

package info (click to toggle)
python-transitions 0.9.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,728 kB
  • sloc: python: 8,765; makefile: 10; sh: 7
file content (154 lines) | stat: -rw-r--r-- 5,963 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
    transitions.extensions.diagrams_base
    ------------------------------------

    The class BaseGraph implements the common ground for Graphviz backends.
"""

import copy
import abc
import logging

_LOGGER = logging.getLogger(__name__)
_LOGGER.addHandler(logging.NullHandler())


class BaseGraph(metaclass=abc.ABCMeta):
    """Provides the common foundation for graphs generated either with pygraphviz or graphviz. This abstract class
    should not be instantiated directly. Use .(py)graphviz.(Nested)Graph instead.
    Attributes:
        machine (GraphMachine): The associated GraphMachine
        fsm_graph (object): The AGraph-like object that holds the graphviz information
    """

    def __init__(self, machine):
        self.machine = machine
        self.fsm_graph = None
        self.generate()

    @abc.abstractmethod
    def generate(self):
        """Triggers the generation of a graph."""

    @abc.abstractmethod
    def set_previous_transition(self, src, dst):
        """Sets the styling of an edge to 'previous'
        Args:
            src (str): Name of the source state
            dst (str): Name of the destination
        """

    @abc.abstractmethod
    def reset_styling(self):
        """Resets the styling of the currently generated graph."""

    @abc.abstractmethod
    def set_node_style(self, state, style):
        """Sets the style of nodes associated with a model state
        Args:
            state (str, Enum or list): Name of the state(s) or Enum(s)
            style (str): Name of the style
        """

    @abc.abstractmethod
    def get_graph(self, title=None, roi_state=None):
        """Returns a graph object.
        Args:
            title (str): Title of the generated graph
            roi_state (State): If not None, the returned graph will only contain edges and states connected to it.
        Returns:
             A graph instance with a `draw` that allows to render the graph.
        """

    def _convert_state_attributes(self, state):
        label = state.get("label", state["name"])
        if self.machine.show_state_attributes:
            if "tags" in state:
                label += " [" + ", ".join(state["tags"]) + "]"
            if "on_enter" in state:
                label += r"\l- enter:\l  + " + r"\l  + ".join(state["on_enter"])
            if "on_exit" in state:
                label += r"\l- exit:\l  + " + r"\l  + ".join(state["on_exit"])
            if "timeout" in state:
                label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
        # end each label with a left-aligned newline
        return label + r"\l"

    def _get_state_names(self, state):
        if isinstance(state, (list, tuple, set)):
            for res in state:
                for inner in self._get_state_names(res):
                    yield inner
        else:
            yield self.machine.state_cls.separator.join(self.machine._get_enum_path(state))\
                if hasattr(state, "name") else state

    def _transition_label(self, tran):
        edge_label = tran.get("label", tran["trigger"])
        if "dest" not in tran:
            edge_label += " [internal]"
        if self.machine.show_conditions and any(prop in tran for prop in ["conditions", "unless"]):
            edge_label = "{edge_label} [{conditions}]".format(
                edge_label=edge_label,
                conditions=" & ".join(
                    tran.get("conditions", []) + ["!" + u for u in tran.get("unless", [])]
                ),
            )
        return edge_label

    def _get_global_name(self, path):
        if path:
            state = path.pop(0)
            with self.machine(state):
                return self._get_global_name(path)
        else:
            return self.machine.get_global_name()

    def _flatten(self, *lists):
        return (e for a in lists for e in
                (self._flatten(*a)
                 if isinstance(a, (tuple, list))
                 else (a.name if hasattr(a, 'name') else a,)))

    def _get_elements(self):
        states = []
        transitions = []
        try:
            markup = self.machine.get_markup_config()
            queue = [([], markup)]

            while queue:
                prefix, scope = queue.pop(0)
                for transition in scope.get("transitions", []):
                    if prefix:
                        tran = copy.copy(transition)
                        tran["source"] = self.machine.state_cls.separator.join(
                            prefix + [tran["source"]]
                        )
                        if "dest" in tran:  # don't do this for internal transitions
                            tran["dest"] = self.machine.state_cls.separator.join(
                                prefix + [tran["dest"]]
                            )
                    else:
                        tran = transition
                    transitions.append(tran)
                for state in scope.get("children", []) + scope.get("states", []):
                    if not prefix:
                        states.append(state)

                    ini = state.get("initial", [])
                    if not isinstance(ini, list):
                        ini = ini.name if hasattr(ini, "name") else ini
                        tran = dict(
                            trigger="",
                            source=self.machine.state_cls.separator.join(prefix + [state["name"]]),
                            dest=self.machine.state_cls.separator.join(
                                prefix + [state["name"], ini]
                            ),
                        )
                        transitions.append(tran)
                    if state.get("children", []):
                        queue.append((prefix + [state["name"]], state))
        except KeyError:
            _LOGGER.error("Graph creation incomplete!")
        return states, transitions