"""Some utility functions for rank estimation."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np
from scipy import linalg

from ._fiff.meas_info import Info, _simplify_info
from ._fiff.pick import _picks_by_type, _picks_to_idx, pick_channels_cov, pick_info
from ._fiff.proj import make_projector
from .defaults import _handle_default
from .utils import (
    _apply_scaling_cov,
    _check_on_missing,
    _check_rank,
    _compute_row_norms,
    _on_missing,
    _pl,
    _scaled_array,
    _undo_scaling_cov,
    _validate_type,
    fill_doc,
    logger,
    verbose,
    warn,
)


@verbose
def estimate_rank(
    data,
    tol="auto",
    return_singular=False,
    norm=True,
    tol_kind="absolute",
    verbose=None,
):
    """Estimate the rank of data.

    This function will normalize the rows of the data (typically
    channels or vertices) such that non-zero singular values
    should be close to one.

    Parameters
    ----------
    data : array
        Data to estimate the rank of (should be 2-dimensional).
    %(tol_rank)s
    return_singular : bool
        If True, also return the singular values that were used
        to determine the rank.
    norm : bool
        If True, data will be scaled by their estimated row-wise norm.
        Else data are assumed to be scaled. Defaults to True.
    %(tol_kind_rank)s

    Returns
    -------
    rank : int
        Estimated rank of the data.
    s : array
        If return_singular is True, the singular values that were
        thresholded to determine the rank are also returned.
    """
    if norm:
        data = data.copy()  # operate on a copy
        norms = _compute_row_norms(data)
        data /= norms[:, np.newaxis]
    s = linalg.svdvals(data)
    rank = _estimate_rank_from_s(s, tol, tol_kind)
    if return_singular is True:
        return rank, s
    else:
        return rank


def _estimate_rank_from_s(s, tol="auto", tol_kind="absolute"):
    """Estimate the rank of a matrix from its singular values.

    Parameters
    ----------
    s : ndarray, shape (..., ndim)
        The singular values of the matrix.
    tol : float | ``'auto'``
        Tolerance for singular values to consider non-zero in calculating the
        rank. Can be 'auto' to use the same thresholding as
        ``scipy.linalg.orth`` (assuming np.float64 datatype) adjusted
        by a factor of 2.
    tol_kind : str
        Can be ``"absolute"`` or ``"relative"``.

    Returns
    -------
    rank : ndarray, shape (...)
        The estimated rank.
    """
    s = np.array(s, float)
    max_s = np.amax(s, axis=-1)
    if isinstance(tol, str):
        if tol not in ("auto", "float32"):
            raise ValueError(f'tol must be "auto" or float, got {repr(tol)}')
        # XXX this should be float32 probably due to how we save and
        # load data, but it breaks test_make_inverse_operator (!)
        # The factor of 2 gets test_compute_covariance_auto_reg[None]
        # to pass without breaking minimum norm tests. :(
        # Passing 'float32' is a hack workaround for test_maxfilter_get_rank :(
        if tol == "float32":
            eps = np.finfo(np.float32).eps
        else:
            eps = np.finfo(np.float64).eps
        tol = s.shape[-1] * max_s * eps
        if s.ndim == 1:  # typical
            logger.info(
                "    Using tolerance %0.2g (%0.2g eps * %d dim * %0.2g"
                "  max singular value)",
                tol,
                eps,
                len(s),
                max_s,
            )
    elif not (isinstance(tol, np.ndarray) and tol.dtype.kind == "f"):
        tol = float(tol)
        if tol_kind == "relative":
            tol = tol * max_s

    rank = np.sum(s > tol, axis=-1)
    return rank


def _estimate_rank_raw(
    raw, picks=None, tol=1e-4, scalings="norm", with_ref_meg=False, tol_kind="absolute"
):
    """Aid the transition away from raw.estimate_rank."""
    if picks is None:
        picks = _picks_to_idx(raw.info, picks, with_ref_meg=with_ref_meg)
    # conveniency wrapper to expose the expert "tol" option + scalings options
    return _estimate_rank_meeg_signals(
        raw[picks][0], pick_info(raw.info, picks), scalings, tol, False, tol_kind
    )


@fill_doc
def _estimate_rank_meeg_signals(
    data,
    info,
    scalings,
    tol="auto",
    return_singular=False,
    tol_kind="absolute",
    log_ch_type=None,
):
    """Estimate rank for M/EEG data.

    Parameters
    ----------
    data : np.ndarray of float, shape(n_channels, n_samples)
        The M/EEG signals.
    %(info_not_none)s
    scalings : dict | ``'norm'`` | np.ndarray | None
        The rescaling method to be applied. If dict, it will override the
        following default dict:

            dict(mag=1e15, grad=1e13, eeg=1e6)

        If ``'norm'`` data will be scaled by channel-wise norms. If array,
        pre-specified norms will be used. If None, no scaling will be applied.
    tol : float | str
        Tolerance. See ``estimate_rank``.
    return_singular : bool
        If True, also return the singular values that were used
        to determine the rank.
    tol_kind : str
        Tolerance kind. See ``estimate_rank``.

    Returns
    -------
    rank : int
        Estimated rank of the data.
    s : array
        If return_singular is True, the singular values that were
        thresholded to determine the rank are also returned.
    """
    picks_list = _picks_by_type(info)
    if data.shape[1] < data.shape[0]:
        ValueError(
            "You've got fewer samples than channels, your "
            "rank estimate might be inaccurate."
        )
    with _scaled_array(data, picks_list, scalings):
        out = estimate_rank(
            data,
            tol=tol,
            norm=False,
            return_singular=return_singular,
            tol_kind=tol_kind,
        )
    rank = out[0] if isinstance(out, tuple) else out
    if log_ch_type is None:
        ch_type = " + ".join(list(zip(*picks_list))[0])
    else:
        ch_type = log_ch_type
    logger.info("    Estimated rank (%s): %d", ch_type, rank)
    return out


@verbose
def _estimate_rank_meeg_cov(
    data,
    info,
    scalings,
    tol="auto",
    return_singular=False,
    *,
    log_ch_type=None,
    verbose=None,
):
    """Estimate rank of M/EEG covariance data, given the covariance.

    Parameters
    ----------
    data : np.ndarray of float, shape (n_channels, n_channels)
        The M/EEG covariance.
    %(info_not_none)s
    scalings : dict | 'norm' | np.ndarray | None
        The rescaling method to be applied. If dict, it will override the
        following default dict:

            dict(mag=1e12, grad=1e11, eeg=1e5)

        If 'norm' data will be scaled by channel-wise norms. If array,
        pre-specified norms will be used. If None, no scaling will be applied.
    tol : float | str
        Tolerance. See ``estimate_rank``.
    return_singular : bool
        If True, also return the singular values that were used
        to determine the rank.

    Returns
    -------
    rank : int
        Estimated rank of the data.
    s : array
        If return_singular is True, the singular values that were
        thresholded to determine the rank are also returned.
    """
    picks_list = _picks_by_type(info, exclude=[])
    scalings = _handle_default("scalings_cov_rank", scalings)
    _apply_scaling_cov(data, picks_list, scalings)
    if data.shape[1] < data.shape[0]:
        ValueError(
            "You've got fewer samples than channels, your "
            "rank estimate might be inaccurate."
        )
    out = estimate_rank(data, tol=tol, norm=False, return_singular=return_singular)
    rank = out[0] if isinstance(out, tuple) else out
    if log_ch_type is None:
        ch_type_ = " + ".join(list(zip(*picks_list))[0])
    else:
        ch_type_ = log_ch_type
    logger.info(f"    Estimated rank ({ch_type_}): {rank}")
    _undo_scaling_cov(data, picks_list, scalings)
    return out


@verbose
def _get_rank_sss(
    inst, msg="You should use data-based rank estimate instead", verbose=None
):
    """Look up rank from SSS data.

    .. note::
        Throws an error if SSS has not been applied.

    Parameters
    ----------
    inst : instance of Raw, Epochs or Evoked, or Info
        Any MNE object with an .info attribute

    Returns
    -------
    rank : int
        The numerical rank as predicted by the number of SSS
        components.
    """
    # XXX this is too basic for movement compensated data
    # https://github.com/mne-tools/mne-python/issues/4676
    info = inst if isinstance(inst, Info) else inst.info
    del inst

    proc_info = info.get("proc_history", [])
    if len(proc_info) > 1:
        logger.info("Found multiple SSS records. Using the first.")
    if (
        len(proc_info) == 0
        or "max_info" not in proc_info[0]
        or "in_order" not in proc_info[0]["max_info"]["sss_info"]
    ):
        raise ValueError(
            f'Could not find Maxfilter information in info["proc_history"]. {msg}'
        )
    proc_info = proc_info[0]
    max_info = proc_info["max_info"]
    inside = max_info["sss_info"]["in_order"]
    nfree = (inside + 1) ** 2 - 1
    nfree -= (
        len(max_info["sss_info"]["components"][:nfree])
        - max_info["sss_info"]["components"][:nfree].sum()
    )
    return nfree


def _info_rank(info, ch_type, picks, rank):
    if ch_type in ["meg", "mag", "grad"] and rank != "full":
        try:
            return _get_rank_sss(info)
        except ValueError:
            pass
    return len(picks)


def _compute_rank_int(inst, *args, **kwargs):
    """Wrap compute_rank but yield an int."""
    # XXX eventually we should unify how channel types are handled
    # so that we don't need to do this, or we do it everywhere.
    # Using pca=True in compute_whitener might help.
    return sum(compute_rank(inst, *args, **kwargs).values())


@verbose
def compute_rank(
    inst,
    rank=None,
    scalings=None,
    info=None,
    tol="auto",
    proj=True,
    tol_kind="absolute",
    on_rank_mismatch="ignore",
    verbose=None,
):
    """Compute the rank of data or noise covariance.

    This function will normalize the rows of the data (typically
    channels or vertices) such that non-zero singular values
    should be close to one. It operates on :term:`data channels` only.

    Parameters
    ----------
    inst : instance of Raw, Epochs, or Covariance
        Raw measurements to compute the rank from or the covariance.
    %(rank_none)s
    scalings : dict | None (default None)
        Defaults to ``dict(mag=1e15, grad=1e13, eeg=1e6)``.
        These defaults will scale different channel types
        to comparable values.
    %(info)s Only necessary if ``inst`` is a :class:`mne.Covariance`
        object (since this does not provide ``inst.info``).
    %(tol_rank)s
    proj : bool
        If True, all projs in ``inst`` and ``info`` will be applied or
        considered when ``rank=None`` or ``rank='info'``.
    %(tol_kind_rank)s
    %(on_rank_mismatch)s
    %(verbose)s

    Returns
    -------
    rank : dict
        Estimated rank of the data for each channel type.
        To get the total rank, you can use ``sum(rank.values())``.

    Notes
    -----
    .. versionadded:: 0.18
    """
    return _compute_rank(
        inst=inst,
        rank=rank,
        scalings=scalings,
        info=info,
        tol=tol,
        proj=proj,
        tol_kind=tol_kind,
        on_rank_mismatch=on_rank_mismatch,
    )


@verbose
def _compute_rank(
    inst,
    rank=None,
    scalings=None,
    info=None,
    *,
    tol="auto",
    proj=True,
    tol_kind="absolute",
    on_rank_mismatch="ignore",
    log_ch_type=None,
    verbose=None,
):
    from .cov import Covariance
    from .epochs import BaseEpochs
    from .io import BaseRaw

    rank = _check_rank(rank)
    scalings = _handle_default("scalings_cov_rank", scalings)
    _check_on_missing(on_rank_mismatch, "on_rank_mismatch")

    if isinstance(inst, Covariance):
        inst_type = "covariance"
        if info is None:
            raise ValueError("info cannot be None if inst is a Covariance.")
        # Reset bads as it's already taken into account in inst['names']
        info = info.copy()
        info["bads"] = []
        inst = pick_channels_cov(
            inst,
            set(inst["names"]) & set(info["ch_names"]),
            exclude=info["bads"] + inst["bads"],
            ordered=False,
        )
        if info["ch_names"] != inst["names"]:
            info = pick_info(
                info, [info["ch_names"].index(name) for name in inst["names"]]
            )
    else:
        info = inst.info
        inst_type = "data"
    logger.info(f"Computing rank from {inst_type} with rank={repr(rank)}")

    _validate_type(rank, (str, dict, None), "rank")
    if isinstance(rank, str):  # string, either 'info' or 'full'
        rank_type = "info"
        info_type = rank
        rank = dict()
    else:  # None or dict
        rank_type = "estimated"
        if rank is None:
            rank = dict()

    simple_info = _simplify_info(info)
    picks_list = _picks_by_type(info, meg_combined=True, ref_meg=False, exclude="bads")
    for ch_type, picks in picks_list:
        est_verbose = None
        if ch_type in rank:
            # raise an error of user-supplied rank exceeds number of channels
            if rank[ch_type] > len(picks):
                raise ValueError(
                    f"rank[{repr(ch_type)}]={rank[ch_type]} exceeds the number"
                    f" of channels ({len(picks)})"
                )
            # special case: if whitening a covariance, check the passed rank
            # against the estimated one
            est_verbose = False
            if not (
                on_rank_mismatch != "ignore"
                and rank_type == "estimated"
                and ch_type == "meg"
                and isinstance(inst, Covariance)
                and not inst["diag"]
            ):
                continue
        ch_names = [info["ch_names"][pick] for pick in picks]
        n_chan = len(ch_names)
        if proj:
            proj_op, n_proj, _ = make_projector(info["projs"], ch_names)
        else:
            proj_op, n_proj = None, 0
        if log_ch_type is None:
            ch_type_ = ch_type.upper()
        else:
            ch_type_ = log_ch_type
        if rank_type == "info":
            # use info
            this_rank = _info_rank(info, ch_type, picks, info_type)
            if info_type != "full":
                this_rank -= n_proj
                logger.info(
                    f"    {ch_type_}: rank {this_rank} after "
                    f"{n_proj} projector{_pl(n_proj)} applied to "
                    f"{n_chan} channel{_pl(n_chan)}"
                )
            else:
                logger.info(f"    {ch_type_}: rank {this_rank} from info")
        else:
            # Use empirical estimation
            assert rank_type == "estimated"
            if isinstance(inst, BaseRaw | BaseEpochs):
                if isinstance(inst, BaseRaw):
                    data = inst.get_data(picks, reject_by_annotation="omit")
                else:  # isinstance(inst, BaseEpochs):
                    data = np.concatenate(inst.get_data(picks), axis=1)
                if proj:
                    data = np.dot(proj_op, data)
                this_rank = _estimate_rank_meeg_signals(
                    data,
                    pick_info(simple_info, picks),
                    scalings,
                    tol,
                    False,
                    tol_kind,
                    log_ch_type=log_ch_type,
                )
            else:
                assert isinstance(inst, Covariance)
                if inst["diag"]:
                    this_rank = (inst["data"][picks] > 0).sum() - n_proj
                else:
                    data = inst["data"][picks][:, picks]
                    if proj:
                        data = np.dot(np.dot(proj_op, data), proj_op.T)

                    this_rank, sing = _estimate_rank_meeg_cov(
                        data,
                        pick_info(simple_info, picks),
                        scalings,
                        tol,
                        return_singular=True,
                        log_ch_type=log_ch_type,
                        verbose=est_verbose,
                    )
                    if ch_type in rank:
                        ratio = sing[this_rank - 1] / sing[rank[ch_type] - 1]
                        if ratio > 100:
                            msg = (
                                f"The passed rank[{repr(ch_type)}]="
                                f"{rank[ch_type]} exceeds the estimated rank "
                                f"of the noise covariance ({this_rank}) "
                                f"leading to a potential increase in "
                                f"noise during whitening by a factor "
                                f"of {np.sqrt(ratio):0.1g}. Ensure that the "
                                f"rank correctly corresponds to that of the "
                                f"given noise covariance matrix."
                            )
                            _on_missing(on_rank_mismatch, msg, "on_rank_mismatch")
                        continue
            this_info_rank = _info_rank(info, ch_type, picks, "info")
            logger.info(
                f"    {ch_type_}: rank {this_rank} computed from "
                f"{n_chan} data channel{_pl(n_chan)} with "
                f"{n_proj} projector{_pl(n_proj)}"
            )
            if this_rank > this_info_rank:
                warn(
                    "Something went wrong in the data-driven estimation of the data "
                    "rank as it exceeds the theoretical rank from the info "
                    f"({this_rank} > {this_info_rank}). Consider setting rank "
                    'to "auto" or setting it explicitly as an integer.'
                )
        if ch_type not in rank:
            rank[ch_type] = int(this_rank)

    return rank
