# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

from copy import deepcopy

import numpy as np
from scipy.fft import fft, fftfreq, ifft

from .._fiff.pick import _pick_data_channels, pick_info
from ..parallel import parallel_func
from ..utils import _validate_type, legacy, logger, verbose
from .tfr import AverageTFRArray, _ensure_slice, _get_data


def _check_input_st(x_in, n_fft):
    """Aux function."""
    # flatten to 2 D and memorize original shape
    n_times = x_in.shape[-1]

    def _is_power_of_two(n):
        return not (n > 0 and (n & (n - 1)))

    if n_fft is None or (not _is_power_of_two(n_fft) and n_times > n_fft):
        # Compute next power of 2
        n_fft = 2 ** int(np.ceil(np.log2(n_times)))
    elif n_fft < n_times:
        raise ValueError(
            f"n_fft cannot be smaller than signal size. Got {n_fft} < {n_times}."
        )
    if n_times < n_fft:
        logger.info(
            f'The input signal is shorter ({x_in.shape[-1]}) than "n_fft" ({n_fft}). '
            "Applying zero padding."
        )
        zero_pad = n_fft - n_times
        pad_array = np.zeros(x_in.shape[:-1] + (zero_pad,), x_in.dtype)
        x_in = np.concatenate((x_in, pad_array), axis=-1)
    else:
        zero_pad = 0
    return x_in, n_fft, zero_pad


def _precompute_st_windows(n_samp, start_f, stop_f, sfreq, width):
    """Precompute stockwell Gaussian windows (in the freq domain)."""
    tw = fftfreq(n_samp, 1.0 / sfreq) / n_samp
    tw = np.r_[tw[:1], tw[1:][::-1]]

    k = width  # 1 for classical stowckwell transform
    f_range = np.arange(start_f, stop_f, 1)
    windows = np.empty((len(f_range), len(tw)), dtype=np.complex128)
    for i_f, f in enumerate(f_range):
        if f == 0.0:
            window = np.ones(len(tw))
        else:
            window = (f / (np.sqrt(2.0 * np.pi) * k)) * np.exp(
                -0.5 * (1.0 / k**2.0) * (f**2.0) * tw**2.0
            )
        window /= window.sum()  # normalisation
        windows[i_f] = fft(window)
    return windows


def _st(x, start_f, windows):
    """Compute ST based on Ali Moukadem MATLAB code (used in tests)."""
    from scipy.fft import fft, ifft

    n_samp = x.shape[-1]
    ST = np.empty(x.shape[:-1] + (len(windows), n_samp), dtype=np.complex128)
    # do the work
    Fx = fft(x)
    XF = np.concatenate([Fx, Fx], axis=-1)
    for i_f, window in enumerate(windows):
        f = start_f + i_f
        ST[..., i_f, :] = ifft(XF[..., f : f + n_samp] * window)
    return ST


def _st_power_itc(x, start_f, compute_itc, zero_pad, decim, W):
    """Aux function."""
    decim = _ensure_slice(decim)
    n_samp = x.shape[-1]
    decim_indices = decim.indices(n_samp - zero_pad)
    n_out = len(range(*decim_indices))
    psd = np.empty((len(W), n_out))
    itc = np.empty_like(psd) if compute_itc else None
    X = fft(x)
    XX = np.concatenate([X, X], axis=-1)
    for i_f, window in enumerate(W):
        f = start_f + i_f
        ST = ifft(XX[:, f : f + n_samp] * window)
        TFR = ST[:, slice(*decim_indices)]
        TFR_abs = np.abs(TFR)
        TFR_abs[TFR_abs == 0] = 1.0
        if compute_itc:
            TFR /= TFR_abs
            itc[i_f] = np.abs(np.mean(TFR, axis=0))
        TFR_abs *= TFR_abs
        psd[i_f] = np.mean(TFR_abs, axis=0)
    return psd, itc


def _compute_freqs_st(fmin, fmax, n_fft, sfreq):
    from scipy.fft import fftfreq

    freqs = fftfreq(n_fft, 1.0 / sfreq)
    if fmin is None:
        fmin = freqs[freqs > 0][0]
    if fmax is None:
        fmax = freqs.max()

    start_f = np.abs(freqs - fmin).argmin()
    stop_f = np.abs(freqs - fmax).argmin()
    freqs = freqs[start_f:stop_f]
    return start_f, stop_f, freqs


