import sys
from collections import defaultdict
import apt_pkg
from debian import deb822
from functools import cmp_to_key
import subprocess
import io
import json
import networkx as nx
import yaml
try:
    from yaml import CBaseLoader as yamlLoader
except ImportError:
    from yaml import BaseLoader as yamlLoader
import argparse
import xml


def write_plain(path):
    class fh_out:
        fd = None
        path = None

        def __init__(self, path):
            self.path = path

        def __enter__(self):
            if self.path == "-":
                self.fd = sys.stdout
            else:
                self.fd = open(self.path, "w", encoding="utf8")
            return self.fd

        def __exit__(self, type, value, traceback):
            self.fd.close()
    return fh_out(path)


def read_yaml_file(path):
    try:
        if path == '-':
            data = yaml.load(sys.stdin, Loader=yamlLoader)
        else:
            try:
                with open(path) as f:
                    data = yaml.load(f, Loader=yamlLoader)
            except IsADirectoryError:
                raise argparse.ArgumentTypeError(
                    "\"%s\" is a directory" % path)
            except PermissionError:
                raise argparse.ArgumentTypeError(
                    "\"%s\" permission denied" % path)
            except FileNotFoundError:
                raise argparse.ArgumentTypeError(
                    "\"%s\" does not exist" % path)
    except yaml.scanner.ScannerError as e:
        raise argparse.ArgumentTypeError("failed to parse yaml: %s%s"
                                         % (e.problem, e.problem_mark))
    if data is None:
        raise argparse.ArgumentTypeError("input yaml file is empty")
    return data


def read_json_file(path):
    if path == '-':
        return json.load(sys.stdin)
    else:
        with open(path) as f:
            return json.load(f)


def read_graphml(path):
    origpath = path
    if path == '-':
        path = sys.stdin
    try:
        return nx.read_graphml(path)
    except IsADirectoryError:
        raise argparse.ArgumentTypeError("\"%s\" is a directory" % origpath)
    except PermissionError:
        raise argparse.ArgumentTypeError("\"%s\" permission denied" % origpath)
    except FileNotFoundError:
        raise argparse.ArgumentTypeError("\"%s\" does not exist" % origpath)
    except xml.etree.ElementTree.ParseError:
        raise argparse.ArgumentTypeError(
            "Input graph \"%s\" is not in GraphML format" % origpath)


def read_graph(path):
    origpath = path
    data = None
    # we have to slurp all data from stdin as we potentially have to give the
    # data to read_graphml() *and* to read_dot() if the former failed
    if path == '-':
        data = sys.stdin.read()
    # try reading graphml first because xml.etree.ElementTree will properly
    # fail if it cannot read the input while pygraphviz will parse XML
    # without raising an exception
    # monkey patching the output graph to not loose the input file type
    try:
        if data is None:
            g = nx.read_graphml(path)
        else:
            g = nx.read_graphml(io.StringIO(data))
        g.input_file_type = "graphml"
    except IsADirectoryError:
        raise argparse.ArgumentTypeError("\"%s\" is a directory" % origpath)
    except PermissionError:
        raise argparse.ArgumentTypeError("\"%s\" permission denied" % origpath)
    except FileNotFoundError:
        raise argparse.ArgumentTypeError("\"%s\" does not exist" % origpath)
    except xml.etree.ElementTree.ParseError:
        try:
            if data is None:
                g = nx.nx_agraph.read_dot(path)
            else:
                # we cannot use nx.nx_agraph.read_dot because that uses the
                # deprecated "file" argument when calling pygraphviz.AGraph()
                # which does not allow one to pass a file object or a string
                # with the data
                import pygraphviz
                A = pygraphviz.AGraph(string=data)
                g = nx.nx_agraph.from_agraph(A)
            g.input_file_type = "dot"
        except:
            raise argparse.ArgumentTypeError(
                "Input graph \"%s\" is neither in GraphML nor in dot format"
                % origpath)
    return g


