#!/usr/bin/python3
# -*- coding: utf-8 -*-

from gi.repository import GLib

import apt_pkg
import aptsources.distro
import aptsources.sourceslist

import dbus
import logging
import glob
import os
import subprocess
import sys
import threading
import time
import unittest
import pathlib

from dbus.mainloop.glib import DBusGMainLoop

sys.path.insert(0, "../")
from softwareproperties.dbus.SoftwarePropertiesDBus import (
    SoftwarePropertiesDBus, DBUS_BUS_NAME, DBUS_PATH, DBUS_INTERFACE_NAME)
from softwareproperties import (
    UPDATE_INST_SEC, UPDATE_DOWNLOAD, UPDATE_NOTIFY)

try:
    DPKG_ARCH = subprocess.check_output(
        ["dpkg", "--print-architecture"]).strip().decode("utf-8")
except subprocess.CalledProcessError:
    sys.stderr.write("WARNING: Failed to read dpkg arch")
    DPKG_ARCH = None

if DPKG_ARCH in ("i386", "amd64"):
    PRIMARY_MIRROR = "http://archive.ubuntu.com/ubuntu"
else:
    PRIMARY_MIRROR = "http://ports.ubuntu.com/ubuntu-ports"


def get_test_source_line():
    distro_release = get_distro_release()
    return "deb %s %s main restricted #"\
           " comment with unicode äöü" % (PRIMARY_MIRROR, distro_release)


def get_dpkg_arch():
    return subprocess.check_output(
        ["dpkg", "--print-architecture"]).strip().decode("utf-8")


def get_distro_release():
    return "bionic"


def clear_apt_config():
    etc_apt = os.path.join(os.path.dirname(__file__), "aptroot", "etc", "apt")
    for dirpath, dirnames, filenames in os.walk(etc_apt):
        for name in filenames:
            path = os.path.join(dirpath, name)
            if os.path.isfile(path):
                os.unlink(path)

    for d in ["apt.conf.d", "sources.list.d", "trusted.gpg.d", "auth.conf.d"]:
        os.makedirs(os.path.join(etc_apt, d), exist_ok=True)

def create_sources_list():
    s = get_test_source_line() + "\n"
    name = os.path.join(os.path.dirname(__file__),
                        "aptroot", "etc", "apt", "sources.list")
    dirname = os.path.dirname(name)
    if not os.path.exists(dirname):
        pathlib.Path(dirname).mkdir(parents=True)
    with open(name, "w") as f:
        f.write(s)
    return name


def session_bus_thread():
    DBusGMainLoop(set_as_default=True)
    loop = GLib.MainLoop()
    threading.current_thread().loop = loop

    bus = dbus.SessionBus(private=True)
    bus.set_exit_on_disconnect(False)

    rootdir = os.path.join(os.path.dirname(__file__), "aptroot")
    spd = SoftwarePropertiesDBus(bus, rootdir=rootdir)
    spd.enforce_polkit = False
    loop.run()
    bus.close()


