from __future__ import print_function
from collections import defaultdict
import operator
import os
import os.path as op
import copy
from progressbar import ProgressBar

import math

import seqcluster.libs.logger as mylog
from seqcluster.libs import utils
from seqcluster.libs import parameters
from seqcluster.libs.classes import *
from seqcluster.libs.mystats import up_threshold
from seqcluster.libs.bayes import decide_by_bayes

logger = mylog.getLogger(__name__)

REMOVED = 0
CONFLICT = 0


def _get_seqs_from_cluster(seqs, seen):
    """
    Returns the sequences that are already part of the cluster

    :param seqs: list of sequences ids
    :param clus_id: dict of sequences ids that are part of a cluster

    :returns:
        * :code:`already_in`list of cluster id that contained some of the sequences
        * :code:`not_in`list of sequences that don't belong to any cluster yet
    """
    already_in = set()
    not_in = []

    already_in = [e for e in map(seen.get, seqs)]
    # if isinstance(already_in, list):
    already_in = filter(None, already_in)
    not_in = set(seqs) - set(seen.keys())
    return list(set(already_in)), list(not_in)


def reduceloci(clus_obj,  path):
    """reduce number of loci a cluster has
    :param clus_obj: cluster object object
    :param path: output path
    """
    filtered = {}
    n_cluster = 0
    large = 0
    current = clus_obj.clusid
    logger.info("Number of loci: %s" % len(clus_obj.loci.keys()))
    avg_loci_cluster = len(clus_obj.loci.keys())/len(clus_obj.seq)
    if avg_loci_cluster > 0.5:
        logger.warning("The avg number of loci by sequences is close to 0.5: %s",
                       avg_loci_cluster)
        logger.warning("This could mean you have sequences over the genome, rare in "
                       "a typical small RNA data. This can cause errors during the execution "
                       "or long computing time.")
    bar = ProgressBar(maxval=len(current)).start()
    bar.update(0)
    for itern, idmc in enumerate(current):
        bar.update(itern)
        logger.debug("_reduceloci: cluster %s" % idmc)
        c = copy.deepcopy(list(current[idmc]))
        n_loci = len(c)
        if n_loci < 100:
            filtered, n_cluster = _iter_loci(c, clus_obj.clus, (clus_obj.loci, clus_obj.seq), filtered, n_cluster)
        else:
            large += 1
            n_cluster += 1
            _write_cluster(c, clus_obj.clus, clus_obj.loci, n_cluster, path)
            filtered[n_cluster] = _add_complete_cluster(n_cluster, c, clus_obj.clus)
    clus_obj.clus = filtered

    seqs = 0
    for idc in filtered:
        seqs += len(filtered[idc].idmembers)
    logger.info("Seqs in clusters %s" % (seqs))

    logger.info("Clusters too long to be analyzed: %s" % large)
    logger.info("Number of clusters removed because low number of reads: %s" % REMOVED)
    logger.info("Number of clusters with conflicts: %s" % CONFLICT)
    return clus_obj


def _write_cluster(metacluster, cluster, loci, idx, path):
    """
    For complex meta-clusters, write all the loci for further debug
    """
    out_file = op.join(path, 'log', str(idx) + '.bed')
    with utils.safe_run(out_file):
        with open(out_file, 'w') as out_handle:
            for idc in metacluster:
                for idl in cluster[idc].loci2seq:
                    pos = [e for e in loci[idl].list()]
                    print("\t".join(pos[:4] + [str(len(cluster[idc].loci2seq[idl]))] + [pos[-1]]), file=out_handle, end="")


def _add_complete_cluster(idx, meta, clusters):
    logger.debug("Not resolving cluster %s, too many loci" % (idx))
    clus = {}
    [clus.update(clusters[idc].locilen) for idc in meta]
    locilen_sorted = sorted(iter(clus.items()), key=operator.itemgetter(1), reverse=True)
    maxidl = locilen_sorted[0][0]
    c = cluster(idx)
    for idc in meta:
        c.add_id_member(clusters[idc].idmembers.keys(), maxidl)
    c.id = idx
    c.toomany = len(meta)
    return c