@verbose
def tfr_array_stockwell(
    data,
    sfreq,
    fmin=None,
    fmax=None,
    n_fft=None,
    width=1.0,
    decim=1,
    return_itc=False,
    n_jobs=None,
    *,
    verbose=None,
):
    """Compute power and intertrial coherence using Stockwell (S) transform.

    Same computation as `~mne.time_frequency.tfr_stockwell`, but operates on
    :class:`NumPy arrays <numpy.ndarray>` instead of `~mne.Epochs` objects.

    See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
    for more information.

    Parameters
    ----------
    data : ndarray, shape (n_epochs, n_channels, n_times)
        The signal to transform.
    sfreq : float
        The sampling frequency.
    fmin : None, float
        The minimum frequency to include. If None defaults to the minimum fft
        frequency greater than zero.
    fmax : None, float
        The maximum frequency to include. If None defaults to the maximum fft.
    n_fft : int | None
        The length of the windows used for FFT. If None, it defaults to the
        next power of 2 larger than the signal length.
    width : float
        The width of the Gaussian window. If < 1, increased temporal
        resolution, if > 1, increased frequency resolution. Defaults to 1.
        (classical S-Transform).
    %(decim_tfr)s
    return_itc : bool
        Return intertrial coherence (ITC) as well as averaged power.
    %(n_jobs)s
    %(verbose)s

    Returns
    -------
    st_power : ndarray
        The multitaper power of the Stockwell transformed data.
        The last two dimensions are frequency and time.
    itc : ndarray
        The intertrial coherence. Only returned if return_itc is True.
    freqs : ndarray
        The frequencies.

    See Also
    --------
    mne.time_frequency.tfr_stockwell
    mne.time_frequency.tfr_multitaper
    mne.time_frequency.tfr_array_multitaper
    mne.time_frequency.tfr_morlet
    mne.time_frequency.tfr_array_morlet

    References
    ----------
    .. footbibliography::
    """
    _validate_type(data, np.ndarray, "data")
    if data.ndim != 3:
        raise ValueError(
            "data must be 3D with shape (n_epochs, n_channels, n_times), "
            f"got {data.shape}"
        )
    decim = _ensure_slice(decim)
    _, n_channels, n_out = data[..., decim].shape
    data, n_fft_, zero_pad = _check_input_st(data, n_fft)
    start_f, stop_f, freqs = _compute_freqs_st(fmin, fmax, n_fft_, sfreq)

    W = _precompute_st_windows(data.shape[-1], start_f, stop_f, sfreq, width)
    n_freq = stop_f - start_f
    psd = np.empty((n_channels, n_freq, n_out))
    itc = np.empty((n_channels, n_freq, n_out)) if return_itc else None

    parallel, my_st, n_jobs = parallel_func(_st_power_itc, n_jobs, verbose=verbose)
    tfrs = parallel(
        my_st(data[:, c, :], start_f, return_itc, zero_pad, decim, W)
        for c in range(n_channels)
    )
    for c, (this_psd, this_itc) in enumerate(iter(tfrs)):
        psd[c] = this_psd
        if this_itc is not None:
            itc[c] = this_itc

    return psd, itc, freqs


@legacy(alt='.compute_tfr(method="stockwell", freqs="auto")')
@verbose
def tfr_stockwell(
    inst,
    fmin=None,
    fmax=None,
    n_fft=None,
    width=1.0,
    decim=1,
    return_itc=False,
    n_jobs=None,
    verbose=None,
):
    """Compute Time-Frequency Representation (TFR) using Stockwell Transform.

    Same computation as `~mne.time_frequency.tfr_array_stockwell`, but operates
    on `~mne.Epochs` objects instead of :class:`NumPy arrays <numpy.ndarray>`.

    See :footcite:`Stockwell2007,MoukademEtAl2014,WheatEtAl2010,JonesEtAl2006`
    for more information.

    Parameters
    ----------
    inst : Epochs | Evoked
        The epochs or evoked object.
    fmin : None, float
        The minimum frequency to include. If None defaults to the minimum fft
        frequency greater than zero.
    fmax : None, float
        The maximum frequency to include. If None defaults to the maximum fft.
    n_fft : int | None
        The length of the windows used for FFT. If None, it defaults to the
        next power of 2 larger than the signal length.
    width : float
        The width of the Gaussian window. If < 1, increased temporal
        resolution, if > 1, increased frequency resolution. Defaults to 1.
        (classical S-Transform).
    decim : int
        The decimation factor on the time axis. To reduce memory usage.
    return_itc : bool
        Return intertrial coherence (ITC) as well as averaged power.
    n_jobs : int
        The number of jobs to run in parallel (over channels).
    %(verbose)s

    Returns
    -------
    power : AverageTFR
        The averaged power.
    itc : AverageTFR
        The intertrial coherence. Only returned if return_itc is True.

    See Also
    --------
    mne.time_frequency.tfr_array_stockwell
    mne.time_frequency.tfr_multitaper
    mne.time_frequency.tfr_array_multitaper
    mne.time_frequency.tfr_morlet
    mne.time_frequency.tfr_array_morlet

    Notes
    -----
    .. versionadded:: 0.9.0

    References
    ----------
    .. footbibliography::
    """
    # verbose dec is used b/c subfunctions are verbose
    data = _get_data(inst, return_itc)
    picks = _pick_data_channels(inst.info)
    info = pick_info(inst.info, picks)
    data = data[:, picks, :]
    decim = _ensure_slice(decim)
    power, itc, freqs = tfr_array_stockwell(
        data,
        sfreq=info["sfreq"],
        fmin=fmin,
        fmax=fmax,
        n_fft=n_fft,
        width=width,
        decim=decim,
        return_itc=return_itc,
        n_jobs=n_jobs,
    )
    times = inst.times[decim].copy()
    nave = len(data)
    out = AverageTFRArray(
        info=info,
        data=power,
        times=times,
        freqs=freqs,
        nave=nave,
        method="stockwell-power",
    )
    if return_itc:
        out = (
            out,
            AverageTFRArray(
                info=deepcopy(info),
                data=itc,
                times=times.copy(),
                freqs=freqs.copy(),
                nave=nave,
                method="stockwell-itc",
            ),
        )
    return out
