from __future__ import annotations

import warnings
from functools import partial
from operator import add
from typing import Literal

import numpy as np

from ._logging import get_logger
from ._typing import MapFunctor
from .api import Cooler
from .parallel import partition, split
from .util import mad

__all__ = ["balance_cooler"]

logger = get_logger(__name__)


class ConvergenceWarning(UserWarning):
    pass


def _init(chunk):
    return np.copy(chunk["pixels"]["count"])


def _binarize(chunk, data):
    data[data != 0] = 1
    return data


def _zero_diags(n_diags, chunk, data):
    pixels = chunk["pixels"]
    mask = np.abs(pixels["bin1_id"] - pixels["bin2_id"]) < n_diags
    data[mask] = 0
    return data


def _zero_trans(chunk, data):
    chrom_ids = chunk["bins"]["chrom"]
    pixels = chunk["pixels"]
    mask = chrom_ids[pixels["bin1_id"]] != chrom_ids[pixels["bin2_id"]]
    data[mask] = 0
    return data


def _zero_cis(chunk, data):
    chrom_ids = chunk["bins"]["chrom"]
    pixels = chunk["pixels"]
    mask = chrom_ids[pixels["bin1_id"]] == chrom_ids[pixels["bin2_id"]]
    data[mask] = 0
    return data


def _timesouterproduct(vec, chunk, data):
    pixels = chunk["pixels"]
    data = vec[pixels["bin1_id"]] * vec[pixels["bin2_id"]] * data
    return data


def _marginalize(chunk, data):
    n = len(chunk["bins"]["chrom"])
    pixels = chunk["pixels"]
    marg = np.bincount(pixels["bin1_id"], weights=data, minlength=n) + np.bincount(
        pixels["bin2_id"], weights=data, minlength=n
    )
    return marg


def _balance_genomewide(
    bias,
    clr,
    spans,
    filters,
    chunksize,
    map,
    tol,
    max_iters,
    rescale_marginals,
    use_lock,
):
    scale = 1.0
    n_bins = len(bias)

    for _ in range(max_iters):
        marg = (
            split(clr, spans=spans, map=map, use_lock=use_lock)
            .prepare(_init)
            .pipe(filters)
            .pipe(_timesouterproduct, bias)
            .pipe(_marginalize)
            .reduce(add, np.zeros(n_bins))
        )

        nzmarg = marg[marg != 0]
        if not len(nzmarg):
            scale = np.nan
            bias[:] = np.nan
            var = 0.0
            break

        marg = marg / nzmarg.mean()
        marg[marg == 0] = 1
        bias /= marg

        var = nzmarg.var()
        logger.info(f"variance is {var}")
        if var < tol:
            break
    else:
        warnings.warn(
            "Iteration limit reached without convergence.",
            ConvergenceWarning,
            stacklevel=1,
        )

    scale = nzmarg.mean()
    bias[bias == 0] = np.nan
    if rescale_marginals:
        bias /= np.sqrt(scale)

    return bias, scale, var


def _balance_cisonly(
    bias,
    clr,
    spans,
    filters,
    chunksize,
    map,
    tol,
    max_iters,
    rescale_marginals,
    use_lock,
):
    chroms = clr.chroms()["name"][:]
    chrom_ids = np.arange(len(clr.chroms()))
    chrom_offsets = clr._load_dset("indexes/chrom_offset")
    bin1_offsets = clr._load_dset("indexes/bin1_offset")
    scales = np.ones(len(chrom_ids))
    variances = np.full_like(scales, np.nan)
    n_bins = len(bias)

    for cid, lo, hi in zip(chrom_ids, chrom_offsets[:-1], chrom_offsets[1:]):
        logger.info(chroms[cid])

        plo, phi = bin1_offsets[lo], bin1_offsets[hi]
        spans = list(partition(plo, phi, chunksize))
        scale = 1.0
        var = np.nan
        for _ in range(max_iters):
            marg = (
                split(clr, spans=spans, map=map, use_lock=use_lock)
                .prepare(_init)
                .pipe(filters)
                .pipe(_timesouterproduct, bias)
                .pipe(_marginalize)
                .reduce(add, np.zeros(n_bins))
            )

            marg = marg[lo:hi]
            nzmarg = marg[marg != 0]
            if not len(nzmarg):
                scale = np.nan
                bias[lo:hi] = np.nan
                var = 0.0
                break

            marg = marg / nzmarg.mean()
            marg[marg == 0] = 1
            bias[lo:hi] /= marg

            var = nzmarg.var()
            logger.info(f"variance is {var}")
            if var < tol:
                break

        else:
            warnings.warn(
                f"Iteration limit reached without convergence on {chroms[cid]}.",
                ConvergenceWarning,
                stacklevel=1,
            )

        scale = nzmarg.mean()
        b = bias[lo:hi]
        b[b == 0] = np.nan
        scales[cid] = scale
        variances[cid] = var
        if rescale_marginals:
            bias[lo:hi] /= np.sqrt(scale)

    return bias, scales, variances