def _iter_loci(meta, clusters, s2p, filtered, n_cluster):
    """
    Go through all locus and decide if they are part
    of the same TU or not.

    :param idx: int cluster id
    :param s2p: dict with [loci].coverage[start] = # of sequences there
    :param filtered: dict with clusters object
    :param n_cluster: int cluster id

    :return:
        * filtered: dict of cluster objects
        * n_cluster: int cluster id
    """
    global CONFLICT
    loci = dict(zip(meta, [clusters[idc] for idc in meta]))

    n_loci = len(meta)
    n_loci_prev = n_loci + 1
    cicle = 0
    internal_cluster = {}
    if n_loci == 1:
        n_cluster += 1
        filtered[n_cluster] = clusters[meta[0]]
        filtered[n_cluster].update(id=n_cluster)
        filtered[n_cluster].set_freq(s2p[1])

    while n_loci < n_loci_prev and n_loci != 1:
        n_loci_prev = n_loci
        cicle += 1
        if (cicle % 1) == 0:
            logger.debug("_iter_loci:number of cicle: %s with n_loci %s" % (cicle, n_loci))
        loci_similarity = _calculate_similarity(loci)
        internal_cluster = _merge_similar(loci, loci_similarity)
        n_loci = len(internal_cluster)
        loci = internal_cluster
        logger.debug("_iter_loci: n_loci %s" % n_loci)

    if n_loci > 1:
        n_internal_cluster = sorted(internal_cluster.keys(), reverse=True)[0]
        CONFLICT += 1
        internal_cluster = _solve_conflict(internal_cluster, s2p, n_internal_cluster)

    internal_cluster = _clean_cluster(internal_cluster)

    for idc in internal_cluster:
        n_cluster += 1
        logger.debug("_iter_loci: add to filtered %s" % n_cluster)
        filtered[n_cluster] = internal_cluster[idc]
        filtered[n_cluster].id = n_cluster
        filtered[n_cluster].update(id=n_cluster)
        filtered[n_cluster].set_freq(s2p[1])
    logger.debug("_iter_loci: filtered %s" % filtered.keys())

    # for new_c in internal_cluster.values():
    #    [logger.note("%s %s %s %s" % (meta, new_c.id, idl, len(new_c.loci2seq[idl]))) for idl in new_c.loci2seq]
    return filtered, n_cluster


# def _remove_loci(ci, idl):
#     for (idl, lenl) in locilen_sorted:
#         logger.debug("_remove_loci:remove locus %s with len %s:" % (idl, lenl))
#         c.loci2seq.pop(idl, "None")
#         c.locilen.pop(idl, "None")


def _convert_to_clusters(c):
    """Return 1 cluster per loci"""
    new_dict = {}
    n_cluster = 0
    logger.debug("_convert_to_cluster: loci %s" % c.loci2seq.keys())
    for idl in c.loci2seq:
        n_cluster += 1
        new_c = cluster(n_cluster)
        #new_c.id_prev = c.id
        new_c.loci2seq[idl] = c.loci2seq[idl]
        new_dict[n_cluster] = new_c
    logger.debug("_convert_to_cluster: new ids %s" % new_dict.keys())
    return new_dict


def _calculate_similarity(c):
    """Get a similarity matrix of % of shared sequence

    :param c: cluster object

    :return ma: similarity matrix
    """
    ma = {}
    for idc in c:
        set1 = _get_seqs(c[idc])
        [ma.update({(idc, idc2): _common(set1, _get_seqs(c[idc2]), idc, idc2)}) for idc2 in c if idc != idc2 and (idc2, idc) not in ma]
    # logger.debug("_calculate_similarity_ %s" % ma)
    return ma


def _get_seqs(list_idl):
    """get all sequences in a cluster knowing loci"""
    seqs = set()
    for idl in list_idl.loci2seq:
        # logger.debug("_get_seqs_: loci %s" % idl)
        [seqs.add(s) for s in list_idl.loci2seq[idl]]
    # logger.debug("_get_seqs_: %s" % len(seqs))
    return seqs


def _common(s1, s2, i1, i2):
    """calculate the common % percentage of sequences"""
    c = len(set(s1).intersection(s2))
    t = min(len(s1), len(s2))
    pct = 1.0 * c / t * t
    is_gt = up_threshold(pct, t * 1.0, parameters.similar)
    logger.debug("_common: pct %s of clusters:%s %s = %s" % (1.0 * c / t, i1, i2, is_gt))
    if pct < parameters.similar and is_gt and pct > 0:
        pct = parameters.similar
    return pct / t