class TestDBus(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        for k in apt_pkg.config.keys():
            apt_pkg.config.clear(k)
        apt_pkg.init()
        clear_apt_config()
        # create sources.list file
        create_sources_list()
        apt_pkg.config.set("Dir::Etc::sourcelist", "sources.list")
        cls.get_distro = aptsources.distro.get_distro
        aptsources.distro.get_distro = lambda *a, **b: aptsources.distro.UbuntuDistribution(id="Ubuntu", codename="bionic", description="18.04", release="bionic")
        cls.thread = threading.Thread(target=session_bus_thread)
        cls.thread.loop = None
        cls.thread.start()
        while not cls.thread.loop or not cls.thread.loop.is_running():
            time.sleep(1)

    @classmethod
    def tearDownClass(cls):
        cls.thread.loop.quit()
        cls.thread.join()
        aptsources.distro.get_distro = cls.get_distro

    def setUp(self):
        # keep track of signal emissions
        self.sources_list_count = 0
        self.distro_release = get_distro_release()
        create_sources_list()
        self._sourceslist = aptsources.sourceslist.SourcesList()
        # create the client proxy
        self.bus = dbus.SessionBus(private=True, mainloop=DBusGMainLoop())
        self.bus.set_exit_on_disconnect(False)
        self.iface = dbus.Interface(self.bus.get_object(DBUS_BUS_NAME, DBUS_PATH),
                                    DBUS_INTERFACE_NAME)
        self._signal_id = self.iface.connect_to_signal(
            "SourcesListModified", self._on_sources_list_modified)

    def tearDown(self):
        # ensure we remove the "modified" signal again
        self._signal_id.remove()
        self.bus.close()

    def _on_sources_list_modified(self):
        #print("_on_modified_sources_list")
        self.sources_list_count += 1

    @property
    def sourceslist(self):
        self._sourceslist.refresh()
        return ''.join([str(e) for e in self._sourceslist])

    @property
    def enabled_sourceslist(self):
        self._sourceslist.refresh()
        return ''.join([str(e) for e in self._sourceslist
                        if not e.invalid and not e.disabled])

    def _debug_sourceslist(self, text=""):
        logging.debug("sourceslist: %s '%s'" % (text, self.sourceslist))

    # this is an async call - give it a few seconds to catch up with what we expect
    def _assert_eventually(self, prop, n):
        for i in range(9):
            if getattr(self, prop) == n:
                self.assertEqual(getattr(self, prop), n)
            else:
                time.sleep(1)
        # nope, you die now
        self.assertEqual(getattr(self, prop), n)

    def test_enable_disable_component(self):
        # ensure its not there
        self.assertNotIn("universe", self.sourceslist)
        # enable
        self.iface.EnableComponent("universe")
        self._debug_sourceslist("2")
        self.assertIn("universe", self.sourceslist)
        # disable again
        self.iface.DisableComponent("universe")
        self._debug_sourceslist("3")
        self.assertNotIn("universe", self.sourceslist)
        self._assert_eventually("sources_list_count", 2)

    def test_enable_enable_disable_source_code_sources(self):
        # ensure its not there
        self._debug_sourceslist("4")
        self.assertNotIn('deb-src', self.enabled_sourceslist)
        # enable
        self.iface.EnableSourceCodeSources()
        self._debug_sourceslist("5")
        self.assertIn('deb-src', self.enabled_sourceslist)
        # disable again
        self.iface.DisableSourceCodeSources()
        self._debug_sourceslist("6")
        self.assertNotIn('deb-src', self.enabled_sourceslist)
        self._assert_eventually("sources_list_count", 3)

    def test_enable_child_source(self):
        child_source = "%s-updates" % self.distro_release
        # ensure its not there
        self._debug_sourceslist("7")
        self.assertNotIn(child_source, self.sourceslist)
        # enable
        self.iface.EnableChildSource(child_source)
        self._debug_sourceslist("8")
        self.assertIn(child_source, self.sourceslist)
        # disable again
        self.iface.DisableChildSource(child_source)
        self._debug_sourceslist("9")
        self.assertNotIn(child_source, self.sourceslist)
        self._assert_eventually("sources_list_count", 2)

    def test_toggle_source(self):
        # test toggle
        source = get_test_source_line()
        self.iface.ToggleSourceUse(source)
        self._debug_sourceslist("10")
        primary_debline = "# deb %s" % PRIMARY_MIRROR
        self.assertIn(primary_debline, self.sourceslist)
        # to disable the line again, we need to match the new "#"
        source = "# " + source
        self.iface.ToggleSourceUse(source)
        self._debug_sourceslist("11")
        self.assertNotIn(primary_debline, self.sourceslist)

        self._assert_eventually("sources_list_count", 2)

    def test_replace_entry(self):
        # test toggle
        source = get_test_source_line()
        source_new = "deb http://xxx/ %s" % self.distro_release
        self.iface.ReplaceSourceEntry(source, source_new)
        self._debug_sourceslist("11")
        self.assertIn(source_new, self.sourceslist)
        self.assertNotIn(source, self.sourceslist)
        self._assert_eventually("sources_list_count", 1)
        self.iface.ReplaceSourceEntry(source_new, source)
        self._assert_eventually("sources_list_count", 2)

    def test_popcon(self):
        # ensure its set to no
        popcon_p = os.path.join(os.path.dirname(__file__),
                              "aptroot", "etc", "popularity-contest.conf")
        with open(popcon_p) as f:
            popcon = f.read()
            self.assertIn('PARTICIPATE="no"', popcon)
        # toggle
        self.iface.SetPopconPariticipation(True)
        with open(popcon_p) as f:
            popcon = f.read()
            self.assertIn('PARTICIPATE="yes"', popcon)
            self.assertNotIn('PARTICIPATE="no"', popcon)
        # and back
        self.iface.SetPopconPariticipation(False)
        with open(popcon_p) as f:
            popcon = f.read()
            self.assertNotIn('PARTICIPATE="yes"', popcon)
            self.assertIn('PARTICIPATE="no"', popcon)

    def test_updates_automation(self):
        states = [UPDATE_INST_SEC, UPDATE_DOWNLOAD, UPDATE_NOTIFY]
        # security
        self.iface.SetUpdateAutomationLevel(states[0])
        cfg = os.path.join(os.path.dirname(__file__),
                           "aptroot", "etc", "apt", "apt.conf.d",
                           "10periodic")
        with open(cfg) as f:
            config = f.read()
            self.assertIn('APT::Periodic::Unattended-Upgrade "1";', config)
        # download
        self.iface.SetUpdateAutomationLevel(states[1])
        with open(cfg) as f:
            config = f.read()
            self.assertIn('APT::Periodic::Unattended-Upgrade "0";', config)
            self.assertIn('APT::Periodic::Download-Upgradeable-Packages "1";', config)
        # notify
        self.iface.SetUpdateAutomationLevel(states[2])
        with open(cfg) as f:
            config = f.read()
            self.assertIn('APT::Periodic::Unattended-Upgrade "0";', config)
            self.assertIn('APT::Periodic::Download-Upgradeable-Packages "0";', config)

    def test_updates_interval(self):
        # interval
        self.iface.SetUpdateInterval(0)
        cfg = os.path.join(os.path.dirname(__file__),
                           "aptroot", "etc", "apt", "apt.conf.d",
                           "10periodic")
        with open(cfg) as f:
            config = f.read()
            self.assertTrue(
                'APT::Periodic::Update-Package-Lists' not in config or
                'APT::Periodic::Update-Package-Lists "0";' in config)
        self.iface.SetUpdateInterval(1)
        with open(cfg) as f:
            config = f.read()
            self.assertIn('APT::Periodic::Update-Package-Lists "1";', config)
        self.iface.SetUpdateInterval(0)
        with open(cfg) as f:
            config = f.read()
            self.assertIn('APT::Periodic::Update-Package-Lists "0";', config)

    def test_add_remove_source_by_line(self):
        # add invalid
        res = self.iface.AddSourceFromLine("xxx")
        self.assertFalse(res)
        # add real
        s = "deb https://ppa.launchpadcontent.net/ foo bar"
        self.iface.AddSourceFromLine(s)
        self.assertIn(s, self.sourceslist)
        self.assertIn(s.replace("deb", "# deb-src"), self.sourceslist)
        # remove again
        self.iface.RemoveSource(s)
        self.iface.RemoveSource(s.replace("deb", "deb-src"))
        self.assertNotIn(s, self.sourceslist)
        self.assertNotIn(s.replace("deb", "# deb-src"), self.sourceslist)
        self._assert_eventually("sources_list_count", 4)

    def test_add_gpg_key(self):
        # clean
        gpg_glob = os.path.join(os.path.dirname(__file__),
                              "aptroot", "etc", "apt", "trusted.gpg.d", "*.asc")
        trusted_gpg_d = os.path.join(os.path.dirname(__file__),
                                     "aptroot", "etc", "apt", "trusted.gpg.d/")
        testkey = os.path.join(os.path.dirname(__file__),
                               "data", "testkey.asc")
        for f in glob.glob(gpg_glob):
            os.remove(f)
        self.assertTrue(len(glob.glob(gpg_glob)) == 0)
        # add key from file
        res = self.iface.AddKey(os.path.join(os.path.dirname(__file__),
                                             "data", "testkey.asc"))
        self.assertTrue(res)
        self.assertEqual(len(glob.glob(gpg_glob)), 1)
        self.assertNotEqual(os.path.getsize(trusted_gpg_d + "testkey.asc"), 0)
        # remove the key
        res = self.iface.RemoveKey(trusted_gpg_d + "testkey.asc")
        self.assertTrue(res)
        self.assertEqual(len(glob.glob(gpg_glob)), 0)
        # add from data
        with open(testkey) as keyfile:
            data = keyfile.read()
            res = self.iface.AddKeyFromData(data)
            self.assertTrue(res)
            self.assertEqual(len(os.listdir(trusted_gpg_d)), 1)
        # remove the key
        res = self.iface.RemoveKey(trusted_gpg_d + "/46caf96d27ee1eebcf139483275a5c1bb41590474e35f8bb4ae3cceb413f7518.gpg")
        self.assertTrue(res)
        self.assertEqual(len(os.listdir(trusted_gpg_d)), 1)
        # test nonsense
        res = self.iface.AddKeyFromData("nonsens")
        self.assertFalse(res)


if __name__ == "__main__":
    if "-d" in sys.argv:
        logging.basicConfig(level=logging.DEBUG)
    else:
        logging.basicConfig(level=logging.INFO)
    unittest.main()
