File: execute_state.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (293 lines) | stat: -rw-r--r-- 12,989 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import abc
import copy
import logging
import threading
from threading import Thread
from typing import Any, Optional

from moto.stepfunctions.parser.api import HistoryEventType, TaskFailedEventDetails
from moto.stepfunctions.parser.asl.component.common.catch.catch_decl import CatchDecl
from moto.stepfunctions.parser.asl.component.common.catch.catch_outcome import (
    CatchOutcome,
)
from moto.stepfunctions.parser.asl.component.common.error_name.failure_event import (
    FailureEvent,
    FailureEventException,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name import (
    StatesErrorName,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name_type import (
    StatesErrorNameType,
)
from moto.stepfunctions.parser.asl.component.common.path.result_path import ResultPath
from moto.stepfunctions.parser.asl.component.common.result_selector import (
    ResultSelector,
)
from moto.stepfunctions.parser.asl.component.common.retry.retry_decl import RetryDecl
from moto.stepfunctions.parser.asl.component.common.retry.retry_outcome import (
    RetryOutcome,
)
from moto.stepfunctions.parser.asl.component.common.timeouts.heartbeat import (
    Heartbeat,
    HeartbeatSeconds,
)
from moto.stepfunctions.parser.asl.component.common.timeouts.timeout import (
    EvalTimeoutError,
    Timeout,
    TimeoutSeconds,
)
from moto.stepfunctions.parser.asl.component.state.state import CommonStateField
from moto.stepfunctions.parser.asl.component.state.state_props import StateProps
from moto.stepfunctions.parser.asl.eval.environment import Environment
from moto.stepfunctions.parser.asl.eval.event.event_detail import EventDetails
from moto.stepfunctions.parser.utils import TMP_THREADS

LOG = logging.getLogger(__name__)


class ExecutionState(CommonStateField, abc.ABC):
    def __init__(
        self,
        state_entered_event_type: HistoryEventType,
        state_exited_event_type: Optional[HistoryEventType],
    ):
        super().__init__(
            state_entered_event_type=state_entered_event_type,
            state_exited_event_type=state_exited_event_type,
        )
        # ResultPath (Optional)
        # Specifies where (in the input) to place the results of executing the state_task that's specified in Resource.
        # The input is then filtered as specified by the OutputPath field (if present) before being used as the
        # state's output.
        self.result_path: Optional[ResultPath] = None

        # ResultSelector (Optional)
        # Pass a collection of key value pairs, where the values are static or selected from the result.
        self.result_selector: Optional[ResultSelector] = None

        # Retry (Optional)
        # An array of objects, called Retriers, that define a retry policy if the state encounters runtime errors.
        self.retry: Optional[RetryDecl] = None

        # Catch (Optional)
        # An array of objects, called Catchers, that define a fallback state. This state is executed if the state
        # encounters runtime errors and its retry policy is exhausted or isn't defined.
        self.catch: Optional[CatchDecl] = None

        # TimeoutSeconds (Optional)
        # If the state_task runs longer than the specified seconds, this state fails with a States.Timeout error name.
        # Must be a positive, non-zero integer. If not provided, the default value is 99999999. The count begins after
        # the state_task has been started, for example, when ActivityStarted or LambdaFunctionStarted are logged in the
        # Execution event history.
        # TimeoutSecondsPath (Optional)
        # If you want to provide a timeout value dynamically from the state input using a reference path, use
        # TimeoutSecondsPath. When resolved, the reference path must select fields whose values are positive integers.
        # A Task state cannot include both TimeoutSeconds and TimeoutSecondsPath
        # TimeoutSeconds and TimeoutSecondsPath fields are encoded by the timeout type.
        self.timeout: Timeout = TimeoutSeconds(
            timeout_seconds=TimeoutSeconds.DEFAULT_TIMEOUT_SECONDS
        )

        # HeartbeatSeconds (Optional)
        # If more time than the specified seconds elapses between heartbeats from the task, this state fails with a
        # States.Timeout error name. Must be a positive, non-zero integer less than the number of seconds specified in
        # the TimeoutSeconds field. If not provided, the default value is 99999999. For Activities, the count begins
        # when GetActivityTask receives a token and ActivityStarted is logged in the Execution event history.
        # HeartbeatSecondsPath (Optional)
        # If you want to provide a heartbeat value dynamically from the state input using a reference path, use
        # HeartbeatSecondsPath. When resolved, the reference path must select fields whose values are positive integers.
        # A Task state cannot include both HeartbeatSeconds and HeartbeatSecondsPath
        # HeartbeatSeconds and HeartbeatSecondsPath fields are encoded by the Heartbeat type.
        self.heartbeat: Optional[Heartbeat] = None

    def from_state_props(self, state_props: StateProps) -> None:
        super().from_state_props(state_props=state_props)
        self.result_path = state_props.get(ResultPath) or ResultPath(
            result_path_src=ResultPath.DEFAULT_PATH
        )
        self.result_selector = state_props.get(ResultSelector)
        self.retry = state_props.get(RetryDecl)
        self.catch = state_props.get(CatchDecl)

        # If provided, the "HeartbeatSeconds" interval MUST be smaller than the "TimeoutSeconds" value.
        # If not provided, the default value of "TimeoutSeconds" is 60.
        timeout = state_props.get(Timeout)
        heartbeat = state_props.get(Heartbeat)
        if isinstance(timeout, TimeoutSeconds) and isinstance(
            heartbeat, HeartbeatSeconds
        ):
            if timeout.timeout_seconds <= heartbeat.heartbeat_seconds:
                raise RuntimeError(
                    f"'HeartbeatSeconds' interval MUST be smaller than the 'TimeoutSeconds' value, "
                    f"got '{timeout.timeout_seconds}' and '{heartbeat.heartbeat_seconds}' respectively."
                )
        if heartbeat is not None and timeout is None:
            timeout = TimeoutSeconds(timeout_seconds=60, is_default=True)

        if timeout is not None:
            self.timeout = timeout
        if heartbeat is not None:
            self.heartbeat = heartbeat

    def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
        if isinstance(ex, FailureEventException):
            return ex.failure_event
        LOG.warning(
            "State Task encountered an unhandled exception that lead to a State.Runtime error."
        )
        return FailureEvent(
            env=env,
            error_name=StatesErrorName(typ=StatesErrorNameType.StatesRuntime),
            event_type=HistoryEventType.TaskFailed,
            event_details=EventDetails(
                taskFailedEventDetails=TaskFailedEventDetails(
                    error=StatesErrorNameType.StatesRuntime.to_name(),
                    cause=str(ex),
                )
            ),
        )

    @abc.abstractmethod
    def _eval_execution(self, env: Environment) -> None: ...

    def _handle_retry(
        self, env: Environment, failure_event: FailureEvent
    ) -> RetryOutcome:
        env.stack.append(failure_event.error_name)
        self.retry.eval(env)
        res: RetryOutcome = env.stack.pop()
        if res == RetryOutcome.CanRetry:
            retry_count = env.states.context_object.context_object_data["State"][
                "RetryCount"
            ]
            env.states.context_object.context_object_data["State"]["RetryCount"] = (
                retry_count + 1
            )
        return res

    def _handle_catch(self, env: Environment, failure_event: FailureEvent) -> None:
        env.stack.append(failure_event)
        self.catch.eval(env)

    def _handle_uncaught(self, env: Environment, failure_event: FailureEvent) -> None:
        self._terminate_with_event(env=env, failure_event=failure_event)

    @staticmethod
    def _terminate_with_event(env: Environment, failure_event: FailureEvent) -> None:
        raise FailureEventException(failure_event=failure_event)

    def _evaluate_with_timeout(self, env: Environment) -> None:
        self.timeout.eval(env=env)
        timeout_seconds: int = env.stack.pop()

        frame: Environment = env.open_frame()
        frame.states.reset(input_value=env.states.get_input())
        frame.stack = copy.deepcopy(env.stack)
        execution_outputs: list[Any] = []
        execution_exceptions: list[Optional[Exception]] = [None]
        terminated_event = threading.Event()

        def _exec_and_notify():
            try:
                self._eval_execution(frame)
                execution_outputs.extend(frame.stack)
            except Exception as ex:
                execution_exceptions.append(ex)
            terminated_event.set()

        thread = Thread(target=_exec_and_notify, daemon=True)
        TMP_THREADS.append(thread)
        thread.start()

        finished_on_time: bool = terminated_event.wait(timeout_seconds)
        frame.set_ended()
        env.close_frame(frame)

        execution_exception = execution_exceptions.pop()
        if execution_exception:
            raise execution_exception

        if not finished_on_time:
            raise EvalTimeoutError()

        execution_output = execution_outputs.pop()
        env.stack.append(execution_output)

        if not self._is_language_query_jsonpath():
            env.states.set_result(execution_output)

        if self.assign_decl:
            self.assign_decl.eval(env=env)

        if self.result_selector:
            self.result_selector.eval(env=env)

        if self.result_path:
            self.result_path.eval(env)
        else:
            res = env.stack.pop()
            env.states.reset(input_value=res)

    @staticmethod
    def _construct_error_output_value(failure_event: FailureEvent) -> Any:
        specs_event_details = list(failure_event.event_details.values())
        if (
            len(specs_event_details) != 1
            and "error" in specs_event_details
            and "cause" in specs_event_details
        ):
            raise RuntimeError(
                f"Internal Error: invalid event details declaration in FailureEvent: '{failure_event}'."
            )
        spec_event_details: dict = list(failure_event.event_details.values())[0]
        return {
            # If no cause or error fields are given, AWS binds an empty string; otherwise it attaches the value.
            "Error": spec_event_details.get("error", ""),
            "Cause": spec_event_details.get("cause", ""),
        }

    def _eval_state(self, env: Environment) -> None:
        # Initialise the retry counter for execution states.
        env.states.context_object.context_object_data["State"]["RetryCount"] = 0

        # Attempt to evaluate the state's logic through until it's successful, caught, or retries have run out.
        while env.is_running():
            try:
                self._evaluate_with_timeout(env)
                break
            except Exception as ex:
                failure_event: FailureEvent = self._from_error(env=env, ex=ex)
                env.event_manager.add_event(
                    context=env.event_history_context,
                    event_type=failure_event.event_type,
                    event_details=failure_event.event_details,
                )
                error_output = self._construct_error_output_value(
                    failure_event=failure_event
                )
                env.states.set_error_output(error_output)
                env.states.set_result(error_output)

                if self.retry:
                    retry_outcome: RetryOutcome = self._handle_retry(
                        env=env, failure_event=failure_event
                    )
                    if retry_outcome == RetryOutcome.CanRetry:
                        continue

                if self.catch:
                    self._handle_catch(env=env, failure_event=failure_event)
                    catch_outcome: CatchOutcome = env.stack[-1]
                    if catch_outcome == CatchOutcome.Caught:
                        break

                self._handle_uncaught(env=env, failure_event=failure_event)

    def _eval_state_output(self, env: Environment) -> None:
        # Obtain a reference to the state output.
        output = env.stack[-1]
        # CatcherOutputs (i.e. outputs of Catch blocks) are never subjects of output normalisers,
        # the entire value is instead passed by value as input to the next state, or program output.
        if not isinstance(output, CatchOutcome):
            super()._eval_state_output(env=env)