from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from shutil import which
from subprocess import run
from typing import Optional

import pytest

import aristaproto
from tests.output_aristaproto import (
    map,
    nested,
    oneof,
    repeated,
    repeatedpacked,
)


oneof_example = oneof.Test().from_dict(
    {"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"}
)

len_oneof = len(oneof_example)

nested_example = nested.Test().from_dict(
    {
        "nested": {"count": 1},
        "sibling": {"foo": 2},
        "sibling2": {"foo": 3},
        "msg": nested.TestMsg.THIS,
    }
)

repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]})

packed_example = repeatedpacked.Test().from_dict(
    {"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]}
)

map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})

streams_path = Path("tests/streams/")

java = which("java")


def test_load_varint_too_long():
    with BytesIO(
        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01"
    ) as stream, pytest.raises(ValueError):
        aristaproto.load_varint(stream)

    with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream:
        # This should not raise a ValueError, as it is within 64 bits
        aristaproto.load_varint(stream)


def test_load_varint_file():
    with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
        assert aristaproto.load_varint(stream) == (8, b"\x08")  # Single-byte varint
        stream.read(2)  # Skip until first multi-byte
        assert aristaproto.load_varint(stream) == (
            123456789,
            b"\x95\x9A\xEF\x3A",
        )  # Multi-byte varint


def test_load_varint_cutoff():
    with open(streams_path / "load_varint_cutoff.in", "rb") as stream:
        with pytest.raises(EOFError):
            aristaproto.load_varint(stream)

        stream.seek(1)
        with pytest.raises(EOFError):
            aristaproto.load_varint(stream)


def test_dump_varint_file(tmp_path):
    # Dump test varints to file
    with open(tmp_path / "dump_varint_file.out", "wb") as stream:
        aristaproto.dump_varint(8, stream)  # Single-byte varint
        aristaproto.dump_varint(123456789, stream)  # Multi-byte varint

    # Check that file contents are as expected
    with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open(
        streams_path / "message_dump_file_single.expected", "rb"
    ) as exp_stream:
        assert aristaproto.load_varint(test_stream) == aristaproto.load_varint(
            exp_stream
        )
        exp_stream.read(2)
        assert aristaproto.load_varint(test_stream) == aristaproto.load_varint(
            exp_stream
        )


def test_parse_fields():
    with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
        parsed_bytes = aristaproto.parse_fields(stream.read())

    with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
        parsed_stream = aristaproto.load_fields(stream)
        for field in parsed_bytes:
            assert field == next(parsed_stream)


def test_message_dump_file_single(tmp_path):
    # Write the message to the stream
    with open(tmp_path / "message_dump_file_single.out", "wb") as stream:
        oneof_example.dump(stream)

    # Check that the outputted file is exactly as expected
    with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open(
        streams_path / "message_dump_file_single.expected", "rb"
    ) as exp_stream:
        assert test_stream.read() == exp_stream.read()


def test_message_dump_file_multiple(tmp_path):
    # Write the same Message twice and another, different message
    with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream:
        oneof_example.dump(stream)
        oneof_example.dump(stream)
        nested_example.dump(stream)

    # Check that all three Messages were outputted to the file correctly
    with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open(
        streams_path / "message_dump_file_multiple.expected", "rb"
    ) as exp_stream:
        assert test_stream.read() == exp_stream.read()


def test_message_dump_delimited(tmp_path):
    with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
        oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
        oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
        nested_example.dump(stream, aristaproto.SIZE_DELIMITED)

    with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
        streams_path / "delimited_messages.in", "rb"
    ) as exp_stream:
        assert test_stream.read() == exp_stream.read()


def test_message_len():
    assert len_oneof == len(bytes(oneof_example))
    assert len(nested_example) == len(bytes(nested_example))


def test_message_load_file_single():
    with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
        assert oneof.Test().load(stream) == oneof_example
        stream.seek(0)
        assert oneof.Test().load(stream, len_oneof) == oneof_example


def test_message_load_file_multiple():
    with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream:
        oneof_size = len_oneof
        assert oneof.Test().load(stream, oneof_size) == oneof_example
        assert oneof.Test().load(stream, oneof_size) == oneof_example
        assert nested.Test().load(stream) == nested_example
        assert stream.read(1) == b""


