import json
from typing import Dict

import pytest
from fastapi import BackgroundTasks, Depends, FastAPI
from fastapi.responses import StreamingResponse
from fastapi.testclient import TestClient

app = FastAPI()
state = {
    "/async": "asyncgen not started",
    "/sync": "generator not started",
    "/async_raise": "asyncgen raise not started",
    "/sync_raise": "generator raise not started",
    "context_a": "not started a",
    "context_b": "not started b",
    "bg": "not set",
    "sync_bg": "not set",
}

errors = []


async def get_state():
    return state


class AsyncDependencyError(Exception):
    pass


class SyncDependencyError(Exception):
    pass


class OtherDependencyError(Exception):
    pass


async def asyncgen_state(state: Dict[str, str] = Depends(get_state)):
    state["/async"] = "asyncgen started"
    yield state["/async"]
    state["/async"] = "asyncgen completed"


def generator_state(state: Dict[str, str] = Depends(get_state)):
    state["/sync"] = "generator started"
    yield state["/sync"]
    state["/sync"] = "generator completed"


async def asyncgen_state_try(state: Dict[str, str] = Depends(get_state)):
    state["/async_raise"] = "asyncgen raise started"
    try:
        yield state["/async_raise"]
    except AsyncDependencyError:
        errors.append("/async_raise")
        raise
    finally:
        state["/async_raise"] = "asyncgen raise finalized"


def generator_state_try(state: Dict[str, str] = Depends(get_state)):
    state["/sync_raise"] = "generator raise started"
    try:
        yield state["/sync_raise"]
    except SyncDependencyError:
        errors.append("/sync_raise")
        raise
    finally:
        state["/sync_raise"] = "generator raise finalized"


async def context_a(state: dict = Depends(get_state)):
    state["context_a"] = "started a"
    try:
        yield state
    finally:
        state["context_a"] = "finished a"


async def context_b(state: dict = Depends(context_a)):
    state["context_b"] = "started b"
    try:
        yield state
    finally:
        state["context_b"] = f"finished b with a: {state['context_a']}"


@app.get("/async")
async def get_async(state: str = Depends(asyncgen_state)):
    return state


@app.get("/sync")
async def get_sync(state: str = Depends(generator_state)):
    return state


@app.get("/async_raise")
async def get_async_raise(state: str = Depends(asyncgen_state_try)):
    assert state == "asyncgen raise started"
    raise AsyncDependencyError()


@app.get("/sync_raise")
async def get_sync_raise(state: str = Depends(generator_state_try)):
    assert state == "generator raise started"
    raise SyncDependencyError()


@app.get("/async_raise_other")
async def get_async_raise_other(state: str = Depends(asyncgen_state_try)):
    assert state == "asyncgen raise started"
    raise OtherDependencyError()


@app.get("/sync_raise_other")
async def get_sync_raise_other(state: str = Depends(generator_state_try)):
    assert state == "generator raise started"
    raise OtherDependencyError()


@app.get("/context_b")
async def get_context_b(state: dict = Depends(context_b)):
    return state


@app.get("/context_b_raise")
async def get_context_b_raise(state: dict = Depends(context_b)):
    assert state["context_b"] == "started b"
    assert state["context_a"] == "started a"
    raise OtherDependencyError()


@app.get("/context_b_bg")
async def get_context_b_bg(tasks: BackgroundTasks, state: dict = Depends(context_b)):
    async def bg(state: dict):
        state["bg"] = f"bg set - b: {state['context_b']} - a: {state['context_a']}"

    tasks.add_task(bg, state)
    return state


# Sync versions


@app.get("/sync_async")
def get_sync_async(state: str = Depends(asyncgen_state)):
    return state


@app.get("/sync_sync")
def get_sync_sync(state: str = Depends(generator_state)):
    return state


@app.get("/sync_async_raise")
def get_sync_async_raise(state: str = Depends(asyncgen_state_try)):
    assert state == "asyncgen raise started"
    raise AsyncDependencyError()


@app.get("/sync_sync_raise")
def get_sync_sync_raise(state: str = Depends(generator_state_try)):
    assert state == "generator raise started"
    raise SyncDependencyError()


@app.get("/sync_async_raise_other")
def get_sync_async_raise_other(state: str = Depends(asyncgen_state_try)):
    assert state == "asyncgen raise started"
    raise OtherDependencyError()


@app.get("/sync_sync_raise_other")
def get_sync_sync_raise_other(state: str = Depends(generator_state_try)):
    assert state == "generator raise started"
    raise OtherDependencyError()


@app.get("/sync_context_b")
def get_sync_context_b(state: dict = Depends(context_b)):
    return state


@app.get("/sync_context_b_raise")
def get_sync_context_b_raise(state: dict = Depends(context_b)):
    assert state["context_b"] == "started b"
    assert state["context_a"] == "started a"
    raise OtherDependencyError()


@app.get("/sync_context_b_bg")
async def get_sync_context_b_bg(
    tasks: BackgroundTasks, state: dict = Depends(context_b)
):
    async def bg(state: dict):
        state["sync_bg"] = (
            f"sync_bg set - b: {state['context_b']} - a: {state['context_a']}"
        )

    tasks.add_task(bg, state)
    return state


