import json
from unittest import SkipTest

from moto import mock_aws, settings

from . import verify_execution_result


@mock_aws(config={"stepfunctions": {"execute_state_machine": True}})
def test_state_machine_with_parallel_states():
    if settings.TEST_SERVER_MODE:
        raise SkipTest("Don't need to test this in ServerMode")
    tmpl_name = "parallel_states"

    def _verify_result(client, execution, execution_arn):
        assert "stopDate" in execution
        assert json.loads(execution["output"]) == [
            [1, 2, 3, 4],
            [1, 2, 3, 4],
            [1, 2, 3, 4],
        ]

        history = client.get_execution_history(executionArn=execution_arn)
        assert len(history["events"]) == 20
        assert "ParallelStateEntered" in [e["type"] for e in history["events"]]
        assert "ParallelStateStarted" in [e["type"] for e in history["events"]]
        assert "PassStateExited" in [e["type"] for e in history["events"]]

    verify_execution_result(_verify_result, "SUCCEEDED", tmpl_name)


@mock_aws(config={"stepfunctions": {"execute_state_machine": True}})
def test_state_machine_with_parallel_states_catch_error():
    if settings.TEST_SERVER_MODE:
        raise SkipTest("Don't need to test this in ServerMode")
    expected_status = "SUCCEEDED"
    tmpl_name = "parallel_states_catch"

    def _verify_result(client, execution, execution_arn):
        assert "stopDate" in execution
        assert json.loads(execution["output"])["Cause"] == "Cause description"
        history = client.get_execution_history(executionArn=execution_arn)
        assert len(history["events"]) == 9
        assert "ParallelStateEntered" in [e["type"] for e in history["events"]]
        assert "ParallelStateStarted" in [e["type"] for e in history["events"]]
        assert "ParallelStateFailed" in [e["type"] for e in history["events"]]

    verify_execution_result(_verify_result, expected_status, tmpl_name)


@mock_aws(config={"stepfunctions": {"execute_state_machine": True}})
def test_state_machine_with_parallel_states_error():
    if settings.TEST_SERVER_MODE:
        raise SkipTest("Don't need to test this in ServerMode")

    def _verify_result(client, execution, execution_arn):
        assert "stopDate" in execution
        assert execution["error"] == "FailureType"

    verify_execution_result(_verify_result, "FAILED", "parallel_fail")


@mock_aws(config={"stepfunctions": {"execute_state_machine": True}})
def test_state_machine_with_parallel_states_order():
    if settings.TEST_SERVER_MODE:
        raise SkipTest("Don't need to test this in ServerMode")

    def _verify_result(client, execution, execution_arn):
        assert json.loads(execution["output"]) == [
            {"branch": 0},
            {"branch": 1},
            {"branch": 2},
            {"branch": 3},
        ]

    verify_execution_result(_verify_result, "SUCCEEDED", "parallel_states_order")
