File: preprocessor.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (139 lines) | stat: -rw-r--r-- 5,867 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import enum
from typing import Final

from moto.stepfunctions.parser.asl.antlr.runtime.ASLParser import ASLParser
from moto.stepfunctions.parser.asl.component.common.parargs import Parameters
from moto.stepfunctions.parser.asl.component.common.path.input_path import InputPath
from moto.stepfunctions.parser.asl.component.common.path.result_path import ResultPath
from moto.stepfunctions.parser.asl.component.common.query_language import QueryLanguage
from moto.stepfunctions.parser.asl.component.common.result_selector import (
    ResultSelector,
)
from moto.stepfunctions.parser.asl.component.state.choice.state_choice import (
    StateChoice,
)
from moto.stepfunctions.parser.asl.component.state.exec.execute_state import (
    ExecutionState,
)
from moto.stepfunctions.parser.asl.component.state.state import CommonStateField
from moto.stepfunctions.parser.asl.component.state.state_pass.result import Result
from moto.stepfunctions.parser.asl.component.test_state.program.test_state_program import (
    TestStateProgram,
)
from moto.stepfunctions.parser.asl.component.test_state.state.test_state_state_props import (
    TestStateStateProps,
)
from moto.stepfunctions.parser.asl.eval.test_state.environment import (
    TestStateEnvironment,
)
from moto.stepfunctions.parser.asl.parse.preprocessor import Preprocessor
from moto.stepfunctions.parser.asl.utils.encoding import to_json_str


class InspectionDataKey(enum.Enum):
    INPUT = "input"
    AFTER_INPUT_PATH = "afterInputPath"
    AFTER_PARAMETERS = "afterParameters"
    RESULT = "result"
    AFTER_RESULT_SELECTOR = "afterResultSelector"
    AFTER_RESULT_PATH = "afterResultPath"
    REQUEST = "request"
    RESPONSE = "response"


def _decorated_updated_choice_inspection_data(method):
    def wrapper(env: TestStateEnvironment, *args, **kwargs):
        method(env, *args, **kwargs)
        env.set_choice_selected(env.next_state_name)

    return wrapper


def _decorated_updates_inspection_data(method, inspection_data_key: InspectionDataKey):
    def wrapper(env: TestStateEnvironment, *args, **kwargs):
        method(env, *args, **kwargs)
        result = to_json_str(env.stack[-1])
        # We know that the enum value used here corresponds to a supported inspection data field by design.
        env.inspection_data[inspection_data_key.value] = result  # noqa

    return wrapper


def _decorate_state_field(state_field: CommonStateField) -> None:
    if isinstance(state_field, ExecutionState):
        state_field._eval_execution = _decorated_updates_inspection_data(
            # As part of the decoration process, we intentionally access this protected member
            # to facilitate the decorator's functionality.
            method=state_field._eval_execution,  # noqa
            inspection_data_key=InspectionDataKey.RESULT,
        )
    elif isinstance(state_field, StateChoice):
        state_field._eval_body = _decorated_updated_choice_inspection_data(
            # As part of the decoration process, we intentionally access this protected member
            # to facilitate the decorator's functionality.
            method=state_field._eval_body  # noqa
        )


class TestStatePreprocessor(Preprocessor):
    STATE_NAME: Final[str] = "TestState"

    def visitState_decl_body(
        self, ctx: ASLParser.State_decl_bodyContext
    ) -> TestStateProgram:
        self._open_query_language_scope(ctx)
        state_props = TestStateStateProps()
        state_props.name = self.STATE_NAME
        for child in ctx.children:
            cmp = self.visit(child)
            state_props.add(cmp)
        state_field = self._common_state_field_of(state_props=state_props)
        if state_props.get(QueryLanguage) is None:
            state_props.add(self._get_current_query_language())
        _decorate_state_field(state_field)
        self._close_query_language_scope()
        return TestStateProgram(state_field)

    def visitInput_path_decl(self, ctx: ASLParser.Input_path_declContext) -> InputPath:
        input_path: InputPath = super().visitInput_path_decl(ctx=ctx)
        input_path._eval_body = _decorated_updates_inspection_data(
            method=input_path._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_INPUT_PATH,
        )
        return input_path

    def visitParameters_decl(self, ctx: ASLParser.Parameters_declContext) -> Parameters:
        parameters: Parameters = super().visitParameters_decl(ctx=ctx)
        parameters._eval_body = _decorated_updates_inspection_data(
            method=parameters._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_PARAMETERS,
        )
        return parameters

    def visitResult_selector_decl(
        self, ctx: ASLParser.Result_selector_declContext
    ) -> ResultSelector:
        result_selector: ResultSelector = super().visitResult_selector_decl(ctx=ctx)
        result_selector._eval_body = _decorated_updates_inspection_data(
            method=result_selector._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_RESULT_SELECTOR,
        )
        return result_selector

    def visitResult_path_decl(
        self, ctx: ASLParser.Result_path_declContext
    ) -> ResultPath:
        result_path: ResultPath = super().visitResult_path_decl(ctx=ctx)
        result_path._eval_body = _decorated_updates_inspection_data(
            method=result_path._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_RESULT_PATH,
        )
        return result_path

    def visitResult_decl(self, ctx: ASLParser.Result_declContext) -> Result:
        result: Result = super().visitResult_decl(ctx=ctx)
        result._eval_body = _decorated_updates_inspection_data(
            method=result._eval_body,
            inspection_data_key=InspectionDataKey.RESULT,  # noqa
        )
        return result