import random
import struct
import zlib
from unittest import mock

import pytest

import aiohttp
from aiohttp import http_websocket
from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType
from aiohttp.http_websocket import (_WS_DEFLATE_TRAILING, PACK_CLOSE_CODE,
                                    PACK_LEN1, PACK_LEN2, PACK_LEN3,
                                    WebSocketReader, _websocket_mask)


def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True,
                compress=False):
    """Send a frame over the websocket with message as its payload."""
    if compress:
        compressobj = zlib.compressobj(wbits=-9)
        message = compressobj.compress(message)
        message = message + compressobj.flush(zlib.Z_SYNC_FLUSH)
        if message.endswith(_WS_DEFLATE_TRAILING):
            message = message[:-4]
    msg_length = len(message)
    if use_mask:  # pragma: no cover
        mask_bit = 0x80
    else:
        mask_bit = 0

    if is_fin:
        header_first_byte = 0x80 | opcode
    else:
        header_first_byte = opcode

    if compress:
        header_first_byte |= 0x40

    if msg_length < 126:
        header = PACK_LEN1(
            header_first_byte, msg_length | mask_bit)
    elif msg_length < (1 << 16):  # pragma: no cover
        header = PACK_LEN2(
            header_first_byte, 126 | mask_bit, msg_length)
    else:
        header = PACK_LEN3(
            header_first_byte, 127 | mask_bit, msg_length)

    if use_mask:  # pragma: no cover
        mask = random.randrange(0, 0xffffffff)
        mask = mask.to_bytes(4, 'big')
        message = bytearray(message)
        _websocket_mask(mask, message)
        if noheader:
            return message
        else:
            return header + mask + message
    else:
        if noheader:
            return message
        else:
            return header + message


def build_close_frame(code=1000, message=b'', noheader=False):
    """Close the websocket, sending the specified code and message."""
    if isinstance(message, str):  # pragma: no cover
        message = message.encode('utf-8')
    return build_frame(
        PACK_CLOSE_CODE(code) + message,
        opcode=WSMsgType.CLOSE, noheader=noheader)


@pytest.fixture()
def out(loop):
    return aiohttp.DataQueue(loop)


@pytest.fixture()
def parser(out):
    return WebSocketReader(out, 4*1024*1024)


def test_parse_frame(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b00000001, 0b00000001))
    res = parser.parse_frame(b'1')
    fin, opcode, payload, compress = res[0]

    assert (0, 1, b'1', False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length0(parser) -> None:
    fin, opcode, payload, compress = parser.parse_frame(
        struct.pack('!BB', 0b00000001, 0b00000000))[0]

    assert (0, 1, b'', False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length2(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b00000001, 126))
    parser.parse_frame(struct.pack('!H', 4))
    res = parser.parse_frame(b'1234')
    fin, opcode, payload, compress = res[0]

    assert (0, 1, b'1234', False) == (fin, opcode, payload, not not compress)


def test_parse_frame_length4(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b00000001, 127))
    parser.parse_frame(struct.pack('!Q', 4))
    fin, opcode, payload, compress = parser.parse_frame(b'1234')[0]

    assert (0, 1, b'1234', False) == (fin, opcode, payload, not not compress)


def test_parse_frame_mask(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b00000001, 0b10000001))
    parser.parse_frame(b'0001')
    fin, opcode, payload, compress = parser.parse_frame(b'1')[0]

    assert (0, 1, b'\x01', False) == (fin, opcode, payload, not not compress)


def test_parse_frame_header_reversed_bits(out, parser) -> None:
    with pytest.raises(WebSocketError):
        parser.parse_frame(struct.pack('!BB', 0b01100000, 0b00000000))
        raise out.exception()


def test_parse_frame_header_control_frame(out, parser) -> None:
    with pytest.raises(WebSocketError):
        parser.parse_frame(struct.pack('!BB', 0b00001000, 0b00000000))
        raise out.exception()


def _test_parse_frame_header_new_data_err(out, parser):
    with pytest.raises(WebSocketError):
        parser.parse_frame(struct.pack('!BB', 0b000000000, 0b00000000))
        raise out.exception()


def test_parse_frame_header_payload_size(out, parser) -> None:
    with pytest.raises(WebSocketError):
        parser.parse_frame(struct.pack('!BB', 0b10001000, 0b01111110))
        raise out.exception()


def test_ping_frame(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.PING, b'data', False)]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res == ((WSMsgType.PING, b'data', ''), 4)


def test_pong_frame(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.PONG, b'data', False)]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res == ((WSMsgType.PONG, b'data', ''), 4)


