# coding: utf-8

# Copyright 2014-2025 Álvaro Justen <https://github.com/turicas/rows/>
#    This program is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General
#    Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option)
#    any later version.
#    This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
#    warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for
#    more details.
#    You should have received a copy of the GNU Lesser General Public License along with this program.  If not, see
#    <http://www.gnu.org/licenses/>.

from __future__ import unicode_literals

import csv
import io
import tempfile
import textwrap
import unittest
from collections import OrderedDict
from textwrap import dedent

import mock
import pytest

import rows
import rows.plugins.plugin_csv
import tests.utils as utils
from rows.utils import Source


def make_csv_data(quote_char, field_delimiter, line_delimiter):
    data = [["field1", "field2", "field3"], ["value1", "value2", "value3"]]
    lines = [
        ["{}{}{}".format(quote_char, value, quote_char) for value in line]
        for line in data
    ]
    lines = line_delimiter.join([field_delimiter.join(line) for line in data])
    return data, lines


class PluginCsvTestCase(utils.RowsTestMixIn, unittest.TestCase):

    plugin_name = "csv"
    file_extension = "csv"
    filename = "tests/data/all-field-types.csv"
    encoding = "utf-8"
    expected_meta = {
        "imported_from": "csv",
        "source": Source(uri=filename, plugin_name=plugin_name, encoding=encoding),
    }

    def test_imports(self):
        # The order must be this one to force loading the lazy module (first the real function, then the alias)
        assert rows.plugins.plugin_csv.import_from_csv is rows.import_from_csv
        assert rows.plugins.plugin_csv.export_to_csv is rows.export_to_csv

    @mock.patch("rows.plugins.utils.create_table")
    def test_import_from_csv_uses_create_table(self, mocked_create_table):
        mocked_create_table.return_value = 42
        kwargs = {"some_key": 123, "other": 456}
        result = rows.import_from_csv(self.filename, encoding="utf-8", **kwargs)
        assert mocked_create_table.called
        assert mocked_create_table.call_count == 1
        assert result == 42

    @mock.patch("rows.plugins.utils.create_table")
    def test_import_from_csv_retrieve_desired_data(self, mocked_create_table):
        mocked_create_table.return_value = 42

        # import using filename
        rows.import_from_csv(self.filename)
        call_args = mocked_create_table.call_args_list[0]
        self.assert_create_table_data(call_args, expected_meta=self.expected_meta)

        # import using fobj
        with open(self.filename, "rb") as fobj:
            rows.import_from_csv(fobj)
            call_args = mocked_create_table.call_args_list[1]
            self.assert_create_table_data(call_args, expected_meta=self.expected_meta)

    @mock.patch("rows.plugins.utils.create_table")
    def test_import_from_csv_discover_dialect(self, mocked_create_table):
        data, lines = make_csv_data(
            quote_char="'", field_delimiter=";", line_delimiter="\r\n"
        )
        fobj = io.BytesIO()
        fobj.write(lines.encode("utf-8"))
        fobj.seek(0)

        rows.import_from_csv(fobj)
        call_args = mocked_create_table.call_args_list[0]
        assert data == list(call_args[0][0])

    def test_import_from_csv_discover_dialect_decode_error(self):

        # Create a 1024-bytes line (if encoded to ASCII, UTF-8 etc.)
        line = '"' + ("a" * 508) + '", "' + ("b" * 508) + '"\r\n'
        lines = 256 * line  # 256KiB

        # Now change the last byte (in the 256KiB sample) to have half of a
        # character representation (when encoded to UTF-8)
        data = lines[:-3] + '++Á"\r\n'
        data = data.encode("utf-8")

        # Should not raise `UnicodeDecodeError`
        table = rows.import_from_csv(
            io.BytesIO(data), encoding="utf-8", sample_size=262144
        )

        last_row = table[-1]
        last_column = "b" * 508
        assert getattr(last_row, last_column) == "b" * 508 + "++Á"

    def test_import_from_csv_impossible_dialect(self):
        # Fix a bug from: https://github.com/turicas/rows/issues/214
        # The following CSV will make the `csv`'s sniff to return an impossible
        # dialect to be used (having doublequote = False and escapechar =
        # None). See more at:
        # https://docs.python.org/3/library/csv.html#csv.Dialect.doublequote

        encoding = "utf-8"
        data = dedent(
            """
            field1,field2
            1,2
            3,4
            5,6
            """.strip()
        ).encode(encoding)

        dialect = rows.plugins.plugin_csv.discover_dialect(data, encoding)
        assert dialect.doublequote is True
        assert dialect.escapechar is None

    def test_import_from_csv_excel_semicolon_dialect(self):
        encoding = "utf-8"
        data = dedent(
            """
            field1;field2
            1;2
            3;4
            5;6
            """.strip()
        ).encode(encoding)

        table = rows.import_from_csv(io.BytesIO(data), dialect="excel-semicolon")
        assert table.field_names == ["field1", "field2"]
        assert table[0].field1 == 1
        assert table[0].field2 == 2
        assert table[1].field1 == 3
        assert table[1].field2 == 4
        assert table[2].field1 == 5
        assert table[2].field2 == 6

    @mock.patch("rows.plugins.utils.create_table")
    def test_import_from_csv_force_dialect(self, mocked_create_table):
        data, lines = make_csv_data(
            quote_char="'", field_delimiter="\t", line_delimiter="\r\n"
        )
        fobj = io.BytesIO()
        fobj.write(lines.encode("utf-8"))
        fobj.seek(0)

        rows.import_from_csv(fobj, dialect="excel-tab")
        call_args = mocked_create_table.call_args_list[0]
        assert data == list(call_args[0][0])

    def test_detect_dialect_more_data(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        self.files_to_delete.append(filename)

        # If the sniffer reads only the first line, it will think the delimiter
        # is ',' instead of ';'
        data = textwrap.dedent(
            """
            field1,samefield;field2,other
            row1value1;row1value2
            row2value1;row2value2
            """
        ).strip()
        with open(filename, "wb") as fobj:
            fobj.write(data.encode("utf-8"))

        table = rows.import_from_csv(filename, encoding="utf-8")
        assert table.field_names == ["field1_samefield", "field2_other"]
        assert table[0].field1_samefield == "row1value1"
        assert table[0].field2_other == "row1value2"
        assert table[1].field1_samefield == "row2value1"
        assert table[1].field2_other == "row2value2"

    def test_detect_weird_dialect(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        self.files_to_delete.append(filename)

        # If the sniffer reads only the first line, it will think the delimiter
        # is ',' instead of ';'
        encoding = "utf-8"
        data = io.BytesIO(
            textwrap.dedent(
                """
            field1|field2|field3|field4
            1|2|3|4
            5|6|7|8
            9|0|1|2
            """
            )
            .strip()
            .encode(encoding)
        )

        table = rows.import_from_csv(data, encoding=encoding, lazy=False)
        assert table.field_names == ["field1", "field2", "field3", "field4"]

        expected = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]]
        for expected_data, row in zip(expected, table):
            row = [row.field1, row.field2, row.field3, row.field4]
            assert expected_data == row

    def test_detect_dialect_using_json(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        encoding = "utf-8"
        self.files_to_delete.append(filename)

        # Using JSON will force the sniffer to do not include ':', '}' in the
        # possible delimiters
        table = rows.Table(
            fields=OrderedDict(
                [
                    ("jsoncolumn1", rows.fields.JSONField),
                    ("jsoncolumn2", rows.fields.JSONField),
                ]
            )
        )
        table.append({"jsoncolumn1": '{"a": 42}', "jsoncolumn2": '{"b": 43}'})
        table.append({"jsoncolumn1": '{"c": 44}', "jsoncolumn2": '{"d": 45}'})
        rows.export_to_csv(table, filename, encoding=encoding)

        table = rows.import_from_csv(filename, encoding=encoding)

        assert table.field_names == ["jsoncolumn1", "jsoncolumn2"]
        self.assertDictEqual(table[0].jsoncolumn1, {"a": 42})
        self.assertDictEqual(table[0].jsoncolumn2, {"b": 43})
        self.assertDictEqual(table[1].jsoncolumn1, {"c": 44})
        self.assertDictEqual(table[1].jsoncolumn2, {"d": 45})

    @mock.patch("rows.plugins.utils.serialize")
    def test_export_to_csv_uses_serialize(self, mocked_serialize):
        temp = tempfile.NamedTemporaryFile(delete=False)
        self.files_to_delete.append(temp.name)
        kwargs = {"test": 123, "parameter": 3.14}
        mocked_serialize.return_value = iter([utils.table.fields.keys()])

        rows.export_to_csv(utils.table, temp.name, encoding="utf-8", **kwargs)
        assert mocked_serialize.called
        assert mocked_serialize.call_count == 1

        call = mocked_serialize.call_args
        assert call[0] == (utils.table,)
        assert call[1] == kwargs

    def test_export_to_csv_filename(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        self.files_to_delete.append(temp.name)
        rows.export_to_csv(utils.table, temp.name)
        # TODO: test file contents instead of this side-effect
        table = rows.import_from_csv(temp.name)
        self.assert_table_equal(table, utils.table)
        temp.file.seek(0)
        result = temp.file.read()
        export_in_memory = rows.export_to_csv(utils.table, None)
        assert result == export_in_memory

    def test_export_to_csv_fobj_binary(self):
        temp = tempfile.NamedTemporaryFile(delete=False, mode="wb")
        self.files_to_delete.append(temp.name)
        fobj = temp.file
        result = rows.export_to_csv(utils.table, fobj, encoding="utf-8")
        assert result is fobj
        assert not fobj.closed
        # TODO: test file contents instead of this side-effect
        table = rows.import_from_csv(temp.name)
        self.assert_table_equal(table, utils.table)

    def test_export_to_csv_fobj_binary_without_encoding(self):
        temp = tempfile.NamedTemporaryFile(delete=False, mode="wb")
        self.files_to_delete.append(temp.name)
        fobj = temp.file
        with pytest.raises(ValueError, match="export_to_csv must receive an encoding when file is in binary mode"):
            rows.export_to_csv(utils.table, fobj, encoding=None)

    def test_export_to_csv_fobj_text(self):
        tmp = tempfile.NamedTemporaryFile(delete=False)
        tmp.close()
        temp = io.TextIOWrapper(io.open(tmp.name, mode="wb"), encoding="utf-8")
        self.files_to_delete.append(tmp.name)
        result = rows.export_to_csv(utils.table, temp)
        assert result is temp
        assert not temp.closed
        # TODO: test file contents instead of this side-effect
        table = rows.import_from_csv(tmp.name)
        self.assert_table_equal(table, utils.table)

    def test_issue_168(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        self.files_to_delete.append(filename)

        table = rows.Table(fields=OrderedDict([("jsoncolumn", rows.fields.JSONField)]))
        table.append({"jsoncolumn": '{"python": 42}'})
        rows.export_to_csv(table, filename)

        table2 = rows.import_from_csv(filename)
        self.assert_table_equal(table, table2)

    def test_quotes(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        self.files_to_delete.append(filename)

        table = rows.Table(
            fields=OrderedDict(
                [
                    ("field_1", rows.fields.TextField),
                    ("field_2", rows.fields.TextField),
                    ("field_3", rows.fields.TextField),
                    ("field_4", rows.fields.TextField),
                ]
            )
        )
        table.append(
            {
                "field_1": '"quotes"',
                "field_2": 'test "quotes"',
                "field_3": '"quotes" test',
                "field_4": 'test "quotes" test',
            }
        )
        # we need this line row since `"quotes"` on `field_1` could be
        # `JSONField` or `TextField`
        table.append(
            {
                "field_1": "noquotes",
                "field_2": 'test "quotes"',
                "field_3": '"quotes" test',
                "field_4": 'test "quotes" test',
            }
        )
        rows.export_to_csv(table, filename)

        table2 = rows.import_from_csv(filename)
        self.assert_table_equal(table, table2)

    def test_export_to_csv_accepts_dialect(self):
        result_1 = rows.export_to_csv(utils.table, dialect=csv.excel_tab)
        result_2 = rows.export_to_csv(utils.table, dialect=csv.excel)
        assert result_1.replace(b"\t", b",") == result_2

    def test_export_callback(self):
        table = rows.import_from_dicts([{"id": number} for number in range(10)])
        myfunc = mock.Mock()
        rows.export_to_csv(table, callback=myfunc, batch_size=3)
        assert myfunc.call_count == 4
        assert [x[0][0] for x in myfunc.call_args_list] == [3, 6, 9, 10]

    def test_import_field_limit(self):
        temp = tempfile.NamedTemporaryFile(delete=False)
        filename = "{}.{}".format(temp.name, self.file_extension)
        self.files_to_delete.append(filename)

        table = rows.import_from_dicts([{"f1": "a" * 132000}])
        rows.export_to_csv(table, filename)

        # The following line must not raise the exception:
        # `_csv.Error: field larger than field limit (131072)`
        rows.import_from_csv(filename)
