File: state_task.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (97 lines) | stat: -rw-r--r-- 3,646 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations

import abc
from typing import Optional

from moto.stepfunctions.parser.api import HistoryEventType, TaskTimedOutEventDetails
from moto.stepfunctions.parser.asl.component.common.error_name.failure_event import (
    FailureEvent,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name import (
    StatesErrorName,
)
from moto.stepfunctions.parser.asl.component.common.error_name.states_error_name_type import (
    StatesErrorNameType,
)
from moto.stepfunctions.parser.asl.component.common.parargs import Parargs
from moto.stepfunctions.parser.asl.component.state.exec.execute_state import (
    ExecutionState,
)
from moto.stepfunctions.parser.asl.component.state.exec.state_task.credentials import (
    Credentials,
    StateCredentials,
)
from moto.stepfunctions.parser.asl.component.state.exec.state_task.service.resource import (
    Resource,
)
from moto.stepfunctions.parser.asl.component.state.state_props import StateProps
from moto.stepfunctions.parser.asl.eval.environment import Environment
from moto.stepfunctions.parser.asl.eval.event.event_detail import EventDetails


class StateTask(ExecutionState, abc.ABC):
    resource: Resource
    parargs: Optional[Parargs]
    credentials: Optional[Credentials]

    def __init__(self):
        super().__init__(
            state_entered_event_type=HistoryEventType.TaskStateEntered,
            state_exited_event_type=HistoryEventType.TaskStateExited,
        )

    def from_state_props(self, state_props: StateProps) -> None:
        super().from_state_props(state_props)
        self.resource = state_props.get(Resource)
        self.parargs = state_props.get(Parargs)
        self.credentials = state_props.get(Credentials)

    def _get_supported_parameters(self) -> Optional[set[str]]:  # noqa
        return None

    def _eval_parameters(self, env: Environment) -> dict:
        # Eval raw parameters.
        parameters = {}
        if self.parargs is not None:
            self.parargs.eval(env=env)
            parameters = env.stack.pop()

        # Handle supported parameters.
        supported_parameters = self._get_supported_parameters()
        if supported_parameters:
            unsupported_parameters: list[str] = [
                parameter
                for parameter in parameters.keys()
                if parameter not in supported_parameters
            ]
            for unsupported_parameter in unsupported_parameters:
                parameters.pop(unsupported_parameter, None)

        return parameters

    def _eval_state_credentials(self, env: Environment) -> StateCredentials:
        if not self.credentials:
            state_credentials = StateCredentials(
                role_arn=env.aws_execution_details.role_arn
            )
        else:
            self.credentials.eval(env=env)
            state_credentials = env.stack.pop()
        return state_credentials

    def _get_timed_out_failure_event(self, env: Environment) -> FailureEvent:
        return FailureEvent(
            env=env,
            error_name=StatesErrorName(typ=StatesErrorNameType.StatesTimeout),
            event_type=HistoryEventType.TaskTimedOut,
            event_details=EventDetails(
                taskTimedOutEventDetails=TaskTimedOutEventDetails(
                    error=StatesErrorNameType.StatesTimeout.to_name(),
                )
            ),
        )

    def _from_error(self, env: Environment, ex: Exception) -> FailureEvent:
        if isinstance(ex, TimeoutError):
            return self._get_timed_out_failure_event(env)
        return super()._from_error(env=env, ex=ex)