import csv
import io
import sys
from typing import AsyncIterator, Callable, List, Protocol, Type

import pytest

from aiocsv._parser import Parser as CParser
from aiocsv.parser import Parser as PyParser
from aiocsv.protocols import DialectLike, WithAsyncRead


class Parser(Protocol):
    def __aiter__(self) -> AsyncIterator[List[str]]: ...
    @property
    def line_num(self) -> int: ...


PARSERS: List[Callable[[WithAsyncRead, DialectLike], Parser]] = [PyParser, CParser]
PARSER_NAMES: List[str] = ["pure_python_parser", "c_parser"]


class AsyncStringIO:
    """Simple wrapper to fulfill WithAsyncRead around a string"""

    def __init__(self, data: str = "") -> None:
        self.ptr = 0
        self.data = data

    async def read(self, size: int) -> str:
        start = self.ptr
        self.ptr += size
        return self.data[start : self.ptr]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_simple(parser: Type[Parser]):
    data = 'abc,"def",ghi\r\n' '"j""k""l",mno,pqr\r\n' 'stu,vwx,"yz"\r\n'

    csv_result = list(csv.reader(io.StringIO(data, newline="")))
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv.get_dialect("excel"))  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [
        ["abc", "def", "ghi"],
        ['j"k"l', "mno", "pqr"],
        ["stu", "vwx", "yz"],
    ]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_escapes(parser: Type[Parser]):
    data = 'ab$"c,de$\nf\r\n' '"$"",$$gh$"\r\n' '"i\nj",k$,\r\n'
    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=True)
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [['ab"c', "de\nf"], ['"', '$gh"'], ["i\nj", "k,"]]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_empty(parser: Type[Parser]):
    data = "\r\n  a,,\r\n,\r\n  "

    csv_parser = csv.reader(io.StringIO(data, newline=""), skipinitialspace=True, strict=True)
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [[], ["a", "", ""], ["", ""], [""]]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_nonnumeric(parser: Type[Parser]):
    data = '1,2\n"a",,3.14'

    csv_parser = csv.reader(
        io.StringIO(data, newline=""), quoting=csv.QUOTE_NONNUMERIC, strict=True
    )
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [[1.0, 2.0], ["a", "", 3.14]]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_nonnumeric_invalid(parser: Type[Parser]):
    data = "1,2\na,3.14\n"

    csv_parser = csv.reader(
        io.StringIO(data, newline=""), quoting=csv.QUOTE_NONNUMERIC, strict=True
    )

    with pytest.raises(ValueError):
        list(csv_parser)

    with pytest.raises(ValueError):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_none_quoting(parser: Type[Parser]):
    data = '1" hello,"2\na","3.14"'

    csv_parser = csv.reader(io.StringIO(data, newline=""), quoting=csv.QUOTE_NONE, strict=True)
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [['1" hello', '"2'], ['a"', '"3.14"']]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_weird_quotes(parser: Type[Parser]):
    data = 'a"b,$"cd"\r\n' '"ef"g",\r\n' '"$"""","e"$f"\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=False)
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [['a"b', '"cd"'], ['efg"', ""], ['""', 'e$f"']]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_strict_quoting(parser: Type[Parser]):
    data = '"ab"c,def\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), strict=True)

    with pytest.raises(csv.Error, match="',' expected after '\"'"):
        list(csv_parser)

    with pytest.raises(csv.Error, match="',' expected after '\"'"):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.skipif(
    sys.version_info < (3, 12, 9),
    reason="CPython bug gh-113785 was fixed in 3.12.9",
)
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_weird_quotes_nonnumeric(parser: Type[Parser]):
    data = '3.0,\r\n"1."5,"15"\r\n$2,"-4".5\r\n-5$.2,-11'

    csv_parser = csv.reader(
        io.StringIO(data, newline=""),
        quoting=csv.QUOTE_NONNUMERIC,
        escapechar="$",
        strict=False,
    )
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [[3.0, ""], ["1.5", "15"], [2.0, "-4.5"], [-5.2, -11.0]]


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_escape_after_quote_in_quoted(parser: Type[Parser]):
    data = '"fo"$o\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$")
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]
    expected_result = [["fo$o"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_escaped_crlf(parser: Type[Parser]):
    data = "foo$\r\nbar\r\n"

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$")
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]
    expected_result = [["foo\r"], ["bar"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_escaped_crlf_in_quoted(parser: Type[Parser]):
    data = '"foo$\r\n",bar\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$")
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]
    expected_result = [["foo\r\n", "bar"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_consecutive_newlines(parser: Type[Parser]):
    data = "foo\r\rbar\n\rbaz\n\nspam\r\n\neggs"

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$")
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]
    expected_result = [["foo"], [], ["bar"], [], ["baz"], [], ["spam"], [], ["eggs"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_line_num(parser: Type[Parser]):
    data = 'foo,bar,baz\r\nspam,"egg\reggs",milk\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""))
    csv_result = [(csv_parser.line_num, line) for line in csv_parser]

    custom_parser = parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    custom_result = [(custom_parser.line_num, line) async for line in custom_parser]

    expected_result = [
        (1, ["foo", "bar", "baz"]),
        (3, ["spam", "egg\reggs", "milk"]),
    ]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_field_size_limit(parser: Type[Parser]):
    csv.field_size_limit(64)

    data = "a" * 65 + "\r\n"

    csv_parser = csv.reader(io.StringIO(data, newline=""), strict=True)

    with pytest.raises(csv.Error, match=r"field larger than field limit \(64\)"):
        list(csv_parser)

    with pytest.raises(csv.Error, match=r"field larger than field limit \(64\)"):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_unterminated_quote(parser: Type[Parser]):
    data = '"abc\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), strict=True)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        list(csv_parser)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_unterminated_quote_non_strict(parser: Type[Parser]):
    data = '"abc\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), strict=False)
    csv_result = list(csv_parser)

    custom_parser = parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    custom_result = [line async for line in custom_parser]

    expected_result = [["abc\r\n"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_eof_in_escape(parser: Type[Parser]):
    data = "a$"

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=True)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        list(csv_parser)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_eof_in_escape_non_strict(parser: Type[Parser]):
    data = "a$"

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=False)
    csv_result = list(csv_parser)

    custom_parser = parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    custom_result = [line async for line in custom_parser]

    expected_result = [["a\n"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_eof_in_quoted_escape(parser: Type[Parser]):
    data = '"a$'

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=True)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        list(csv_parser)

    with pytest.raises(csv.Error, match=r"unexpected end of data"):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_eof_in_quoted_escape_non_strict(parser: Type[Parser]):
    data = '"a$'

    csv_parser = csv.reader(io.StringIO(data, newline=""), escapechar="$", strict=False)
    csv_result = list(csv_parser)

    custom_parser = parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    custom_result = [line async for line in custom_parser]

    expected_result = [["a\n"]]

    assert csv_result == expected_result
    assert custom_result == expected_result


@pytest.mark.asyncio
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_no_newline_at_the_end(parser: Type[Parser]):
    data = "pi,3.1416\r\nsqrt2,1.4142\r\nphi,1.618\r\ne,2.7183"

    csv_result = list(csv.reader(io.StringIO(data, newline="")))
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv.get_dialect("excel"))  # type: ignore
    ]

    assert csv_result == custom_result
    assert custom_result == [
        ["pi", "3.1416"],
        ["sqrt2", "1.4142"],
        ["phi", "1.618"],
        ["e", "2.7183"],
    ]


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 12), reason="csv.QUOTE_STRINGS was added in 3.12")
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_quote_strings(parser: Type[Parser]):
    data = '3.14,,"abc",""\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), quoting=csv.QUOTE_STRINGS, strict=True)  # type: ignore
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    if sys.version_info < (3, 13):
        # https://github.com/python/cpython/issues/113732
        assert csv_result == [["3.14", "", "abc", ""]]
    else:
        assert csv_result == [[3.14, None, "abc", ""]]
    assert custom_result == [[3.14, None, "abc", ""]]


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 12), reason="csv.QUOTE_STRINGS was added in 3.12")
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_quote_strings_non_float(parser: Type[Parser]):
    data = "abc"

    csv_parser = csv.reader(io.StringIO(data, newline=""), quoting=csv.QUOTE_STRINGS, strict=True)  # type: ignore
    if sys.version_info < (3, 13):
        # https://github.com/python/cpython/issues/113732
        assert list(csv_parser) == [["abc"]]
    else:
        with pytest.raises(ValueError):
            list(csv_parser)

    with pytest.raises(ValueError):
        [r async for r in parser(AsyncStringIO(data), csv_parser.dialect)]  # type: ignore


@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 12), reason="csv.QUOTE_NOTNULL was added in 3.12")
@pytest.mark.parametrize("parser", PARSERS, ids=PARSER_NAMES)
async def test_parsing_quote_not_null(parser: Type[Parser]):
    data = '3.14,,abc,""\r\n'

    csv_parser = csv.reader(io.StringIO(data, newline=""), quoting=csv.QUOTE_NOTNULL, strict=True)  # type: ignore
    csv_result = list(csv_parser)
    custom_result = [
        r async for r in parser(AsyncStringIO(data), csv_parser.dialect)  # type: ignore
    ]

    if sys.version_info < (3, 13):
        # https://github.com/python/cpython/issues/113732
        assert csv_result == [["3.14", "", "abc", ""]]
    else:
        assert csv_result == [["3.14", None, "abc", ""]]
    assert custom_result == [["3.14", None, "abc", ""]]