def write_graph(path):
    if path == '-':
        def helper(g):
            if not hasattr(g, "input_file_type"):
                raise Exception("this function needs the input graph to have "
                                "the attribute input_file_type")
            if g.input_file_type == "graphml":
                nx.write_graphml(g, sys.stdout.buffer)
            elif g.input_file_type == "dot":
                nx.nx_agraph.write_dot(g, sys.stdout)
            else:
                raise Exception("input_file_type attribute must either be dot "
                                "or graphml")
    else:
        def helper(g):
            if path.endswith(".xml"):
                nx.write_graphml(g, path)
            elif path.endswith(".dot"):
                nx.nx_agraph.write_dot(g, path)
            else:
                # cannot determine desired output file type from output file
                # extension, fall back to input file type
                if not hasattr(g, "input_file_type"):
                    raise Exception("this function needs the input graph to "
                                    "have the attribute input_file_type")
                if g.input_file_type == "graphml":
                    nx.write_graphml(g, path)
                elif g.input_file_type == "dot":
                    nx.nx_agraph.write_dot(g, path)
                else:
                    raise Exception("input_file_type attribute must either be "
                                    "dot or graphml")
    return helper


def write_graphml(path):
    if path == '-':
        return lambda g: nx.write_graphml(g, sys.stdout.buffer)
    else:
        return lambda g: nx.write_graphml(g, path)


def write_dot(path):
    if path == '-':
        return lambda g: nx.nx_agraph.write_dot(g, sys.stdout)
    else:
        return lambda g: nx.nx_agraph.write_dot(g, path)


class fh_out:
    fd = None
    path = None

    def __init__(self, path):
        self.path = path

    def __enter__(self):
        if self.path == "-":
            self.fd = getattr(sys.stdout, 'buffer', sys.stdout)
        elif self.path.endswith(".gz"):
            import gzip
            self.fd = gzip.GzipFile(self.path, "w")
        elif self.path.endswith(".bz2"):
            import bz2
            self.fd = bz2.BZ2File(self.path, "w")
        elif self.path.endswith(".xz"):
            import lzma
            self.fd = lzma.LZMAFile(self.path, "w")
        else:
            self.fd = open(self.path, "wb")
        return self.fd

    def __exit__(self, type, value, traceback):
        self.fd.close()


class fh_out_read_write(fh_out):
    pkgs = []

    def __init__(self, path, pkgs):
        fh_out.__init__(self, path)
        self.pkgs = pkgs

    def __iter__(self):
        return iter(self.pkgs)


def get_fh_out(path):
    return fh_out(path)


def read_tag_file(path):
    # we read all the input into memory because some of the decompressors can
    # only create file like objects from filenames anyway
    # juggling with multiple file descriptors and making sure they are all
    # cleaned up afterwards seems hard enough to just waste some megs of memory
    # instead
    if path == '-':
        data = getattr(sys.stdin, 'buffer', sys.stdin).read()
    else:
        with open(path, "rb") as f:
            data = f.read()
    if data[:2] == b"\x1f\x8b":
        import gzip
        with io.BytesIO(data) as f:
            data = gzip.GzipFile(fileobj=f).read()
    elif data[:3] == b"BZh":
        import bz2
        data = bz2.decompress(data)
    elif data[:5] == b"\xfd7zXZ":
        import lzma
        data = lzma.decompress(data)
    with io.BytesIO(data) as f:
        # not using apt_pkg because that will leave file handles open
        # (see bug#748922)
        # not using apt_pkg because it can't transparently decompress data from
        # filehandles
        pkgs = list(deb822.Deb822.iter_paragraphs(f, use_apt_pkg=False))
    return pkgs


def read_write_tag_file(path):
    pkgs = read_tag_file(path)
    return fh_out_read_write(path, pkgs)


def read_fas(filename):
    fas = defaultdict(set)
    with open(filename) as f:
        for line in f:
            # remove everything after first '#'
            line = line.split('#', 1)[0]
            line = line.strip()
            if not line:
                continue
            src, deps = line.split(' ', 1)
            fas[src].update(deps)
    return fas


def read_weak_deps(filename):
    weak_deps = set()
    with open(filename) as f:
        for line in f:
            # remove everything after first '#'
            line = line.split('#', 1)[0]
            line = line.strip()
            if not line:
                continue
            weak_deps.add(line)
    return weak_deps


def read_reduced_deps(filenames):
    reduced_deps = defaultdict(set)
    for filename in filenames.split(','):
        with open(filename) as f:
            for line in f:
                # remove everything after first '#'
                line = line.split('#', 1)[0]
                line = line.strip()
                if not line:
                    continue
                src, pkgs = line.split(' ', 1)
                pkgs = pkgs.split(' ')
                reduced_deps[src].update(set(pkgs))
    return reduced_deps