def _is_consistent(pairs, common, clus_seen, loci_similarity):
    """
    Check if loci shared that match sequences with all
    clusters seen until now.
    """
    all_true1 = all([all([common and loci_similarity[(p, c)] > parameters.similar  for p in pairs if (p, c) in loci_similarity]) for c in clus_seen])
    all_true2 = all([all([common and loci_similarity[(c, p)] > parameters.similar  for p in pairs if (c, p) in loci_similarity]) for c in clus_seen])
    return all_true1 * all_true2


def _merge_similar(loci, loci_similarity):
    """
    Internal function to reduce loci complexity

    :param loci: class cluster
    :param locilen_sorted: list of loci sorted by size

    :return
     c: updated class cluster
    """
    n_cluster = 0
    internal_cluster = {}
    clus_seen = {}
    loci_sorted = sorted(iter(loci_similarity.items()), key=operator.itemgetter(1), reverse=True)
    for pairs, sim in loci_sorted:
        common = sim > parameters.similar
        n_cluster += 1
        logger.debug("_merge_similar:try new cluster %s" % n_cluster)
        new_c = cluster(n_cluster)
        p_seen, p_unseen = [], []
        size = min(len(_get_seqs(loci[pairs[0]])), len(_get_seqs(loci[pairs[1]])))
        if common:
            consistent = _is_consistent(pairs, common, clus_seen, loci_similarity)
            logger.debug("_merge_similar: clusters seen: %s" % clus_seen)
            logger.debug("_merge_similar: id %s common %s|%s total  %s consistent %s" % (pairs, sim, common, size, consistent))
            if not consistent:
                continue
            if pairs[0] in clus_seen:
                p_seen.append(pairs[0])
                p_unseen.append(pairs[1])
            if pairs[1] in clus_seen:
                p_seen.append(pairs[1])
                p_unseen.append(pairs[0])
            if len(p_seen) == 0:
                new_c = _merge_cluster(loci[pairs[0]], new_c)
                new_c = _merge_cluster(loci[pairs[1]], new_c)
                [clus_seen.update({p: n_cluster}) for p in pairs]
                internal_cluster[n_cluster] = new_c
            if len(p_seen) == 1:
                idc_seen = clus_seen[p_seen[0]]
                internal_cluster[idc_seen] = _merge_cluster(loci[p_unseen[0]], internal_cluster[idc_seen])
                clus_seen[p_unseen[0]] = idc_seen
        else:
            logger.debug("_merge_similar: id %s %s  are different" % pairs)
            continue
    internal_cluster.update(_add_unseen(loci, clus_seen, n_cluster))
    logger.debug("_merge_similar: total clus %s" %
                 len(internal_cluster.keys()))
    return internal_cluster


def _merge_cluster(old, new):
    """merge one cluster to another"""
    logger.debug("_merge_cluster: %s to %s" % (old.id, new.id))
    logger.debug("_merge_cluster: add idls %s" % old.loci2seq.keys())
    for idl in old.loci2seq:
        # if idl in new.loci2seq:
        #    new.loci2seq[idl] = list(set(new.loci2seq[idl] + old.loci2seq[idl]))
        # new.loci2seq[idl] = old.loci2seq[idl]
        new.add_id_member(old.loci2seq[idl], idl)
    return new


def _solve_conflict(list_c, s2p, n_cluster):
    """
    Make sure sequences are counts once.
    Resolve by most-vote or exclussion

    :params list_c: dict of objects cluster
    :param s2p: dict of [loci].coverage = # num of seqs
    :param n_cluster: number of clusters

    return dict: new set of clusters
    """
    logger.debug("_solve_conflict: count once")
    if parameters.decision_cluster == "bayes":
        return decide_by_bayes(list_c, s2p)
    loci_similarity = _calculate_similarity(list_c)
    loci_similarity = sorted(iter(loci_similarity.items()), key=operator.itemgetter(1), reverse=True)
    common = sum([score for p, score in loci_similarity])
    while common > 0:
        n_cluster += 1
        logger.debug("_solve_conflict: ma %s" % loci_similarity)
        pairs = loci_similarity[0][0]
        score = loci_similarity[0][1]
        logger.debug("_solve_conflict: common %s, new %s" % (score, n_cluster))
        if parameters.decision_cluster.startswith("most-voted"):
            list_c = _split_cluster_by_most_vote(list_c, pairs)
        else:
            list_c = _split_cluster(list_c, pairs, n_cluster)
        list_c = {k: v for k, v in iter(list_c.items()) if len(v.loci2seq) > 0}
        loci_similarity = _calculate_similarity(list_c)
        loci_similarity = sorted(iter(loci_similarity.items()), key=operator.itemgetter(1), reverse=True)
        #logger.note("%s %s" % (pairs, loci_similarity[0][1]))
        common = sum([score for p, score in loci_similarity])
        logger.debug("_solve_conflict: solved clusters %s" % len(list_c.keys()))
    return list_c


