import copy
import datetime
import json
from typing import Any, Optional

from moto.core.common_models import BackendDict
from moto.stepfunctions.models import StateMachine, StepFunctionBackend
from moto.stepfunctions.parser.api import (
    Definition,
    EncryptionConfiguration,
    ExecutionStatus,
    GetExecutionHistoryOutput,
    InvalidDefinition,
    InvalidExecutionInput,
    InvalidToken,
    LoggingConfiguration,
    MissingRequiredParameter,
    Name,
    ResourceNotFound,
    SendTaskFailureOutput,
    SendTaskHeartbeatOutput,
    SendTaskSuccessOutput,
    SensitiveCause,
    SensitiveData,
    SensitiveError,
    TaskDoesNotExist,
    TaskTimedOut,
    TaskToken,
    TraceHeader,
    TracingConfiguration,
)
from moto.stepfunctions.parser.asl.component.state.exec.state_map.iteration.itemprocessor.map_run_record import (
    MapRunRecord,
)
from moto.stepfunctions.parser.asl.eval.callback.callback import (
    CallbackConsumerTimeout,
    CallbackNotifyConsumerError,
    CallbackOutcomeFailure,
    CallbackOutcomeSuccess,
)
from moto.stepfunctions.parser.asl.parse.asl_parser import (
    AmazonStateLanguageParser,
    ASLParserException,
)
from moto.stepfunctions.parser.backend.execution import Execution


