File: branches_decl.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (118 lines) | stat: -rw-r--r-- 4,431 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import datetime
import threading
from typing import Final, Optional

from moto.stepfunctions.parser.api import ExecutionFailedEventDetails, HistoryEventType
from moto.stepfunctions.parser.asl.component.common.error_name.custom_error_name import (
    CustomErrorName,
)
from moto.stepfunctions.parser.asl.component.common.error_name.failure_event import (
    FailureEvent,
    FailureEventException,
)
from moto.stepfunctions.parser.asl.component.eval_component import EvalComponent
from moto.stepfunctions.parser.asl.component.program.program import Program
from moto.stepfunctions.parser.asl.component.state.exec.state_parallel.branch_worker import (
    BranchWorker,
)
from moto.stepfunctions.parser.asl.eval.environment import Environment
from moto.stepfunctions.parser.asl.eval.event.event_detail import EventDetails
from moto.stepfunctions.parser.asl.eval.program_state import ProgramError, ProgramState
from moto.utilities.collections import select_from_typed_dict


class BranchWorkerPool(BranchWorker.BranchWorkerComm):
    _mutex: Final[threading.Lock]
    _termination_event: Final[threading.Event]
    _active_workers_num: int

    _terminated_with_error: Optional[ExecutionFailedEventDetails]

    def __init__(self, workers_num: int):
        self._mutex = threading.Lock()
        self._termination_event = threading.Event()
        self._active_workers_num = workers_num

        self._terminated_with_error = None

    def on_terminated(self, env: Environment):
        if self._termination_event.is_set():
            return
        with self._mutex:
            end_program_state: ProgramState = env.program_state()
            if isinstance(end_program_state, ProgramError):
                self._terminated_with_error = select_from_typed_dict(
                    typed_dict=ExecutionFailedEventDetails,
                    obj=end_program_state.error or {},
                )
                self._termination_event.set()
            else:
                self._active_workers_num -= 1
                if self._active_workers_num == 0:
                    self._termination_event.set()

    def wait(self):
        self._termination_event.wait()

    def get_exit_event_details(self) -> Optional[ExecutionFailedEventDetails]:
        return self._terminated_with_error


class BranchesDecl(EvalComponent):
    def __init__(self, programs: list[Program]):
        self.programs: Final[list[Program]] = programs

    def _eval_body(self, env: Environment) -> None:
        # Input value for every state_parallel process.
        input_val = env.stack.pop()

        branch_worker_pool = BranchWorkerPool(workers_num=len(self.programs))

        branch_workers: list[BranchWorker] = []
        for program in self.programs:
            # Environment frame for this sub process.
            env_frame: Environment = env.open_inner_frame()
            env_frame.states.reset(input_value=input_val)

            # Launch the worker.
            worker = BranchWorker(
                branch_worker_comm=branch_worker_pool, program=program, env=env_frame
            )
            branch_workers.append(worker)

            worker.start()

        branch_worker_pool.wait()

        # Propagate exception if parallel task failed.
        exit_event_details: Optional[ExecutionFailedEventDetails] = (
            branch_worker_pool.get_exit_event_details()
        )
        if exit_event_details is not None:
            for branch_worker in branch_workers:
                branch_worker.stop(
                    stop_date=datetime.datetime.now(), cause=None, error=None
                )
                env.close_frame(branch_worker.env)

            exit_error_name = exit_event_details.get("error")
            raise FailureEventException(
                failure_event=FailureEvent(
                    env=env,
                    error_name=CustomErrorName(error_name=exit_error_name),
                    event_type=HistoryEventType.ExecutionFailed,
                    event_details=EventDetails(
                        executionFailedEventDetails=exit_event_details
                    ),
                )
            )

        # Collect the results and return.
        result_list = []

        for worker in branch_workers:
            env_frame = worker.env
            result_list.append(env_frame.states.get_input())
            env.close_frame(env_frame)

        env.stack.append(result_list)