# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
#
# This file is part of logilab-common.
#
# logilab-common is free software: you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 2.1 of the License, or (at your option) any
# later version.
#
# logilab-common is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License along
# with logilab-common.  If not, see <http://www.gnu.org/licenses/>.
"""unit tests for logilab.common.shellutils"""

from io import StringIO
from os.path import join, dirname, abspath
from unittest.mock import patch

from logilab.common.testlib import TestCase, unittest_main

from logilab.common.shellutils import globfind, find, ProgressBar, RawInput


DATA_DIR = join(dirname(abspath(__file__)), "data", "find_test")


class FindTC(TestCase):
    def test_include(self):
        files = set(find(DATA_DIR, ".py"))
        self.assertSetEqual(
            files,
            {
                join(DATA_DIR, f)
                for f in [
                    "__init__.py",
                    "module.py",
                    "module2.py",
                    "noendingnewline.py",
                    "nonregr.py",
                    join("sub", "momo.py"),
                ]
            },
        )
        files = set(find(DATA_DIR, (".py",), blacklist=("sub",)))
        self.assertSetEqual(
            files,
            {
                join(DATA_DIR, f)
                for f in [
                    "__init__.py",
                    "module.py",
                    "module2.py",
                    "noendingnewline.py",
                    "nonregr.py",
                ]
            },
        )

    def test_exclude(self):
        files = set(find(DATA_DIR, (".py", ".pyc"), exclude=True))
        self.assertSetEqual(
            files,
            {
                join(DATA_DIR, f)
                for f in [
                    "foo.txt",
                    "newlines.txt",
                    "normal_file.txt",
                    "test.ini",
                    "test1.msg",
                    "test2.msg",
                    "spam.txt",
                    join("sub", "doc.txt"),
                    "write_protected_file.txt",
                ]
            },
        )

    def test_globfind(self):
        files = set(globfind(DATA_DIR, "*.py"))
        self.assertSetEqual(
            files,
            {
                join(DATA_DIR, f)
                for f in [
                    "__init__.py",
                    "module.py",
                    "module2.py",
                    "noendingnewline.py",
                    "nonregr.py",
                    join("sub", "momo.py"),
                ]
            },
        )
        files = set(globfind(DATA_DIR, "mo*.py"))
        self.assertSetEqual(
            files,
            {join(DATA_DIR, f) for f in ["module.py", "module2.py", join("sub", "momo.py")]},
        )
        files = set(globfind(DATA_DIR, "mo*.py", blacklist=("sub",)))
        self.assertSetEqual(files, {join(DATA_DIR, f) for f in ["module.py", "module2.py"]})


class ProgressBarTC(TestCase):
    def test_refresh(self):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        pgb = ProgressBar(20, stream=pgb_stream)
        self.assertEqual(
            pgb_stream.getvalue(), expected_stream.getvalue()
        )  # nothing print before refresh
        pgb.refresh()
        expected_stream.write("\r[" + " " * 20 + "]")
        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())

    def test_refresh_g_size(self):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        pgb = ProgressBar(20, 35, stream=pgb_stream)
        pgb.refresh()
        expected_stream.write("\r[" + " " * 35 + "]")
        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())

    def test_refresh_l_size(self):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        pgb = ProgressBar(20, 3, stream=pgb_stream)
        pgb.refresh()
        expected_stream.write("\r[" + " " * 3 + "]")
        self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())

    def _update_test(self, nbops, expected, size=None):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        if size is None:
            pgb = ProgressBar(nbops, stream=pgb_stream)
            size = 20
        else:
            pgb = ProgressBar(nbops, size, stream=pgb_stream)
        last = 0
        for round in expected:
            if not hasattr(round, "__int__"):
                dots, update = round
            else:
                dots, update = round, None
            pgb.update()
            if update or (update is None and dots != last):
                last = dots
                expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())

    def test_default(self):
        self._update_test(20, range(1, 21))

    def test_nbops_gt_size(self):
        """Test the progress bar for nbops > size"""

        def half(total):
            for counter in range(1, total + 1):
                yield counter // 2

        self._update_test(40, half(40))

    def test_nbops_lt_size(self):
        """Test the progress bar for nbops < size"""

        def double(total):
            for counter in range(1, total + 1):
                yield counter * 2

        self._update_test(10, double(10))

    def test_nbops_nomul_size(self):
        """Test the progress bar for size % nbops !=0 (non int number of dots per update)"""
        self._update_test(3, (6, 13, 20))

    def test_overflow(self):
        self._update_test(5, (8, 16, 25, 33, 42, (42, True)), size=42)

    def test_update_exact(self):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        size = 20
        pgb = ProgressBar(100, size, stream=pgb_stream)
        for dots in range(10, 105, 15):
            pgb.update(dots, exact=True)
            dots //= 5
            expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())

    def test_update_relative(self):
        pgb_stream = StringIO()
        expected_stream = StringIO()
        size = 20
        pgb = ProgressBar(100, size, stream=pgb_stream)
        for dots in range(5, 105, 5):
            pgb.update(5, exact=False)
            dots //= 5
            expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
            self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())


class RawInputTC(TestCase):
    def auto_input(self, *args):
        self.input_args = args
        return self.input_answer

    def setUp(self):
        null_printer = lambda x: None
        self.qa = RawInput(self.auto_input, null_printer)

    def test_ask_using_builtin_input(self):
        with patch("builtins.input", return_value="no"):
            qa = RawInput()
            answer = qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "no")

    def test_ask_default(self):
        self.input_answer = ""
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "yes")
        self.input_answer = "  "
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "yes")

    def test_ask_case(self):
        self.input_answer = "no"
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "no")
        self.input_answer = "No"
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "no")
        self.input_answer = "NO"
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "no")
        self.input_answer = "nO"
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "no")
        self.input_answer = "YES"
        answer = self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(answer, "yes")

    def test_ask_prompt(self):
        self.input_answer = ""
        self.qa.ask("text", ("yes", "no"), "yes")
        self.assertEqual(self.input_args[0], "text [Y(es)/n(o)]: ")
        self.qa.ask("text", ("y", "n"), "y")
        self.assertEqual(self.input_args[0], "text [Y/n]: ")
        self.qa.ask("text", ("n", "y"), "y")
        self.assertEqual(self.input_args[0], "text [n/Y]: ")
        self.qa.ask("text", ("yes", "no", "maybe", "1"), "yes")
        self.assertEqual(self.input_args[0], "text [Y(es)/n(o)/m(aybe)/1]: ")

    def test_ask_ambiguous(self):
        self.input_answer = "y"
        self.assertRaises(Exception, self.qa.ask, "text", ("yes", "yep"), "yes")

    def test_confirm(self):
        self.input_answer = "y"
        self.assertEqual(self.qa.confirm("Say yes"), True)
        self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), True)
        self.input_answer = "n"
        self.assertEqual(self.qa.confirm("Say yes"), False)
        self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), False)
        self.input_answer = ""
        self.assertEqual(self.qa.confirm("Say default"), True)
        self.assertEqual(self.qa.confirm("Say default", default_is_yes=False), False)


if __name__ == "__main__":
    unittest_main()