def _balance_transonly(
    bias,
    clr,
    spans,
    filters,
    chunksize,
    map,
    tol,
    max_iters,
    rescale_marginals,
    use_lock,
):
    scale = 1.0
    n_bins = len(bias)

    chrom_offsets = clr._load_dset("indexes/chrom_offset")
    cweights = 1.0 / np.concatenate(
        [
            [(1 - (hi - lo) / n_bins)] * (hi - lo)
            for lo, hi in zip(chrom_offsets[:-1], chrom_offsets[1:])
        ]
    )

    for _ in range(max_iters):
        marg = (
            split(clr, spans=spans, map=map, use_lock=use_lock)
            .prepare(_init)
            .pipe(filters)
            .pipe(_zero_cis)
            .pipe(_timesouterproduct, bias * cweights)
            .pipe(_marginalize)
            .reduce(add, np.zeros(n_bins))
        )

        nzmarg = marg[marg != 0]
        if not len(nzmarg):
            scale = np.nan
            bias[:] = np.nan
            var = 0.0
            break

        marg = marg / nzmarg.mean()
        marg[marg == 0] = 1
        bias /= marg

        var = nzmarg.var()
        logger.info(f"variance is {var}")
        if var < tol:
            break
    else:
        warnings.warn(
            "Iteration limit reached without convergence.",
            ConvergenceWarning,
            stacklevel=1,
        )

    scale = nzmarg.mean()
    bias[bias == 0] = np.nan
    if rescale_marginals:
        bias /= np.sqrt(scale)

    return bias, scale, var