def test_message_load_too_small():
    with open(
        streams_path / "message_dump_file_single.expected", "rb"
    ) as stream, pytest.raises(ValueError):
        oneof.Test().load(stream, len_oneof - 1)


def test_message_load_delimited():
    with open(streams_path / "delimited_messages.in", "rb") as stream:
        assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example
        assert oneof.Test().load(stream, aristaproto.SIZE_DELIMITED) == oneof_example
        assert nested.Test().load(stream, aristaproto.SIZE_DELIMITED) == nested_example
        assert stream.read(1) == b""


def test_message_load_too_large():
    with open(
        streams_path / "message_dump_file_single.expected", "rb"
    ) as stream, pytest.raises(ValueError):
        oneof.Test().load(stream, len_oneof + 1)


def test_message_len_optional_field():
    @dataclass
    class Request(aristaproto.Message):
        flag: Optional[bool] = aristaproto.message_field(1, wraps=aristaproto.TYPE_BOOL)

    assert len(Request()) == len(b"")
    assert len(Request(flag=True)) == len(b"\n\x02\x08\x01")
    assert len(Request(flag=False)) == len(b"\n\x00")


def test_message_len_repeated_field():
    assert len(repeated_example) == len(bytes(repeated_example))


def test_message_len_packed_field():
    assert len(packed_example) == len(bytes(packed_example))


def test_message_len_map_field():
    assert len(map_example) == len(bytes(map_example))


def test_message_len_empty_string():
    @dataclass
    class Empty(aristaproto.Message):
        string: str = aristaproto.string_field(1, "group")
        integer: int = aristaproto.int32_field(2, "group")

    empty = Empty().from_dict({"string": ""})
    assert len(empty) == len(bytes(empty))


def test_calculate_varint_size_negative():
    single_byte = -1
    multi_byte = -10000000
    edge = -(1 << 63)
    beyond = -(1 << 63) - 1
    before = -(1 << 63) + 1

    assert (
        aristaproto.size_varint(single_byte)
        == len(aristaproto.encode_varint(single_byte))
        == 10
    )
    assert (
        aristaproto.size_varint(multi_byte)
        == len(aristaproto.encode_varint(multi_byte))
        == 10
    )
    assert aristaproto.size_varint(edge) == len(aristaproto.encode_varint(edge)) == 10
    assert (
        aristaproto.size_varint(before) == len(aristaproto.encode_varint(before)) == 10
    )

    with pytest.raises(ValueError):
        aristaproto.size_varint(beyond)


def test_calculate_varint_size_positive():
    single_byte = 1
    multi_byte = 10000000

    assert aristaproto.size_varint(single_byte) == len(
        aristaproto.encode_varint(single_byte)
    )
    assert aristaproto.size_varint(multi_byte) == len(
        aristaproto.encode_varint(multi_byte)
    )


def test_dump_varint_negative(tmp_path):
    single_byte = -1
    multi_byte = -10000000
    edge = -(1 << 63)
    beyond = -(1 << 63) - 1
    before = -(1 << 63) + 1

    with open(tmp_path / "dump_varint_negative.out", "wb") as stream:
        aristaproto.dump_varint(single_byte, stream)
        aristaproto.dump_varint(multi_byte, stream)
        aristaproto.dump_varint(edge, stream)
        aristaproto.dump_varint(before, stream)

        with pytest.raises(ValueError):
            aristaproto.dump_varint(beyond, stream)

    with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open(
        tmp_path / "dump_varint_negative.out", "rb"
    ) as test_stream:
        assert test_stream.read() == exp_stream.read()


def test_dump_varint_positive(tmp_path):
    single_byte = 1
    multi_byte = 10000000

    with open(tmp_path / "dump_varint_positive.out", "wb") as stream:
        aristaproto.dump_varint(single_byte, stream)
        aristaproto.dump_varint(multi_byte, stream)

    with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open(
        streams_path / "dump_varint_positive.expected", "rb"
    ) as exp_stream:
        assert test_stream.read() == exp_stream.read()