def _split_cluster(c, pairs, n):
    """split cluster by exclussion"""
    old = c[p[0]]
    new = c[p[1]]
    new_c = cluster(n)
    common = set(_get_seqs(old)).intersection(_get_seqs(new))
    for idl in old.loci2seq:
        in_common = list(set(common).intersection(old.loci2seq[idl]))
        if len(in_common) > 0:
            logger.debug("_split_cluster: in_common %s with pair 1" % (len(in_common)))
            new_c.add_id_member(in_common, idl)
            old.loci2seq[idl] = list(set(old.loci2seq[idl]) - set(common))
            logger.debug("_split_cluster: len old %s with pair 1" % (len(old.loci2seq)))
    for idl in new.loci2seq:
        in_common = list(set(common).intersection(new.loci2seq[idl]))
        if len(in_common) > 0:
            logger.debug("_split_cluster: in_common %s with pair 2" % (len(in_common)))
            new_c.add_id_member(in_common, idl)
            new.loci2seq[idl] = list(set(new.loci2seq[idl]) - set(common))
            logger.debug("_split_cluster: len old %s with pair 2" % (len(new.loci2seq)))
    old.update()
    new.update()
    old.loci2seq = {k: v for k, v in iter(old.loci2seq.items()) if len(v) > 0}
    new.loci2seq = {k: v for k, v in iter(new.loci2seq.items()) if len(v) > 0}
    c[n] = new
    c[p[0]] = old
    c[p[1]] = new
    return c


def _split_cluster_by_most_vote(c, p):
    """split cluster by most-vote strategy"""
    old, new = c[p[0]], c[p[1]]
    old_size = _get_seqs(old)
    new_size = _get_seqs(new)
    logger.debug("_most_vote: size of %s with %s - %s with %s" % (old.id, len(old_size), new.id, len(new_size)))
    if len(old_size) > len(new_size):
        keep, remove = old, new
    else:
        keep, remove = new, old
    common = list(set(old_size).intersection(new_size))
    logger.debug("_most_vote: keep %s remove  %s with common %s" % (keep.id, remove.id, len(common)))
    for idl in remove.loci2seq:
        if len(common) > 0:
            remove.loci2seq[idl] = list(set(remove.loci2seq[idl]) - set(common))
    keep.loci2seq = {k: v for k, v in iter(keep.loci2seq.items()) if len(v) > 0}
    remove.loci2seq = {k: v for k, v in iter(remove.loci2seq.items()) if len(v) > 0}
    keep.update()
    remove.update()
    c[keep.id] = keep
    c[remove.id] = remove
    return c


def _add_unseen(loci, clus_seen, n_cluster):
    unseen = {}
    for idc in loci:
        if idc not in clus_seen:
            n_cluster += 1
            loci[idc].id = n_cluster
            unseen[n_cluster] = loci[idc]
            logger.debug("_add_unseen: add  %s as new %s" %
                         (idc, n_cluster))
    return unseen


def _clean_cluster(list_c):
    """
    Remove cluster with less than 10 sequences and
    loci with size smaller than 60%
    """
    global REMOVED
    init = len(list_c)
    list_c = {k: v for k, v in iter(list_c.items()) if len(_get_seqs(v)) > parameters.min_seqs}
    logger.debug("_clean_cluster: number of clusters %s " % len(list_c.keys()))
    list_c = {k: _select_loci(v) for k, v in iter(list_c.items())}
    end = len(list_c)
    REMOVED += init - end
    return list_c