@app.middleware("http")
async def middleware(request, call_next):
    response: StreamingResponse = await call_next(request)
    response.headers["x-state"] = json.dumps(state.copy())
    return response


client = TestClient(app)


def test_async_state():
    assert state["/async"] == "asyncgen not started"
    response = client.get("/async")
    assert response.status_code == 200, response.text
    assert response.json() == "asyncgen started"
    assert state["/async"] == "asyncgen completed"


def test_sync_state():
    assert state["/sync"] == "generator not started"
    response = client.get("/sync")
    assert response.status_code == 200, response.text
    assert response.json() == "generator started"
    assert state["/sync"] == "generator completed"


def test_async_raise_other():
    assert state["/async_raise"] == "asyncgen raise not started"
    with pytest.raises(OtherDependencyError):
        client.get("/async_raise_other")
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" not in errors


def test_sync_raise_other():
    assert state["/sync_raise"] == "generator raise not started"
    with pytest.raises(OtherDependencyError):
        client.get("/sync_raise_other")
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" not in errors


def test_async_raise_raises():
    with pytest.raises(AsyncDependencyError):
        client.get("/async_raise")
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" in errors
    errors.clear()


def test_async_raise_server_error():
    client = TestClient(app, raise_server_exceptions=False)
    response = client.get("/async_raise")
    assert response.status_code == 500, response.text
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" in errors
    errors.clear()


def test_context_b():
    response = client.get("/context_b")
    data = response.json()
    assert data["context_b"] == "started b"
    assert data["context_a"] == "started a"
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"


def test_context_b_raise():
    with pytest.raises(OtherDependencyError):
        client.get("/context_b_raise")
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"


def test_background_tasks():
    response = client.get("/context_b_bg")
    data = response.json()
    assert data["context_b"] == "started b"
    assert data["context_a"] == "started a"
    assert data["bg"] == "not set"
    middleware_state = json.loads(response.headers["x-state"])
    assert middleware_state["context_b"] == "finished b with a: started a"
    assert middleware_state["context_a"] == "finished a"
    assert middleware_state["bg"] == "not set"
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"
    assert state["bg"] == "bg set - b: finished b with a: started a - a: finished a"


def test_sync_raise_raises():
    with pytest.raises(SyncDependencyError):
        client.get("/sync_raise")
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" in errors
    errors.clear()


def test_sync_raise_server_error():
    client = TestClient(app, raise_server_exceptions=False)
    response = client.get("/sync_raise")
    assert response.status_code == 500, response.text
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" in errors
    errors.clear()


def test_sync_async_state():
    response = client.get("/sync_async")
    assert response.status_code == 200, response.text
    assert response.json() == "asyncgen started"
    assert state["/async"] == "asyncgen completed"


def test_sync_sync_state():
    response = client.get("/sync_sync")
    assert response.status_code == 200, response.text
    assert response.json() == "generator started"
    assert state["/sync"] == "generator completed"


def test_sync_async_raise_other():
    with pytest.raises(OtherDependencyError):
        client.get("/sync_async_raise_other")
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" not in errors


def test_sync_sync_raise_other():
    with pytest.raises(OtherDependencyError):
        client.get("/sync_sync_raise_other")
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" not in errors


def test_sync_async_raise_raises():
    with pytest.raises(AsyncDependencyError):
        client.get("/sync_async_raise")
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" in errors
    errors.clear()


def test_sync_async_raise_server_error():
    client = TestClient(app, raise_server_exceptions=False)
    response = client.get("/sync_async_raise")
    assert response.status_code == 500, response.text
    assert state["/async_raise"] == "asyncgen raise finalized"
    assert "/async_raise" in errors
    errors.clear()


def test_sync_sync_raise_raises():
    with pytest.raises(SyncDependencyError):
        client.get("/sync_sync_raise")
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" in errors
    errors.clear()


def test_sync_sync_raise_server_error():
    client = TestClient(app, raise_server_exceptions=False)
    response = client.get("/sync_sync_raise")
    assert response.status_code == 500, response.text
    assert state["/sync_raise"] == "generator raise finalized"
    assert "/sync_raise" in errors
    errors.clear()


def test_sync_context_b():
    response = client.get("/sync_context_b")
    data = response.json()
    assert data["context_b"] == "started b"
    assert data["context_a"] == "started a"
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"


def test_sync_context_b_raise():
    with pytest.raises(OtherDependencyError):
        client.get("/sync_context_b_raise")
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"


def test_sync_background_tasks():
    response = client.get("/sync_context_b_bg")
    data = response.json()
    assert data["context_b"] == "started b"
    assert data["context_a"] == "started a"
    assert data["sync_bg"] == "not set"
    assert state["context_b"] == "finished b with a: started a"
    assert state["context_a"] == "finished a"
    assert (
        state["sync_bg"]
        == "sync_bg set - b: finished b with a: started a - a: finished a"
    )