# Java compatibility tests


@pytest.fixture(scope="module")
def compile_jar():
    # Skip if not all required tools are present
    if java is None:
        pytest.skip("`java` command is absent and is required")
    mvn = which("mvn")
    if mvn is None:
        pytest.skip("Maven is absent and is required")

    # Compile the JAR
    proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
    if proc_maven.returncode != 0:
        pytest.skip(
            "Maven compatibility-test.jar build failed (maybe Java version <11?)"
        )


jar = "tests/streams/java/target/compatibility-test.jar"


def run_jar(command: str, tmp_path):
    return run([java, "-jar", jar, command, tmp_path], check=True)


def run_java_single_varint(value: int, tmp_path) -> int:
    # Write single varint to file
    with open(tmp_path / "py_single_varint.out", "wb") as stream:
        aristaproto.dump_varint(value, stream)

    # Have Java read this varint and write it back
    run_jar("single_varint", tmp_path)

    # Read single varint from Java output file
    with open(tmp_path / "java_single_varint.out", "rb") as stream:
        returned = aristaproto.load_varint(stream)
        with pytest.raises(EOFError):
            aristaproto.load_varint(stream)

    return returned


def test_single_varint(compile_jar, tmp_path):
    single_byte = (1, b"\x01")
    multi_byte = (123456789, b"\x95\x9A\xEF\x3A")

    # Write a single-byte varint to a file and have Java read it back
    returned = run_java_single_varint(single_byte[0], tmp_path)
    assert returned == single_byte

    # Same for a multi-byte varint
    returned = run_java_single_varint(multi_byte[0], tmp_path)
    assert returned == multi_byte


def test_multiple_varints(compile_jar, tmp_path):
    single_byte = (1, b"\x01")
    multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
    over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")

    # Write two varints to the same file
    with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
        aristaproto.dump_varint(single_byte[0], stream)
        aristaproto.dump_varint(multi_byte[0], stream)
        aristaproto.dump_varint(over32[0], stream)

    # Have Java read these varints and write them back
    run_jar("multiple_varints", tmp_path)

    # Read varints from Java output file
    with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
        returned_single = aristaproto.load_varint(stream)
        returned_multi = aristaproto.load_varint(stream)
        returned_over32 = aristaproto.load_varint(stream)
        with pytest.raises(EOFError):
            aristaproto.load_varint(stream)

    assert returned_single == single_byte
    assert returned_multi == multi_byte
    assert returned_over32 == over32


def test_single_message(compile_jar, tmp_path):
    # Write message to file
    with open(tmp_path / "py_single_message.out", "wb") as stream:
        oneof_example.dump(stream)

    # Have Java read and return the message
    run_jar("single_message", tmp_path)

    # Read and check the returned message
    with open(tmp_path / "java_single_message.out", "rb") as stream:
        returned = oneof.Test().load(stream, len(bytes(oneof_example)))
        assert stream.read() == b""

    assert returned == oneof_example


def test_multiple_messages(compile_jar, tmp_path):
    # Write delimited messages to file
    with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
        oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)
        nested_example.dump(stream, aristaproto.SIZE_DELIMITED)

    # Have Java read and return the messages
    run_jar("multiple_messages", tmp_path)

    # Read and check the returned messages
    with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
        returned_oneof = oneof.Test().load(stream, aristaproto.SIZE_DELIMITED)
        returned_nested = nested.Test().load(stream, aristaproto.SIZE_DELIMITED)
        assert stream.read() == b""

    assert returned_oneof == oneof_example
    assert returned_nested == nested_example


def test_infinite_messages(compile_jar, tmp_path):
    num_messages = 5

    # Write delimited messages to file
    with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
        for x in range(num_messages):
            oneof_example.dump(stream, aristaproto.SIZE_DELIMITED)

    # Have Java read and return the messages
    run_jar("infinite_messages", tmp_path)

    # Read and check the returned messages
    messages = []
    with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
        while True:
            try:
                messages.append(oneof.Test().load(stream, aristaproto.SIZE_DELIMITED))
            except EOFError:
                break

    assert len(messages) == num_messages
