1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
|
import abc
import copy
import logging
import threading
from threading import Thread
from typing import Any, Optional
from moto.stepfunctions.parser.api import HistoryEventType, TaskFailedEventDetails
from moto.stepfunctions.parser.asl.component.common.catch.catch_decl import CatchDecl
from moto.stepfunctions.parser.asl.component.common.catch.catch_outcome import (
CatchOutcome,
)
from moto.stepfunctions.parser.asl.component.common.error_name.failure_event import (
FailureEvent,
FailureEventException,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name import (
StatesErrorName,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name_type import (
StatesErrorNameType,
)
from moto.stepfunctions.parser.asl.component.common.path.result_path import ResultPath
from moto.stepfunctions.parser.asl.component.common.result_selector import (
ResultSelector,
)
from moto.stepfunctions.parser.asl.component.common.retry.retry_decl import RetryDecl
from moto.stepfunctions.parser.asl.component.common.retry.retry_outcome import (
RetryOutcome,
)
from moto.stepfunctions.parser.asl.component.common.timeouts.heartbeat import (
Heartbeat,
HeartbeatSeconds,
)
from moto.stepfunctions.parser.asl.component.common.timeouts.timeout import (
EvalTimeoutError,
Timeout,
TimeoutSeconds,
)
from moto.stepfunctions.parser.asl.component.state.state import CommonStateField
from moto.stepfunctions.parser.asl.component.state.state_props import StateProps
from moto.stepfunctions.parser.asl.eval.environment import Environment
from moto.stepfunctions.parser.asl.eval.event.event_detail import EventDetails
from moto.stepfunctions.parser.utils import TMP_THREADS
LOG = logging.getLogger(__name__)
class ExecutionState(CommonStateField, abc.ABC):
def __init__(
self,
state_entered_event_type: HistoryEventType,
state_exited_event_type: Optional[HistoryEventType],
):
super().__init__(
state_entered_event_type=state_entered_event_type,
state_exited_event_type=state_exited_event_type,
)
# ResultPath (Optional)
# Specifies where (in the input) to place the results of executing the state_task that's specified in Resource.
# The input is then filtered as specified by the OutputPath field (if present) before being used as the
# state's output.
self.result_path: Optional[ResultPath] = None
# ResultSelector (Optional)
# Pass a collection of key value pairs, where the values are static or selected from the result.
self.result_selector: Optional[ResultSelector] = None
# Retry (Optional)
# An array of objects, called Retriers, that define a retry policy if the state encounters runtime errors.
self.retry: Optional[RetryDecl] = None
# Catch (Optional)
# An array of objects, called Catchers, that define a fallback state. This state is executed if the state
# encounters runtime errors and its retry policy is exhausted or isn't defined.
self.catch: Optional[CatchDecl] = None
# TimeoutSeconds (Optional)
# If the state_task runs longer than the specified seconds, this state fails with a States.Timeout error name.
# Must be a positive, non-zero integer. If not provided, the default value is 99999999. The count begins after
# the state_task has been started, for example, when ActivityStarted or LambdaFunctionStarted are logged in the
# Execution event history.
# TimeoutSecondsPath (Optional)
# If you want to provide a timeout value dynamically from the state input using a reference path, use
# TimeoutSecondsPath. When resolved, the reference path must select fields whose values are positive integers.
# A Task state cannot include both TimeoutSeconds and TimeoutSecondsPath
# TimeoutSeconds and TimeoutSecondsPath fields are encoded by the timeout type.
self.timeout: Timeout = TimeoutSeconds(
timeout_seconds=TimeoutSeconds.DEFAULT_TIMEOUT_SECONDS
)
# HeartbeatSeconds (Optional)
# If more time than the specified seconds elapses between heartbeats from the task, this state fails with a
# States.Timeout error name. Must be a positive, non-zero integer less than the number of seconds specified in
# the TimeoutSeconds field. If not provided, the default value is 99999999. For Activities, the count begins
# when GetActivityTask receives a token and ActivityStarted is logged in the Execution event history.
# HeartbeatSecondsPath (Optional)
# If you want to provide a heartbeat value dynamically from the state input using a reference path, use
# HeartbeatSecondsPath. When resolved, the reference path must select fields whose values are positive integers.
# A Task state cannot include both HeartbeatSeconds and HeartbeatSecondsPath
# HeartbeatSeconds and HeartbeatSecondsPath fields are encoded by the Heartbeat type.
self.heartbeat: Optional[Heartbeat] = None
def from_state_props(self, state_props: StateProps) -> None:
super().from_state_props(state_props=state_props)
self.result_path = state_props.get(ResultPath) or ResultPath(
result_path_src=ResultPath.DEFAULT_PATH
)
self.result_selector = state_props.get(ResultSelector)
self.retry = state_props.get(RetryDecl)
self.catch = state_props.get(CatchDecl)
# If provided, the "HeartbeatSeconds" interval MUST be smaller than the "TimeoutSeconds" value.
# If not provided, the default value of "TimeoutSeconds" is 60.
timeout = state_props.get(Timeout)
heartbeat = state_props.get(Heartbeat)
if isinstance(timeout, TimeoutSeconds) and isinstance(
heartbeat, HeartbeatSeconds
):
if timeout.timeout_seconds <= heartbeat.heartbeat_seconds:
raise RuntimeError(
f"'HeartbeatSeconds' interval MUST be smaller than the 'TimeoutSeconds' value, "
f"got '{timeout.timeout_seconds}' and '{heartbeat.heartbeat_seconds}' respectively."
)
if heartbeat is not None and timeout is None:
timeout = TimeoutSeconds(timeout_seconds=60, is_default=True)
if timeout is not None:
self.timeout = timeout
if heartbeat is not None:
self.heartbeat = heartbeat
def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
if isinstance(ex, FailureEventException):
return ex.failure_event
LOG.warning(
"State Task encountered an unhandled exception that lead to a State.Runtime error."
)
return FailureEvent(
env=env,
error_name=StatesErrorName(typ=StatesErrorNameType.StatesRuntime),
event_type=HistoryEventType.TaskFailed,
event_details=EventDetails(
taskFailedEventDetails=TaskFailedEventDetails(
error=StatesErrorNameType.StatesRuntime.to_name(),
cause=str(ex),
)
),
)
@abc.abstractmethod
def _eval_execution(self, env: Environment) -> None: ...
def _handle_retry(
self, env: Environment, failure_event: FailureEvent
) -> RetryOutcome:
env.stack.append(failure_event.error_name)
self.retry.eval(env)
res: RetryOutcome = env.stack.pop()
if res == RetryOutcome.CanRetry:
retry_count = env.states.context_object.context_object_data["State"][
"RetryCount"
]
env.states.context_object.context_object_data["State"]["RetryCount"] = (
retry_count + 1
)
return res
def _handle_catch(self, env: Environment, failure_event: FailureEvent) -> None:
env.stack.append(failure_event)
self.catch.eval(env)
def _handle_uncaught(self, env: Environment, failure_event: FailureEvent) -> None:
self._terminate_with_event(env=env, failure_event=failure_event)
@staticmethod
def _terminate_with_event(env: Environment, failure_event: FailureEvent) -> None:
raise FailureEventException(failure_event=failure_event)
def _evaluate_with_timeout(self, env: Environment) -> None:
self.timeout.eval(env=env)
timeout_seconds: int = env.stack.pop()
frame: Environment = env.open_frame()
frame.states.reset(input_value=env.states.get_input())
frame.stack = copy.deepcopy(env.stack)
execution_outputs: list[Any] = []
execution_exceptions: list[Optional[Exception]] = [None]
terminated_event = threading.Event()
def _exec_and_notify():
try:
self._eval_execution(frame)
execution_outputs.extend(frame.stack)
except Exception as ex:
execution_exceptions.append(ex)
terminated_event.set()
thread = Thread(target=_exec_and_notify, daemon=True)
TMP_THREADS.append(thread)
thread.start()
finished_on_time: bool = terminated_event.wait(timeout_seconds)
frame.set_ended()
env.close_frame(frame)
execution_exception = execution_exceptions.pop()
if execution_exception:
raise execution_exception
if not finished_on_time:
raise EvalTimeoutError()
execution_output = execution_outputs.pop()
env.stack.append(execution_output)
if not self._is_language_query_jsonpath():
env.states.set_result(execution_output)
if self.assign_decl:
self.assign_decl.eval(env=env)
if self.result_selector:
self.result_selector.eval(env=env)
if self.result_path:
self.result_path.eval(env)
else:
res = env.stack.pop()
env.states.reset(input_value=res)
@staticmethod
def _construct_error_output_value(failure_event: FailureEvent) -> Any:
specs_event_details = list(failure_event.event_details.values())
if (
len(specs_event_details) != 1
and "error" in specs_event_details
and "cause" in specs_event_details
):
raise RuntimeError(
f"Internal Error: invalid event details declaration in FailureEvent: '{failure_event}'."
)
spec_event_details: dict = list(failure_event.event_details.values())[0]
return {
# If no cause or error fields are given, AWS binds an empty string; otherwise it attaches the value.
"Error": spec_event_details.get("error", ""),
"Cause": spec_event_details.get("cause", ""),
}
def _eval_state(self, env: Environment) -> None:
# Initialise the retry counter for execution states.
env.states.context_object.context_object_data["State"]["RetryCount"] = 0
# Attempt to evaluate the state's logic through until it's successful, caught, or retries have run out.
while env.is_running():
try:
self._evaluate_with_timeout(env)
break
except Exception as ex:
failure_event: FailureEvent = self._from_error(env=env, ex=ex)
env.event_manager.add_event(
context=env.event_history_context,
event_type=failure_event.event_type,
event_details=failure_event.event_details,
)
error_output = self._construct_error_output_value(
failure_event=failure_event
)
env.states.set_error_output(error_output)
env.states.set_result(error_output)
if self.retry:
retry_outcome: RetryOutcome = self._handle_retry(
env=env, failure_event=failure_event
)
if retry_outcome == RetryOutcome.CanRetry:
continue
if self.catch:
self._handle_catch(env=env, failure_event=failure_event)
catch_outcome: CatchOutcome = env.stack[-1]
if catch_outcome == CatchOutcome.Caught:
break
self._handle_uncaught(env=env, failure_event=failure_event)
def _eval_state_output(self, env: Environment) -> None:
# Obtain a reference to the state output.
output = env.stack[-1]
# CatcherOutputs (i.e. outputs of Catch blocks) are never subjects of output normalisers,
# the entire value is instead passed by value as input to the next state, or program output.
if not isinstance(output, CatchOutcome):
super()._eval_state_output(env=env)
|