def graph_remove_weak(g, weak_deps):
    def is_weak(a):
        if a['kind'] == 'SrcPkg':
            return False
        n = a['name']
        return n in weak_deps
    weak_inst_sets = [n for n, a in g.nodes_iter(data=True) if is_weak(a)]
    g.remove_nodes_from(weak_inst_sets)


def graph_remove_droppable(g, reduced_deps):
    def is_droppable(e):
        v1, v2 = e
        if g.node[v1]['kind'] == 'InstSet':
            False
        n1 = "src:" + g.node[v1]['name']
        n2 = g.node[v2]['name']
        return n1 in reduced_deps and n2 in reduced_deps[n1]
    droppable_edges = [e for e in g.edges_iter() if is_droppable(e)]
    g.remove_edges_from(droppable_edges)


apt_pkg.init()


def cmp(a, b):
    return (a > b) - (a < b)

# sort by name, then by version, then by arch


def sort_pkgs(pkg1, pkg2):
    n1, a1, v1 = pkg1
    n2, a2, v2 = pkg2
    name_cmp = cmp(n1, n2)
    if name_cmp:
        return name_cmp
    else:
        ver_cmp = apt_pkg.version_compare(v1, v2)
        if ver_cmp:
            return ver_cmp
        else:
            return cmp(a1, a2)


sort_pkgs_key = cmp_to_key(sort_pkgs)

_arch_matches_cache = dict()


def arch_matches(arch, wildcard):
    if wildcard == 'any' or wildcard == 'all':
        return True
    cached = _arch_matches_cache.get((arch, wildcard), None)
    if cached is not None:
        return cached
    # environment must be empty or otherwise the DEB_HOST_ARCH environment
    # variable will influence the result
    ret = subprocess.call(
        ['dpkg-architecture', '-i%s' % wildcard, '-a%s' % arch],
        env={})
    ret = True if ret == 0 else False
    _arch_matches_cache[(arch, wildcard)] = ret
    return ret


def parse_dose_yaml(yamlin):
    data = {"bin": defaultdict(set), "src": defaultdict(set)}

    if yamlin.get('output-version') is None:
        raise Exception('missing yaml field output-version')

    if yamlin['output-version'] not in ['1.2']:
        raise Exception('yaml output-version is unsupported: ' +
                        yamlin['output-version'])

    if yamlin.get('report') is None:
        return {"bin": {}, "src": {}}

    for p in yamlin['report']:
        if p['status'] == 'broken':
            for r in p['reasons']:
                if r.get('missing'):
                    data['bin'][r['missing']['pkg']['unsat-dependency']
                                .split(' ', 1)[0].split(':',
                                                        1)[0]].add('missing')
                    data['src'][p['package'].split(':', 1)[0]].add('missing')
                if r.get('conflict'):
                    data['bin'][r['conflict']['pkg1'][
                        'package'].split(':', 1)[0]].add('conflict')
                    data['bin'][r['conflict']['pkg2'][
                        'package'].split(':', 1)[0]].add('conflict')
                    data['src'][p['package'].split(':', 1)[0]].add('conflict')
    return data


def vpkgdisj2deb(disj, parentarch, nativearch, hostarch, parenttype):
    disjl = disj.split('|')

    def vpkg2deb(vpkg):
        vpkg = vpkg.strip()
        try:
            name, ver = vpkg.split(' ', 1)
        except ValueError:
            name = vpkg
            ver = None
        try:
            name, arch = name.split(':', 1)
        except ValueError:
            # if the vpkg does not have an architecture qualification, then
            # the architecture of the dependee has to be either the hostarch,
            # the native arch or the parent arch, depending on whether the
            # depender is a source package, an arch:all package or an arch:any
            # package
            if parenttype == "src":
                arch = hostarch
            elif parentarch == "all":
                arch = nativearch
            else:
                arch = parentarch
            arch = parentarch
        if name.startswith('--virtual-'):
            name = name[10:]
        if ver:
            return "%s:%s %s" % (name, arch, ver)
        else:
            return "%s:%s" % (name, arch)
    # make unique, sort and join
    return " | ".join(sorted(list(set([vpkg2deb(vpkg) for vpkg in disjl]))))