def test_close_frame(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'', False)]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res == ((WSMsgType.CLOSE, 0, ''), 0)


def test_close_frame_info(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'0112345', False)]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res == (WSMessage(WSMsgType.CLOSE, 12337, '12345'), 0)


def test_close_frame_invalid(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.CLOSE, b'1', False)]
    parser.feed_data(b'')

    assert isinstance(out.exception(), WebSocketError)
    assert out.exception().code == WSCloseCode.PROTOCOL_ERROR


def test_close_frame_invalid_2(out, parser) -> None:
    data = build_close_frame(code=1)

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)

    assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


def test_close_frame_unicode_err(parser) -> None:
    data = build_close_frame(
        code=1000, message=b'\xf4\x90\x80\x80')

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)

    assert ctx.value.code == WSCloseCode.INVALID_TEXT


def test_unknown_frame(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.CONTINUATION, b'', False)]

    with pytest.raises(WebSocketError):
        parser.feed_data(b'')
        raise out.exception()


def test_simple_text(out, parser) -> None:
    data = build_frame(b'text', WSMsgType.TEXT)
    parser._feed_data(data)
    res = out._buffer[0]
    assert res == ((WSMsgType.TEXT, 'text', ''), 4)


def test_simple_text_unicode_err(parser) -> None:
    data = build_frame(b'\xf4\x90\x80\x80', WSMsgType.TEXT)

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)

    assert ctx.value.code == WSCloseCode.INVALID_TEXT


def test_simple_binary(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [(1, WSMsgType.BINARY, b'binary', False)]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res == ((WSMsgType.BINARY, b'binary', ''), 6)


def test_fragmentation_header(out, parser) -> None:
    data = build_frame(b'a', WSMsgType.TEXT)
    parser._feed_data(data[:1])
    parser._feed_data(data[1:])

    res = out._buffer[0]
    assert res == (WSMessage(WSMsgType.TEXT, 'a', ''), 1)


def test_continuation(out, parser) -> None:
    data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
    parser._feed_data(data1)

    data2 = build_frame(b'line2', WSMsgType.CONTINUATION)
    parser._feed_data(data2)

    res = out._buffer[0]
    assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)


def test_continuation_with_ping(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.PING, b'', False),
        (1, WSMsgType.CONTINUATION, b'line2', False),
    ]

    data1 = build_frame(b'line1', WSMsgType.TEXT, is_fin=False)
    parser._feed_data(data1)

    data2 = build_frame(b'', WSMsgType.PING)
    parser._feed_data(data2)

    data3 = build_frame(b'line2', WSMsgType.CONTINUATION)
    parser._feed_data(data3)

    res = out._buffer[0]
    assert res == (WSMessage(WSMsgType.PING, b'', ''), 0)
    res = out._buffer[1]
    assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)


def test_continuation_err(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (1, WSMsgType.TEXT, b'line2', False)]

    with pytest.raises(WebSocketError):
        parser._feed_data(b'')


def test_continuation_with_close(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.CLOSE,
         build_close_frame(1002, b'test', noheader=True), False),
        (1, WSMsgType.CONTINUATION, b'line2', False),
    ]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res, (WSMessage(WSMsgType.CLOSE, 1002, 'test'), 0)
    res = out._buffer[1]
    assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)


def test_continuation_with_close_unicode_err(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.CLOSE,
         build_close_frame(1000, b'\xf4\x90\x80\x80', noheader=True), False),
        (1, WSMsgType.CONTINUATION, b'line2', False)]

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(b'')

    assert ctx.value.code == WSCloseCode.INVALID_TEXT


def test_continuation_with_close_bad_code(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.CLOSE,
         build_close_frame(1, b'test', noheader=True), False),
        (1, WSMsgType.CONTINUATION, b'line2', False)]

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(b'')

    assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


def test_continuation_with_close_bad_payload(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.CLOSE, b'1', False),
        (1, WSMsgType.CONTINUATION, b'line2', False)]

    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(b'')

    assert ctx.value.code, WSCloseCode.PROTOCOL_ERROR


def test_continuation_with_close_empty(out, parser) -> None:
    parser.parse_frame = mock.Mock()
    parser.parse_frame.return_value = [
        (0, WSMsgType.TEXT, b'line1', False),
        (0, WSMsgType.CLOSE, b'', False),
        (1, WSMsgType.CONTINUATION, b'line2', False),
    ]

    parser.feed_data(b'')
    res = out._buffer[0]
    assert res, (WSMessage(WSMsgType.CLOSE, 0, ''), 0)
    res = out._buffer[1]
    assert res == (WSMessage(WSMsgType.TEXT, 'line1line2', ''), 10)