def balance_cooler(
    clr: Cooler,
    *,
    cis_only: bool = False,
    trans_only: bool = False,
    ignore_diags: int | Literal[False] = 2,
    mad_max: int = 5,
    min_nnz: int = 10,
    min_count: int = 0,
    blacklist: str | None = None,
    rescale_marginals: bool = True,
    x0: np.ndarray | None = None,
    tol: float = 1e-5,
    max_iters: int = 200,
    chunksize: int = 10_000_000,
    map: MapFunctor = map,
    use_lock: bool = False,
    store: bool = False,
    store_name: str = "weight",
) -> tuple[np.ndarray, dict]:
    """
    Iterative correction or matrix balancing of a sparse Hi-C contact map in
    Cooler HDF5 format.

    Parameters
    ----------
    clr : cooler.Cooler
        Cooler object
    cis_only : bool, optional
        Do iterative correction on intra-chromosomal data only.
        Inter-chromosomal data is ignored.
    trans_only : bool, optional
        Do iterative correction on inter-chromosomal data only.
        Intra-chromosomal data is ignored.
    ignore_diags : int or False, optional
        Drop elements occurring on the first ``ignore_diags`` diagonals of the
        matrix (including the main diagonal).
    chunksize : int or None, optional
        Split the contact matrix pixel records into equally sized chunks to
        save memory and/or parallelize. Set to ``None`` to use all the pixels
        at once.
    mad_max : int, optional
        Pre-processing bin-level filter. Drop bins whose log marginal sum is
        less than ``mad_max`` median absolute deviations below the median log
        marginal sum.
    min_nnz : int, optional
        Pre-processing bin-level filter. Drop bins with fewer nonzero elements
        than this value.
    min_count : int, optional
        Pre-processing bin-level filter. Drop bins with lower marginal sum than
        this value.
    blacklist : list or 1D array, optional
        An explicit list of IDs of bad bins to filter out when performing
        balancing.
    rescale_marginals : bool, optional
        Normalize the balancing weights such that the balanced matrix has rows
        / columns that sum to 1.0. The scale factor is stored in the ``stats``
        output dictionary.
    map : callable, optional
        Map function to dispatch the matrix chunks to workers.
        Default is the builtin ``map``, but alternatives include parallel map
        implementations from a multiprocessing pool.
    x0 : 1D array, optional
        Initial weight vector to use. Default is to start with ones(n_bins).
    tol : float, optional
        Convergence criterion is the variance of the marginal (row/col) sum
        vector.
    max_iters : int, optional
        Iteration limit.
    store : bool, optional
        Whether to store the results in the file when finished. Default is
        False.
    store_name : str, optional
        Name of the column of the bin table to save to. Default name is
        'weight'.

    Returns
    -------
    bias : 1D array, whose shape is the number of bins in ``h5``.
        Vector of bin bias weights to normalize the observed contact map.
        Dropped bins will be assigned the value NaN.
        N[i, j] = O[i, j] * bias[i] * bias[j]
    stats : dict
        Summary of parameters used to perform balancing and the average
        magnitude of the corrected matrix's marginal sum at convergence.

    """
    # Divide the number of elements into non-overlapping chunks
    nnz = int(clr.info["nnz"])
    if chunksize is None:
        chunksize = nnz
        spans = [(0, nnz)]
    else:
        edges = np.arange(0, nnz + chunksize, chunksize)
        spans = list(zip(edges[:-1], edges[1:]))

    # List of pre-marginalization data transformations
    base_filters = []
    if cis_only:
        base_filters.append(_zero_trans)
    if ignore_diags:
        base_filters.append(partial(_zero_diags, ignore_diags))

    # Initialize the bias weights
    n_bins = int(clr.info["nbins"])
    if x0 is not None:
        bias = x0
        bias[np.isnan(bias)] = 0
    else:
        bias = np.ones(n_bins, dtype=float)

    # Drop bins with too few nonzeros from bias
    if min_nnz > 0:
        filters = [_binarize, *base_filters]
        marg_nnz = (
            split(clr, spans=spans, map=map, use_lock=use_lock)
            .prepare(_init)
            .pipe(filters)
            .pipe(_marginalize)
            .reduce(add, np.zeros(n_bins))
        )
        bias[marg_nnz < min_nnz] = 0

    filters = base_filters
    marg = (
        split(clr, spans=spans, map=map, use_lock=use_lock)
        .prepare(_init)
        .pipe(filters)
        .pipe(_marginalize)
        .reduce(add, np.zeros(n_bins))
    )

    # Drop bins with too few total counts from bias
    if min_count:
        bias[marg < min_count] = 0

    # MAD-max filter on the marginals
    if mad_max > 0:
        offsets = clr._load_dset("indexes/chrom_offset")
        for lo, hi in zip(offsets[:-1], offsets[1:]):
            c_marg = marg[lo:hi]
            marg[lo:hi] /= np.median(c_marg[c_marg > 0])
        logNzMarg = np.log(marg[marg > 0])
        med_logNzMarg = np.median(logNzMarg)
        dev_logNzMarg = mad(logNzMarg)
        cutoff = np.exp(med_logNzMarg - mad_max * dev_logNzMarg)
        bias[marg < cutoff] = 0

    # Filter out pre-determined bad bins
    if blacklist is not None:
        bias[blacklist] = 0

    # Do balancing
    if cis_only:
        bias, scale, var = _balance_cisonly(
            bias,
            clr,
            spans,
            base_filters,
            chunksize,
            map,
            tol,
            max_iters,
            rescale_marginals,
            use_lock,
        )
    elif trans_only:
        bias, scale, var = _balance_transonly(
            bias,
            clr,
            spans,
            base_filters,
            chunksize,
            map,
            tol,
            max_iters,
            rescale_marginals,
            use_lock,
        )
    else:
        bias, scale, var = _balance_genomewide(
            bias,
            clr,
            spans,
            base_filters,
            chunksize,
            map,
            tol,
            max_iters,
            rescale_marginals,
            use_lock,
        )

    stats = {
        "tol": tol,
        "min_nnz": min_nnz,
        "min_count": min_count,
        "mad_max": mad_max,
        "cis_only": cis_only,
        "ignore_diags": ignore_diags,
        "scale": scale,
        "converged": var < tol,
        "var": var,
        "divisive_weights": False,
    }

    if store:
        with clr.open("r+") as grp:
            if store_name in grp["bins"]:
                del grp["bins"][store_name]
            h5opts = {"compression": "gzip", "compression_opts": 6}
            grp["bins"].create_dataset(store_name, data=bias, **h5opts)
            grp["bins"][store_name].attrs.update(stats)

    return bias, stats


iterative_correction = balance_cooler  # alias
