# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.  See http://www.gnu.org/copyleft/lgpl.html for the full text
# of the license.

__author__ = "Martin Pitt"
__copyright__ = """
(c) 2013 Canonical Ltd.
(c) 2017 - 2022 Martin Pitt <martin@piware.de>
"""

import shutil
import subprocess
import sys
import unittest

import dbus
import dbus.mainloop.glib

import dbusmock

dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
have_pkcheck = shutil.which("pkcheck")


@unittest.skipUnless(have_pkcheck, "pkcheck not installed")
class TestPolkit(dbusmock.DBusTestCase):
    """Test mocking polkitd"""

    @classmethod
    def setUpClass(cls):
        cls.start_system_bus()
        cls.dbus_con = cls.get_dbus(True)

    def setUp(self):
        (self.p_mock, self.obj_polkitd) = self.spawn_server_template("polkitd", {}, stdout=subprocess.PIPE)
        self.dbusmock = dbus.Interface(self.obj_polkitd, dbusmock.MOCK_IFACE)

    def tearDown(self):
        self.p_mock.stdout.close()
        self.p_mock.terminate()
        self.p_mock.wait()

    def test_default(self):
        self.check_action("org.freedesktop.test.frobnicate", False)

    def test_allow_unknown(self):
        self.dbusmock.AllowUnknown(True)
        self.check_action("org.freedesktop.test.frobnicate", True)
        self.dbusmock.AllowUnknown(False)
        self.check_action("org.freedesktop.test.frobnicate", False)

    def test_set_allowed(self):
        self.dbusmock.SetAllowed(["org.freedesktop.test.frobnicate", "org.freedesktop.test.slap"])
        self.check_action("org.freedesktop.test.frobnicate", True)
        self.check_action("org.freedesktop.test.slap", True)
        self.check_action("org.freedesktop.test.wobble", False)

    def test_hanging_call(self):
        self.dbusmock.SimulateHang(True)
        self.assertFalse(self.dbusmock.HaveHangingCalls())
        pkcheck = self.check_action_run("org.freedesktop.test.frobnicate")
        with self.assertRaises(subprocess.TimeoutExpired):
            pkcheck.wait(0.8)

        self.assertTrue(self.dbusmock.HaveHangingCalls())
        pkcheck.stdout.close()
        pkcheck.kill()
        pkcheck.wait()

    def test_hanging_call_return(self):
        self.dbusmock.SetAllowed(["org.freedesktop.test.frobnicate"])
        self.dbusmock.SimulateHangActions(["org.freedesktop.test.frobnicate", "org.freedesktop.test.slap"])
        self.assertFalse(self.dbusmock.HaveHangingCalls())

        frobnicate_pkcheck = self.check_action_run("org.freedesktop.test.frobnicate")
        slap_pkcheck = self.check_action_run("org.freedesktop.test.slap")

        with self.assertRaises(subprocess.TimeoutExpired):
            frobnicate_pkcheck.wait(0.3)
        with self.assertRaises(subprocess.TimeoutExpired):
            slap_pkcheck.wait(0.3)

        self.assertTrue(self.dbusmock.HaveHangingCalls())
        self.dbusmock.ReleaseHangingCalls()

        self.check_action_result(frobnicate_pkcheck, True)
        self.check_action_result(slap_pkcheck, False)

    def test_delayed_call(self):
        self.dbusmock.SetDelay(3)
        pkcheck = self.check_action_run("org.freedesktop.test.frobnicate")
        with self.assertRaises(subprocess.TimeoutExpired):
            pkcheck.wait(0.8)
        pkcheck.stdout.close()
        pkcheck.kill()
        pkcheck.wait()

    def test_delayed_call_return(self):
        self.dbusmock.SetDelay(1)
        self.dbusmock.SetAllowed(["org.freedesktop.test.frobnicate"])
        pkcheck = self.check_action_run("org.freedesktop.test.frobnicate")
        with self.assertRaises(subprocess.TimeoutExpired):
            pkcheck.wait(0.8)
        self.check_action_result(pkcheck, True)

    @staticmethod
    def check_action_run(action):
        # pylint: disable=consider-using-with
        return subprocess.Popen(
            ["pkcheck", "--action-id", action, "--process", "123"],
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
        )

    def check_action_result(self, pkcheck, expect_allow):
        out = pkcheck.communicate()[0]
        if expect_allow:
            self.assertEqual(pkcheck.returncode, 0)
            self.assertEqual(out, "test=test\n")
        else:
            self.assertNotEqual(pkcheck.returncode, 0)
            self.assertEqual(out, "test=test\nNot authorized.\n")

    def check_action(self, action, expect_allow):
        self.check_action_result(self.check_action_run(action), expect_allow)


if __name__ == "__main__":
    # avoid writing to stderr
    unittest.main(testRunner=unittest.TextTestRunner(stream=sys.stdout))