def _select_loci(c):
    """Select only loci with most abundant sequences"""
    loci_len = {k: len(v) for k, v in iter(c.loci2seq.items())}
    logger.debug("_select_loci: number of loci %s" % len(c.loci2seq.keys()))
    loci_len_sort = sorted(iter(loci_len.items()), key=operator.itemgetter(1), reverse=True)
    max_size = loci_len_sort[0][1]
    logger.debug("_select_loci: max size %s" % max_size)
    loci_clean = {locus: c.loci2seq[locus] for locus, size in loci_len_sort if size > 0.8 * max_size}
    c.loci2seq = loci_clean
    removed = list(set(c.idmembers.keys()) - set(_get_seqs(c)))
    c.add_id_member(removed, loci_len_sort[0][0])
    logger.debug("_select_loci: number of loci %s after cleaning" % len(c.loci2seq.keys()))
    return c


def _calculate_size_enrichment(c):
    """calculate whether there is a size
    enrichment in the cluster

    :param c: cluster object
    """
    seqs = _get_sequences(c)
    return True


def _solve_loci_deprecated(c, locilen_sorted, seen_seqs, filtered, maxseq, n_cluster):
    """internal function to reduce loci complexity

    The function will read the all loci in a cluster of
    sequences and will determine if all loci are part
    of the same transcriptional unit(TU) by most-vote locus
    or by exclusion of common sequence that are the
    minority of two loci.

    :param c: class cluster
    :param locilen_sorted: list of loci sorted by size
    :param seem_seqs: list of seen sequences
    :param filtered: final TU list
    :param maxseq: bigger locus
    "param n_cluster: integer with index of different TU"
    :return
     c: updated class cluster
     seen_seqs: updated list of sequences
     filtered: updated dict of TUs
     n_cluster: updated int with current index of TUs
    """
    first_run = 0
    seen_seqs = list()
    n_cluster += 1
    logger.debug("_solve_loci:new cluster %s" % n_cluster)
    new_c = cluster(n_cluster)
    for idl, lenl in locilen_sorted:
        locus_seqs = c.loci2seq[idl]
        if first_run == 0:
            seen_seqs = locus_seqs
            first_run = 1
            first_idl = idl
        intersect = list(set(seen_seqs).intersection(locus_seqs))
        common = 0
        if intersect:
            common = len(intersect)*1.0/min(len(seen_seqs), len(locus_seqs))
        logger.debug("_sole_loci:id %s idl %s len %s max %s seen %s inter %s common %s " % (c.id, idl, lenl, maxseq, len(seen_seqs), len(intersect), common))
        if common*1.0 >= 0.6:
            if lenl*1.0 >= 0.6*maxseq:
                c, new_c, seen_seqs = _merge_loci_in_cluster(c, new_c, idl, seen_seqs)
            else:
                c, new_c, seen_seqs = _merge_with_first_loci(c, new_c, first_idl, idl, seen_seqs)
        else:
            c = _remove_seqs_from_loci(c, idl, seen_seqs)
    filtered[n_cluster] = new_c
    return c, seen_seqs, filtered, n_cluster


def _merge_loci_in_cluster(c, new_c, idl, current_seqs):
    logger.debug("_merge_loci_in_cluster:join")
    locus_seqs = c.loci2seq[idl]
    common = len(set(locus_seqs).intersection(current_seqs))
    seen = list(set(locus_seqs).union(current_seqs))
    new_c.add_id_member(list(locus_seqs), idl)
    c.loci2seq.pop(idl, "None")
    c.locilen.pop(idl, "None")
    return c, new_c, seen


def _merge_with_first_loci(c, new_c, first_idl, idl, current_seqs):
    logger.debug("_merge_with_first_loci:join first")
    locus_seqs = c.loci2seq[idl]
    seen = list(set(locus_seqs).union(current_seqs))
    new_c.add_id_member(list(locus_seqs), first_idl)
    c.loci2seq.pop(idl, "None")
    c.locilen.pop(idl, "None")
    return c, new_c, seen


def _remove_seqs_from_loci(c, idl, current_seqs):
    current = c.loci2seq[idl]
    seen = list(set(current).intersection(current_seqs))
    unseen = list(set(sorted(current)).difference(sorted(seen)))
    logger.debug("_remove_seqs_from_loci:seen %s unseen %s" % (len(seen), len(unseen)))
    c.locilen[idl] = len(unseen)
    c.loci2seq[idl] = unseen
    logger.debug("_remove_seqs_from_loci:remove; new len %s" % len(unseen))
    if c.locilen[idl] == 0:
        c.loci2seq.pop(idl, "None")
        c.locilen.pop(idl, "None")
    return c


