# Tests for http/websocket.py

import base64
import os

import pytest

from aiohttp import web
from aiohttp.test_utils import make_mocked_request


def gen_ws_headers(
    protocols="",
    compress=0,
    extension_text="",
    server_notakeover=False,
    client_notakeover=False,
):
    key = base64.b64encode(os.urandom(16)).decode()
    hdrs = [
        ("Upgrade", "websocket"),
        ("Connection", "upgrade"),
        ("Sec-Websocket-Version", "13"),
        ("Sec-Websocket-Key", key),
    ]
    if protocols:
        hdrs += [("Sec-Websocket-Protocol", protocols)]
    if compress:
        params = "permessage-deflate"
        if compress < 15:
            params += "; server_max_window_bits=" + str(compress)
        if server_notakeover:
            params += "; server_no_context_takeover"
        if client_notakeover:
            params += "; client_no_context_takeover"
        if extension_text:
            params += "; " + extension_text
        hdrs += [("Sec-Websocket-Extensions", params)]
    return hdrs, key


async def test_no_upgrade() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request("GET", "/")
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_no_connection() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request(
        "GET", "/", headers={"Upgrade": "websocket", "Connection": "keep-alive"}
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_protocol_version_unset() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request(
        "GET", "/", headers={"Upgrade": "websocket", "Connection": "upgrade"}
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_protocol_version_not_supported() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request(
        "GET",
        "/",
        headers={
            "Upgrade": "websocket",
            "Connection": "upgrade",
            "Sec-Websocket-Version": "1",
        },
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_protocol_key_not_present() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request(
        "GET",
        "/",
        headers={
            "Upgrade": "websocket",
            "Connection": "upgrade",
            "Sec-Websocket-Version": "13",
        },
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_protocol_key_invalid() -> None:
    ws = web.WebSocketResponse()
    req = make_mocked_request(
        "GET",
        "/",
        headers={
            "Upgrade": "websocket",
            "Connection": "upgrade",
            "Sec-Websocket-Version": "13",
            "Sec-Websocket-Key": "123",
        },
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_protocol_key_bad_size() -> None:
    ws = web.WebSocketResponse()
    sec_key = base64.b64encode(os.urandom(2))
    val = sec_key.decode()
    req = make_mocked_request(
        "GET",
        "/",
        headers={
            "Upgrade": "websocket",
            "Connection": "upgrade",
            "Sec-Websocket-Version": "13",
            "Sec-Websocket-Key": val,
        },
    )
    with pytest.raises(web.HTTPBadRequest):
        await ws.prepare(req)


async def test_handshake_ok() -> None:
    hdrs, sec_key = gen_ws_headers()
    ws = web.WebSocketResponse()
    req = make_mocked_request("GET", "/", headers=hdrs)

    await ws.prepare(req)

    assert ws.ws_protocol is None


async def test_handshake_protocol() -> None:
    # Tests if one protocol is returned by handshake
    proto = "chat"

    ws = web.WebSocketResponse(protocols={"chat"})
    req = make_mocked_request("GET", "/", headers=gen_ws_headers(proto)[0])

    await ws.prepare(req)

    assert ws.ws_protocol == proto


async def test_handshake_protocol_agreement() -> None:
    # Tests if the right protocol is selected given multiple
    best_proto = "worse_proto"
    wanted_protos = ["best", "chat", "worse_proto"]
    server_protos = "worse_proto,chat"

    ws = web.WebSocketResponse(protocols=wanted_protos)
    req = make_mocked_request("GET", "/", headers=gen_ws_headers(server_protos)[0])

    await ws.prepare(req)

    assert ws.ws_protocol == best_proto


async def test_handshake_protocol_unsupported(caplog) -> None:
    # Tests if a protocol mismatch handshake warns and returns None
    proto = "chat"
    req = make_mocked_request("GET", "/", headers=gen_ws_headers("test")[0])

    ws = web.WebSocketResponse(protocols=[proto])
    await ws.prepare(req)

    assert (
        caplog.records[-1].msg
        == "Client protocols %r don’t overlap server-known ones %r"
    )
    assert ws.ws_protocol is None


async def test_handshake_compress() -> None:
    hdrs, sec_key = gen_ws_headers(compress=15)

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    await ws.prepare(req)

    assert ws.compress == 15


def test_handshake_compress_server_notakeover() -> None:
    hdrs, sec_key = gen_ws_headers(compress=15, server_notakeover=True)

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert compress == 15
    assert notakeover is True
    assert "Sec-Websocket-Extensions" in headers
    assert headers["Sec-Websocket-Extensions"] == (
        "permessage-deflate; server_no_context_takeover"
    )


def test_handshake_compress_client_notakeover() -> None:
    hdrs, sec_key = gen_ws_headers(compress=15, client_notakeover=True)

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" in headers
    assert headers["Sec-Websocket-Extensions"] == ("permessage-deflate"), hdrs

    assert compress == 15


def test_handshake_compress_wbits() -> None:
    hdrs, sec_key = gen_ws_headers(compress=9)

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" in headers
    assert headers["Sec-Websocket-Extensions"] == (
        "permessage-deflate; server_max_window_bits=9"
    )
    assert compress == 9


def test_handshake_compress_wbits_error() -> None:
    hdrs, sec_key = gen_ws_headers(compress=6)

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" not in headers
    assert compress == 0


def test_handshake_compress_bad_ext() -> None:
    hdrs, sec_key = gen_ws_headers(compress=15, extension_text="bad")

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" not in headers
    assert compress == 0


def test_handshake_compress_multi_ext_bad() -> None:
    hdrs, sec_key = gen_ws_headers(
        compress=15, extension_text="bad, permessage-deflate"
    )

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" in headers
    assert headers["Sec-Websocket-Extensions"] == "permessage-deflate"


def test_handshake_compress_multi_ext_wbits() -> None:
    hdrs, sec_key = gen_ws_headers(compress=6, extension_text=", permessage-deflate")

    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Sec-Websocket-Extensions" in headers
    assert headers["Sec-Websocket-Extensions"] == "permessage-deflate"
    assert compress == 15


def test_handshake_no_transfer_encoding() -> None:
    hdrs, sec_key = gen_ws_headers()
    req = make_mocked_request("GET", "/", headers=hdrs)

    ws = web.WebSocketResponse()
    headers, _, compress, notakeover = ws._handshake(req)

    assert "Transfer-Encoding" not in headers
