1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
import enum
from typing import Final
from moto.stepfunctions.parser.asl.antlr.runtime.ASLParser import ASLParser
from moto.stepfunctions.parser.asl.component.common.parargs import Parameters
from moto.stepfunctions.parser.asl.component.common.path.input_path import InputPath
from moto.stepfunctions.parser.asl.component.common.path.result_path import ResultPath
from moto.stepfunctions.parser.asl.component.common.query_language import QueryLanguage
from moto.stepfunctions.parser.asl.component.common.result_selector import (
ResultSelector,
)
from moto.stepfunctions.parser.asl.component.state.choice.state_choice import (
StateChoice,
)
from moto.stepfunctions.parser.asl.component.state.exec.execute_state import (
ExecutionState,
)
from moto.stepfunctions.parser.asl.component.state.state import CommonStateField
from moto.stepfunctions.parser.asl.component.state.state_pass.result import Result
from moto.stepfunctions.parser.asl.component.test_state.program.test_state_program import (
TestStateProgram,
)
from moto.stepfunctions.parser.asl.component.test_state.state.test_state_state_props import (
TestStateStateProps,
)
from moto.stepfunctions.parser.asl.eval.test_state.environment import (
TestStateEnvironment,
)
from moto.stepfunctions.parser.asl.parse.preprocessor import Preprocessor
from moto.stepfunctions.parser.asl.utils.encoding import to_json_str
class InspectionDataKey(enum.Enum):
INPUT = "input"
AFTER_INPUT_PATH = "afterInputPath"
AFTER_PARAMETERS = "afterParameters"
RESULT = "result"
AFTER_RESULT_SELECTOR = "afterResultSelector"
AFTER_RESULT_PATH = "afterResultPath"
REQUEST = "request"
RESPONSE = "response"
def _decorated_updated_choice_inspection_data(method):
def wrapper(env: TestStateEnvironment, *args, **kwargs):
method(env, *args, **kwargs)
env.set_choice_selected(env.next_state_name)
return wrapper
def _decorated_updates_inspection_data(method, inspection_data_key: InspectionDataKey):
def wrapper(env: TestStateEnvironment, *args, **kwargs):
method(env, *args, **kwargs)
result = to_json_str(env.stack[-1])
# We know that the enum value used here corresponds to a supported inspection data field by design.
env.inspection_data[inspection_data_key.value] = result # noqa
return wrapper
def _decorate_state_field(state_field: CommonStateField) -> None:
if isinstance(state_field, ExecutionState):
state_field._eval_execution = _decorated_updates_inspection_data(
# As part of the decoration process, we intentionally access this protected member
# to facilitate the decorator's functionality.
method=state_field._eval_execution, # noqa
inspection_data_key=InspectionDataKey.RESULT,
)
elif isinstance(state_field, StateChoice):
state_field._eval_body = _decorated_updated_choice_inspection_data(
# As part of the decoration process, we intentionally access this protected member
# to facilitate the decorator's functionality.
method=state_field._eval_body # noqa
)
class TestStatePreprocessor(Preprocessor):
STATE_NAME: Final[str] = "TestState"
def visitState_decl_body(
self, ctx: ASLParser.State_decl_bodyContext
) -> TestStateProgram:
self._open_query_language_scope(ctx)
state_props = TestStateStateProps()
state_props.name = self.STATE_NAME
for child in ctx.children:
cmp = self.visit(child)
state_props.add(cmp)
state_field = self._common_state_field_of(state_props=state_props)
if state_props.get(QueryLanguage) is None:
state_props.add(self._get_current_query_language())
_decorate_state_field(state_field)
self._close_query_language_scope()
return TestStateProgram(state_field)
def visitInput_path_decl(self, ctx: ASLParser.Input_path_declContext) -> InputPath:
input_path: InputPath = super().visitInput_path_decl(ctx=ctx)
input_path._eval_body = _decorated_updates_inspection_data(
method=input_path._eval_body, # noqa
inspection_data_key=InspectionDataKey.AFTER_INPUT_PATH,
)
return input_path
def visitParameters_decl(self, ctx: ASLParser.Parameters_declContext) -> Parameters:
parameters: Parameters = super().visitParameters_decl(ctx=ctx)
parameters._eval_body = _decorated_updates_inspection_data(
method=parameters._eval_body, # noqa
inspection_data_key=InspectionDataKey.AFTER_PARAMETERS,
)
return parameters
def visitResult_selector_decl(
self, ctx: ASLParser.Result_selector_declContext
) -> ResultSelector:
result_selector: ResultSelector = super().visitResult_selector_decl(ctx=ctx)
result_selector._eval_body = _decorated_updates_inspection_data(
method=result_selector._eval_body, # noqa
inspection_data_key=InspectionDataKey.AFTER_RESULT_SELECTOR,
)
return result_selector
def visitResult_path_decl(
self, ctx: ASLParser.Result_path_declContext
) -> ResultPath:
result_path: ResultPath = super().visitResult_path_decl(ctx=ctx)
result_path._eval_body = _decorated_updates_inspection_data(
method=result_path._eval_body, # noqa
inspection_data_key=InspectionDataKey.AFTER_RESULT_PATH,
)
return result_path
def visitResult_decl(self, ctx: ASLParser.Result_declContext) -> Result:
result: Result = super().visitResult_decl(ctx=ctx)
result._eval_body = _decorated_updates_inspection_data(
method=result._eval_body,
inspection_data_key=InspectionDataKey.RESULT, # noqa
)
return result
|