class StepFunctionsParserBackend(StepFunctionBackend):
    def _get_executions(self, execution_status: Optional[ExecutionStatus] = None):
        executions = []
        for sm in self.state_machines:
            for execution in sm.executions:
                if execution_status is None or execution_status == execution.status:
                    executions.append(execution)
        return executions

    def _revision_by_name(self, name: str) -> Optional[StateMachine]:
        for state_machine in self.state_machines:
            if state_machine.name == name:
                return state_machine
        return None

    @staticmethod
    def _validate_definition(definition: str):
        # Validate
        # TODO: pass through static analyser.
        try:
            AmazonStateLanguageParser.parse(definition)
        except ASLParserException as asl_parser_exception:
            raise InvalidDefinition(message=repr(asl_parser_exception))
        except Exception as exception:
            exception_name = exception.__class__.__name__
            exception_args = list(exception.args)
            raise InvalidDefinition(
                message=f"Error={exception_name} Args={exception_args} in definition '{definition}'."
            )

    def create_state_machine(
        self,
        name: str,
        definition: str,
        roleArn: str,
        tags: Optional[list[dict[str, str]]] = None,
        publish: Optional[bool] = None,
        loggingConfiguration: Optional[LoggingConfiguration] = None,
        tracingConfiguration: Optional[TracingConfiguration] = None,
        encryptionConfiguration: Optional[EncryptionConfiguration] = None,
        version_description: Optional[str] = None,
    ) -> StateMachine:
        StepFunctionsParserBackend._validate_definition(definition=definition)

        return super().create_state_machine(
            name=name,
            definition=definition,
            roleArn=roleArn,
            tags=tags,
            publish=publish,
            loggingConfiguration=loggingConfiguration,
            tracingConfiguration=tracingConfiguration,
            encryptionConfiguration=encryptionConfiguration,
            version_description=version_description,
        )

    def send_task_heartbeat(self, task_token: TaskToken) -> SendTaskHeartbeatOutput:
        running_executions = self._get_executions(ExecutionStatus.RUNNING)
        for execution in running_executions:
            try:
                if execution.exec_worker.env.callback_pool_manager.heartbeat(
                    callback_id=task_token
                ):
                    return
            except CallbackNotifyConsumerError as consumer_error:
                if isinstance(consumer_error, CallbackConsumerTimeout):
                    raise TaskTimedOut()
                else:
                    raise TaskDoesNotExist()
        raise InvalidToken()

    def send_task_success(
        self, task_token: TaskToken, outcome: str
    ) -> SendTaskSuccessOutput:
        outcome = CallbackOutcomeSuccess(callback_id=task_token, output=outcome)
        running_executions = self._get_executions(ExecutionStatus.RUNNING)
        for execution in running_executions:
            try:
                if execution.exec_worker.env.callback_pool_manager.notify(
                    callback_id=task_token, outcome=outcome
                ):
                    return
            except CallbackNotifyConsumerError as consumer_error:
                if isinstance(consumer_error, CallbackConsumerTimeout):
                    raise TaskTimedOut()
                else:
                    raise TaskDoesNotExist()
        raise InvalidToken()

    def send_task_failure(
        self,
        task_token: TaskToken,
        error: SensitiveError = None,
        cause: SensitiveCause = None,
    ) -> SendTaskFailureOutput:
        outcome = CallbackOutcomeFailure(
            callback_id=task_token, error=error, cause=cause
        )
        for execution in self._get_executions():
            try:
                if execution.exec_worker.env.callback_pool_manager.notify(
                    callback_id=task_token, outcome=outcome
                ):
                    return SendTaskFailureOutput()
            except CallbackNotifyConsumerError as consumer_error:
                if isinstance(consumer_error, CallbackConsumerTimeout):
                    raise TaskTimedOut()
                else:
                    raise TaskDoesNotExist()
        raise InvalidToken()

    def start_execution(
        self,
        state_machine_arn: str,
        name: Name = None,
        execution_input: SensitiveData = None,
        trace_header: TraceHeader = None,
    ) -> Execution:
        state_machine = self.describe_state_machine(state_machine_arn)
        existing_execution = state_machine._handle_name_input_idempotency(
            name, execution_input
        )
        if existing_execution is not None:
            # If we found a match for the name and input, return the existing execution.
            return existing_execution

        # Update event change parameters about the state machine and should not affect those about this execution.
        state_machine_clone = copy.deepcopy(state_machine)

        if execution_input is None:
            input_data = "{}"
        else:
            input_data = execution_input
            try:
                # Make sure input is valid json
                json.loads(execution_input)

            except Exception as ex:
                raise InvalidExecutionInput(
                    str(ex)
                )  # TODO: report parsing error like AWS.

        exec_name = name  # TODO: validate name format

        execution_arn = "arn:{}:states:{}:{}:execution:{}:{}"
        execution_arn = execution_arn.format(
            self.partition,
            self.region_name,
            self.account_id,
            state_machine.name,
            name,
        )

        execution = Execution(
            name=exec_name,
            sm_type=state_machine_clone.sm_type,
            role_arn=state_machine_clone.roleArn,
            exec_arn=execution_arn,
            account_id=self.account_id,
            region_name=self.region_name,
            state_machine=state_machine_clone,
            start_date=datetime.datetime.now(tz=datetime.timezone.utc),
            cloud_watch_logging_session=None,
            input_data=input_data,
            trace_header=trace_header,
            activity_store={},
        )
        state_machine.executions.append(execution)

        execution.start()
        return execution

    def update_state_machine(
        self,
        arn: str,
        definition: Definition = None,
        role_arn: str = None,
        logging_configuration: LoggingConfiguration = None,
        tracing_configuration: TracingConfiguration = None,
        encryption_configuration: EncryptionConfiguration = None,
        publish: Optional[bool] = None,
        version_description: str = None,
    ) -> StateMachine:
        if not any(
            [
                definition,
                role_arn,
                logging_configuration,
                tracing_configuration,
                encryption_configuration,
            ]
        ):
            raise MissingRequiredParameter(
                "Either the definition, the role ARN, the LoggingConfiguration, the EncryptionConfiguration or the TracingConfiguration must be specified"
            )

        if definition is not None:
            self._validate_definition(definition=definition)

        return super().update_state_machine(
            arn,
            definition,
            role_arn,
            logging_configuration=logging_configuration,
            tracing_configuration=tracing_configuration,
            encryption_configuration=encryption_configuration,
            publish=publish,
            version_description=version_description,
        )

    def describe_map_run(self, map_run_arn: str) -> dict[str, Any]:
        for execution in self._get_executions():
            map_run_record: Optional[MapRunRecord] = (
                execution.exec_worker.env.map_run_record_pool_manager.get(map_run_arn)
            )
            if map_run_record is not None:
                return map_run_record.describe()
        raise ResourceNotFound()

    def list_map_runs(self, execution_arn: str) -> dict[str, Any]:
        """
        Pagination is not yet implemented
        """
        execution = self.describe_execution(execution_arn=execution_arn)
        map_run_records: list[MapRunRecord] = (
            execution.exec_worker.env.map_run_record_pool_manager.get_all()
        )
        return {
            "mapRuns": [
                map_run_record.list_item() for map_run_record in map_run_records
            ]
        }

    def update_map_run(
        self,
        map_run_arn: str,
        max_concurrency: int,
        tolerated_failure_count: str,
        tolerated_failure_percentage: str,
    ) -> None:
        # TODO: investigate behaviour of empty requests.
        for execution in self._get_executions():
            map_run_record = execution.exec_worker.env.map_run_record_pool_manager.get(
                map_run_arn
            )
            if map_run_record is not None:
                map_run_record.update(
                    max_concurrency=max_concurrency,
                    tolerated_failure_count=tolerated_failure_count,
                    tolerated_failure_percentage=tolerated_failure_percentage,
                )
                return
        raise ResourceNotFound()

    def get_execution_history(self, execution_arn: str) -> GetExecutionHistoryOutput:
        execution = self.describe_execution(execution_arn=execution_arn)
        return execution.to_history_output()


stepfunctions_parser_backends = BackendDict(StepFunctionsParserBackend, "stepfunctions")