def get_depchain(depchain, parentarch, nativearch, hostarch, parenttype):
    firstpkg = ("", None, None,
                vpkgdisj2deb(depchain['depchain'][0]['depends'], parentarch,
                             nativearch, hostarch, parenttype))

    def get_vpkg(k):
        if k.get('depends') is None:
            # if the depchain element does not have a 'depends' field then it
            # will be the last element of the chain and be in the chain because
            # it is Essential:yes
            return "implicit dependency via an Essential:yes package"
        else:
            return vpkgdisj2deb(k['depends'], k['architecture'], nativearch,
                                hostarch, "bin")
    chain = [(k['package'], k['architecture'], k['version'], get_vpkg(k))
             for k in depchain['depchain'][1:]]
    return [firstpkg] + chain


def parse_dose_yaml_mc(yamlin):
    missing = defaultdict(lambda: defaultdict(set))
    conflict = defaultdict(lambda: defaultdict(set))

    if yamlin.get('output-version') is None:
        raise Exception('missing yaml field output-version')

    if yamlin['output-version'] != '1.2':
        raise Exception('yaml output-version is unequal 1.2: ' +
                        yamlin['output-version'])

    if yamlin.get('report') is None:
        return ({}, {})

    nativearch = yamlin['native-architecture']
    hostarch = yamlin.get('host-architecture', nativearch)

    for p in yamlin['report']:
        n, v, a = p['package'], p['version'], p['architecture']
        t = p.get("type", "bin")
        if t == "src":
            n = "src:" + n
        # only add name to avoid duplicates when more than one
        # version exists
        if p['status'] == 'broken':
            for r in p['reasons']:
                if r.get('missing'):
                    pkg = r['missing']['pkg']
                    unsatdep = vpkgdisj2deb(pkg['unsat-dependency'],
                                            pkg['architecture'], nativearch,
                                            hostarch, t)
                    lastpkg = (pkg['package'], pkg['architecture'],
                               pkg['version'], unsatdep)
                    c = r['missing'].get('depchains')
                    if c:
                        depchains = [tuple(get_depchain(depchain, a,
                                                        nativearch, hostarch,
                                                        t) + [lastpkg])
                                     for depchain in c]
                        # normalize the order of depchains
                        depchains = tuple(sorted(set(depchains)))
                    else:
                        depchains = tuple([(("", None, None, unsatdep),)])
                    missing[unsatdep][depchains].add((n, a, v, None))
                if r.get('conflict'):
                    conf1 = r['conflict']['pkg1']
                    conf2 = r['conflict']['pkg2']
                    unsatconf = vpkgdisj2deb(conf1['unsat-conflict'],
                                             conf1['architecture'], nativearch,
                                             hostarch, t)
                    lastpkg1 = (conf1['package'], conf1['architecture'],
                                conf1['version'], unsatconf)
                    lastpkg2 = (conf2['package'], conf2['architecture'],
                                conf2['version'], None)
                    c1 = r['conflict'].get('depchain1')
                    if c1:
                        depchain1 = [tuple(get_depchain(depchain, a,
                                                        nativearch, hostarch,
                                                        t) + [lastpkg1])
                                     for depchain in c1]
                        depchain1 = tuple(sorted(set(depchain1)))
                    else:
                        depchain1 = tuple([(("", None, None, unsatconf),)])
                    c2 = r['conflict'].get('depchain2')
                    if c2:
                        depchain2 = [tuple(get_depchain(depchain, a,
                                                        nativearch, hostarch,
                                                        t) + [lastpkg2])
                                     for depchain in c2]
                        depchain2 = tuple(sorted(set(depchain2)))
                    else:
                        depchain2 = tuple([(("", None, None, None),)])
                    depchains = tuple([depchain1, depchain2])
                    conflict[unsatconf][depchains].add((n, a, v, None))
    return (missing, conflict)


def human_readable_size(val):
    for unit in ['', 'KiB', 'MiB']:
        if val < 1024:
            return "%.1f %s" % (val, unit)
        val /= 1024.0
    return "%.1f GiB" % val


def find_node(g, selectors):
    packages = find_nodes(g, selectors)
    if len(packages) > 1:
        sel = ",".join(["%s:%s" % s for s in selectors])
        print("found %d results for selectors %s - picking first"
              % (len(packages), sel), file=sys.stderr)
    return packages[0]


def find_nodes(g, selectors):
    packages = [n for n, attr in g.nodes(data=True)
                if all([attr.get(k) == v if k != '__ID__' else n == v
                        for k, v in selectors])]
    if len(packages) == 0:
        sel = ",".join(["%s:%s" % s for s in selectors])
        raise Exception("package cannot be found for selectors %s" % sel)
    return sorted(packages)
