File: callback.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (263 lines) | stat: -rw-r--r-- 8,340 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import abc
from collections import OrderedDict
from threading import Event, Lock
from typing import Final, Optional

from moto.stepfunctions.parser.api import ActivityDoesNotExist, Arn
from moto.stepfunctions.parser.backend.activity import Activity, ActivityTask
from moto.stepfunctions.parser.utils import long_uid

CallbackId = str


class CallbackOutcome(abc.ABC):
    callback_id: Final[CallbackId]

    def __init__(self, callback_id: str):
        self.callback_id = callback_id


class CallbackOutcomeSuccess(CallbackOutcome):
    output: Final[str]

    def __init__(self, callback_id: CallbackId, output: str):
        super().__init__(callback_id=callback_id)
        self.output = output


class CallbackOutcomeFailure(CallbackOutcome):
    error: Final[Optional[str]]
    cause: Final[Optional[str]]

    def __init__(
        self, callback_id: CallbackId, error: Optional[str], cause: Optional[str]
    ):
        super().__init__(callback_id=callback_id)
        self.error = error
        self.cause = cause


class CallbackOutcomeTimedOut(CallbackOutcome):
    pass


class CallbackTimeoutError(TimeoutError):
    pass


class CallbackConsumerError(abc.ABC): ...


class CallbackConsumerTimeout(CallbackConsumerError):
    pass


class CallbackConsumerLeft(CallbackConsumerError):
    pass


class HeartbeatEndpoint:
    _mutex: Final[Lock]
    _next_heartbeat_event: Final[Event]
    _heartbeat_seconds: Final[int]

    def __init__(self, heartbeat_seconds: int):
        self._mutex = Lock()
        self._next_heartbeat_event = Event()
        self._heartbeat_seconds = heartbeat_seconds

    def clear_and_wait(self) -> bool:
        with self._mutex:
            if self._next_heartbeat_event.is_set():
                self._next_heartbeat_event.clear()
                return True
        return self._next_heartbeat_event.wait(timeout=self._heartbeat_seconds)

    def notify(self):
        with self._mutex:
            self._next_heartbeat_event.set()


class HeartbeatTimeoutError(TimeoutError):
    pass


class HeartbeatTimedOut(CallbackConsumerError):
    pass


class ActivityTaskStartOutcome:
    worker_name: Optional[str]

    def __init__(self, worker_name: Optional[str] = None):
        self.worker_name = worker_name


class ActivityTaskStartEndpoint:
    _next_activity_task_start_event: Final[Event]
    _outcome: Optional[ActivityTaskStartOutcome]

    def __init__(self):
        self._next_activity_task_start_event = Event()

    def wait(self, timeout_seconds: float) -> Optional[ActivityTaskStartOutcome]:
        self._next_activity_task_start_event.wait(timeout=timeout_seconds)
        return self._outcome

    def notify(self, activity_task: ActivityTaskStartOutcome) -> None:
        self._outcome = activity_task
        self._next_activity_task_start_event.set()


class CallbackEndpoint:
    callback_id: Final[CallbackId]
    _notify_event: Final[Event]
    _outcome: Optional[CallbackOutcome]
    consumer_error: Optional[CallbackConsumerError]
    _heartbeat_endpoint: Optional[HeartbeatEndpoint]

    def __init__(self, callback_id: CallbackId):
        self.callback_id = callback_id
        self._notify_event = Event()
        self._outcome = None
        self.consumer_error = None
        self._heartbeat_endpoint = None

    def setup_heartbeat_endpoint(self, heartbeat_seconds: int) -> HeartbeatEndpoint:
        self._heartbeat_endpoint = HeartbeatEndpoint(
            heartbeat_seconds=heartbeat_seconds
        )
        return self._heartbeat_endpoint

    def interrupt_all(self) -> None:
        # Interrupts all waiting processes on this endpoint.
        self._notify_event.set()
        heartbeat_endpoint = self._heartbeat_endpoint
        if heartbeat_endpoint is not None:
            heartbeat_endpoint.notify()

    def notify(self, outcome: CallbackOutcome):
        self._outcome = outcome
        self._notify_event.set()
        if self._heartbeat_endpoint:
            self._heartbeat_endpoint.notify()

    def notify_heartbeat(self) -> bool:
        if not self._heartbeat_endpoint:
            return False
        self._heartbeat_endpoint.notify()
        return True

    def wait(self, timeout: Optional[float] = None) -> Optional[CallbackOutcome]:
        self._notify_event.wait(timeout=timeout)
        return self._outcome

    def get_outcome(self) -> Optional[CallbackOutcome]:
        return self._outcome

    def report(self, consumer_error: CallbackConsumerError) -> None:
        self.consumer_error = consumer_error