websocket_mask_data = b'some very long data for masking by websocket'
websocket_mask_mask = b'1234'
websocket_mask_masked = (b'B]^Q\x11DVFH\x12_[_U\x13PPFR\x14W]A\x14\\S@_X'
                         b'\\T\x14SK\x13CTP@[RYV@')


def test_websocket_mask_python() -> None:
    message = bytearray(websocket_mask_data)
    http_websocket._websocket_mask_python(
        websocket_mask_mask, message)
    assert message == websocket_mask_masked


@pytest.mark.skipif(not hasattr(http_websocket, '_websocket_mask_cython'),
                    reason='Requires Cython')
def test_websocket_mask_cython() -> None:
    message = bytearray(websocket_mask_data)
    http_websocket._websocket_mask_cython(
        websocket_mask_mask, message)
    assert message == websocket_mask_masked


def test_websocket_mask_python_empty() -> None:
    message = bytearray()
    http_websocket._websocket_mask_python(
        websocket_mask_mask, message)
    assert message == bytearray()


@pytest.mark.skipif(not hasattr(http_websocket, '_websocket_mask_cython'),
                    reason='Requires Cython')
def test_websocket_mask_cython_empty() -> None:
    message = bytearray()
    http_websocket._websocket_mask_cython(
        websocket_mask_mask, message)
    assert message == bytearray()


def test_msgtype_aliases() -> None:
    assert aiohttp.WSMsgType.TEXT == aiohttp.WSMsgType.text
    assert aiohttp.WSMsgType.BINARY == aiohttp.WSMsgType.binary
    assert aiohttp.WSMsgType.PING == aiohttp.WSMsgType.ping
    assert aiohttp.WSMsgType.PONG == aiohttp.WSMsgType.pong
    assert aiohttp.WSMsgType.CLOSE == aiohttp.WSMsgType.close
    assert aiohttp.WSMsgType.CLOSED == aiohttp.WSMsgType.closed
    assert aiohttp.WSMsgType.ERROR == aiohttp.WSMsgType.error


def test_parse_compress_frame_single(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001))
    res = parser.parse_frame(b'1')
    fin, opcode, payload, compress = res[0]

    assert (1, 1, b'1', True) == (fin, opcode, payload, not not compress)


def test_parse_compress_frame_multi(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b01000001, 126))
    parser.parse_frame(struct.pack('!H', 4))
    res = parser.parse_frame(b'1234')
    fin, opcode, payload, compress = res[0]
    assert (0, 1, b'1234', True) == (fin, opcode, payload, not not compress)

    parser.parse_frame(struct.pack('!BB', 0b10000001, 126))
    parser.parse_frame(struct.pack('!H', 4))
    res = parser.parse_frame(b'1234')
    fin, opcode, payload, compress = res[0]
    assert (1, 1, b'1234', True) == (fin, opcode, payload, not not compress)

    parser.parse_frame(struct.pack('!BB', 0b10000001, 126))
    parser.parse_frame(struct.pack('!H', 4))
    res = parser.parse_frame(b'1234')
    fin, opcode, payload, compress = res[0]
    assert (1, 1, b'1234', False) == (fin, opcode, payload, not not compress)


def test_parse_compress_error_frame(parser) -> None:
    parser.parse_frame(struct.pack('!BB', 0b01000001, 0b00000001))
    parser.parse_frame(b'1')

    with pytest.raises(WebSocketError) as ctx:
        parser.parse_frame(struct.pack('!BB', 0b11000001, 0b00000001))
        parser.parse_frame(b'1')

    assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


def test_parse_no_compress_frame_single() -> None:
    parser_no_compress = WebSocketReader(out, 0, compress=False)
    with pytest.raises(WebSocketError) as ctx:
        parser_no_compress.parse_frame(struct.pack(
            '!BB', 0b11000001, 0b00000001))
        parser_no_compress.parse_frame(b'1')

    assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


def test_msg_too_large(out) -> None:
    parser = WebSocketReader(out, 256, compress=False)
    data = build_frame(b'text'*256, WSMsgType.TEXT)
    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)
    assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG


def test_msg_too_large_not_fin(out) -> None:
    parser = WebSocketReader(out, 256, compress=False)
    data = build_frame(b'text'*256, WSMsgType.TEXT, is_fin=False)
    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)
    assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG


def test_compressed_msg_too_large(out) -> None:
    parser = WebSocketReader(out, 256, compress=True)
    data = build_frame(b'aaa'*256, WSMsgType.TEXT, compress=True)
    with pytest.raises(WebSocketError) as ctx:
        parser._feed_data(data)
    assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG
