import enum
from typing import Final

from moto.stepfunctions.parser.asl.antlr.runtime.ASLParser import ASLParser
from moto.stepfunctions.parser.asl.component.common.parargs import Parameters
from moto.stepfunctions.parser.asl.component.common.path.input_path import InputPath
from moto.stepfunctions.parser.asl.component.common.path.result_path import ResultPath
from moto.stepfunctions.parser.asl.component.common.query_language import QueryLanguage
from moto.stepfunctions.parser.asl.component.common.result_selector import (
    ResultSelector,
)
from moto.stepfunctions.parser.asl.component.state.choice.state_choice import (
    StateChoice,
)
from moto.stepfunctions.parser.asl.component.state.exec.execute_state import (
    ExecutionState,
)
from moto.stepfunctions.parser.asl.component.state.state import CommonStateField
from moto.stepfunctions.parser.asl.component.state.state_pass.result import Result
from moto.stepfunctions.parser.asl.component.test_state.program.test_state_program import (
    TestStateProgram,
)
from moto.stepfunctions.parser.asl.component.test_state.state.test_state_state_props import (
    TestStateStateProps,
)
from moto.stepfunctions.parser.asl.eval.test_state.environment import (
    TestStateEnvironment,
)
from moto.stepfunctions.parser.asl.parse.preprocessor import Preprocessor
from moto.stepfunctions.parser.asl.utils.encoding import to_json_str


class InspectionDataKey(enum.Enum):
    INPUT = "input"
    AFTER_INPUT_PATH = "afterInputPath"
    AFTER_PARAMETERS = "afterParameters"
    RESULT = "result"
    AFTER_RESULT_SELECTOR = "afterResultSelector"
    AFTER_RESULT_PATH = "afterResultPath"
    REQUEST = "request"
    RESPONSE = "response"


def _decorated_updated_choice_inspection_data(method):
    def wrapper(env: TestStateEnvironment, *args, **kwargs):
        method(env, *args, **kwargs)
        env.set_choice_selected(env.next_state_name)

    return wrapper


def _decorated_updates_inspection_data(method, inspection_data_key: InspectionDataKey):
    def wrapper(env: TestStateEnvironment, *args, **kwargs):
        method(env, *args, **kwargs)
        result = to_json_str(env.stack[-1])
        # We know that the enum value used here corresponds to a supported inspection data field by design.
        env.inspection_data[inspection_data_key.value] = result  # noqa

    return wrapper


def _decorate_state_field(state_field: CommonStateField) -> None:
    if isinstance(state_field, ExecutionState):
        state_field._eval_execution = _decorated_updates_inspection_data(
            # As part of the decoration process, we intentionally access this protected member
            # to facilitate the decorator's functionality.
            method=state_field._eval_execution,  # noqa
            inspection_data_key=InspectionDataKey.RESULT,
        )
    elif isinstance(state_field, StateChoice):
        state_field._eval_body = _decorated_updated_choice_inspection_data(
            # As part of the decoration process, we intentionally access this protected member
            # to facilitate the decorator's functionality.
            method=state_field._eval_body  # noqa
        )


class TestStatePreprocessor(Preprocessor):
    STATE_NAME: Final[str] = "TestState"

    def visitState_decl_body(
        self, ctx: ASLParser.State_decl_bodyContext
    ) -> TestStateProgram:
        self._open_query_language_scope(ctx)
        state_props = TestStateStateProps()
        state_props.name = self.STATE_NAME
        for child in ctx.children:
            cmp = self.visit(child)
            state_props.add(cmp)
        state_field = self._common_state_field_of(state_props=state_props)
        if state_props.get(QueryLanguage) is None:
            state_props.add(self._get_current_query_language())
        _decorate_state_field(state_field)
        self._close_query_language_scope()
        return TestStateProgram(state_field)

    def visitInput_path_decl(self, ctx: ASLParser.Input_path_declContext) -> InputPath:
        input_path: InputPath = super().visitInput_path_decl(ctx=ctx)
        input_path._eval_body = _decorated_updates_inspection_data(
            method=input_path._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_INPUT_PATH,
        )
        return input_path

    def visitParameters_decl(self, ctx: ASLParser.Parameters_declContext) -> Parameters:
        parameters: Parameters = super().visitParameters_decl(ctx=ctx)
        parameters._eval_body = _decorated_updates_inspection_data(
            method=parameters._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_PARAMETERS,
        )
        return parameters

    def visitResult_selector_decl(
        self, ctx: ASLParser.Result_selector_declContext
    ) -> ResultSelector:
        result_selector: ResultSelector = super().visitResult_selector_decl(ctx=ctx)
        result_selector._eval_body = _decorated_updates_inspection_data(
            method=result_selector._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_RESULT_SELECTOR,
        )
        return result_selector

    def visitResult_path_decl(
        self, ctx: ASLParser.Result_path_declContext
    ) -> ResultPath:
        result_path: ResultPath = super().visitResult_path_decl(ctx=ctx)
        result_path._eval_body = _decorated_updates_inspection_data(
            method=result_path._eval_body,  # noqa
            inspection_data_key=InspectionDataKey.AFTER_RESULT_PATH,
        )
        return result_path

    def visitResult_decl(self, ctx: ASLParser.Result_declContext) -> Result:
        result: Result = super().visitResult_decl(ctx=ctx)
        result._eval_body = _decorated_updates_inspection_data(
            method=result._eval_body,
            inspection_data_key=InspectionDataKey.RESULT,  # noqa
        )
        return result