class ActivityCallbackEndpoint(CallbackEndpoint):
    _activity_task_start_endpoint: Final[ActivityTaskStartEndpoint]
    _activity_input: Final[str]

    def __init__(self, callback_id: str, activity_input: str):
        super().__init__(callback_id=callback_id)
        self._activity_input = activity_input
        self._activity_task_start_endpoint = ActivityTaskStartEndpoint()

    def get_activity_input(self) -> str:
        return self._activity_input

    def get_activity_task_start_endpoint(self) -> ActivityTaskStartEndpoint:
        return self._activity_task_start_endpoint

    def notify_activity_task_start(self, worker_name: Optional[str]) -> None:
        self._activity_task_start_endpoint.notify(
            ActivityTaskStartOutcome(worker_name=worker_name)
        )


class CallbackNotifyConsumerError(RuntimeError):
    callback_consumer_error: CallbackConsumerError

    def __init__(self, callback_consumer_error: CallbackConsumerError):
        self.callback_consumer_error = callback_consumer_error


class CallbackOutcomeFailureError(RuntimeError):
    callback_outcome_failure: CallbackOutcomeFailure

    def __init__(self, callback_outcome_failure: CallbackOutcomeFailure):
        self.callback_outcome_failure = callback_outcome_failure


class CallbackPoolManager:
    _activity_store: Final[dict[CallbackId, Activity]]
    _pool: Final[dict[CallbackId, CallbackEndpoint]]

    def __init__(self, activity_store: dict[Arn, Activity]):
        self._activity_store = activity_store
        self._pool = OrderedDict()

    def get(self, callback_id: CallbackId) -> Optional[CallbackEndpoint]:
        return self._pool.get(callback_id)

    def add(self, callback_id: CallbackId) -> CallbackEndpoint:
        if callback_id in self._pool:
            raise ValueError("Duplicate callback token id value.")
        callback_endpoint = CallbackEndpoint(callback_id=callback_id)
        self._pool[callback_id] = callback_endpoint
        return callback_endpoint

    def add_activity_task(
        self, callback_id: CallbackId, activity_arn: Arn, activity_input: str
    ) -> ActivityCallbackEndpoint:
        if callback_id in self._pool:
            raise ValueError("Duplicate callback token id value.")

        maybe_activity: Optional[Activity] = self._activity_store.get(activity_arn)
        if maybe_activity is None:
            raise ActivityDoesNotExist()

        maybe_activity.add_task(
            ActivityTask(task_token=callback_id, task_input=activity_input)
        )

        callback_endpoint = ActivityCallbackEndpoint(
            callback_id=callback_id, activity_input=activity_input
        )
        self._pool[callback_id] = callback_endpoint
        return callback_endpoint

    def generate(self) -> CallbackEndpoint:
        return self.add(long_uid())

    def notify(self, callback_id: CallbackId, outcome: CallbackOutcome) -> bool:
        callback_endpoint = self._pool.get(callback_id, None)
        if callback_endpoint is None:
            return False

        consumer_error: Optional[CallbackConsumerError] = (
            callback_endpoint.consumer_error
        )
        if consumer_error is not None:
            raise CallbackNotifyConsumerError(callback_consumer_error=consumer_error)

        callback_endpoint.notify(outcome=outcome)
        return True

    def heartbeat(self, callback_id: CallbackId) -> bool:
        callback_endpoint = self._pool.get(callback_id, None)
        if callback_endpoint is None:
            return False

        consumer_error: Optional[CallbackConsumerError] = (
            callback_endpoint.consumer_error
        )
        if consumer_error is not None:
            raise CallbackNotifyConsumerError(callback_consumer_error=consumer_error)

        return callback_endpoint.notify_heartbeat()