"""Some utility functions."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import json
import logging
from collections import OrderedDict
from copy import deepcopy

import numpy as np

from ._logging import verbose, warn
from ._typing import Self
from .check import _check_pandas_installed, _check_preload, _validate_type
from .numerics import _time_mask, object_hash, object_size

logger = logging.getLogger("mne")  # one selection here used across mne-python
logger.propagate = False  # don't propagate (in case of multiple imports)


class SizeMixin:
    """Estimate MNE object sizes."""

    def __eq__(self, other):
        """Compare self to other.

        Parameters
        ----------
        other : object
            The object to compare to.

        Returns
        -------
        eq : bool
            True if the two objects are equal.
        """
        return isinstance(other, type(self)) and hash(self) == hash(other)

    @property
    def _size(self):
        """Estimate the object size."""
        try:
            size = object_size(self.info)
        except Exception:
            warn("Could not get size for self.info")
            return -1
        if hasattr(self, "data"):
            size += object_size(self.data)
        elif hasattr(self, "_data"):
            size += object_size(self._data)
        return size

    def __hash__(self):
        """Hash the object.

        Returns
        -------
        hash : int
            The hash
        """
        from ..epochs import BaseEpochs
        from ..evoked import Evoked
        from ..io import BaseRaw

        if isinstance(self, Evoked):
            return object_hash(dict(info=self.info, data=self.data))
        elif isinstance(self, BaseEpochs | BaseRaw):
            _check_preload(self, "Hashing ")
            return object_hash(dict(info=self.info, data=self._data))
        else:
            raise RuntimeError(f"Hashing unknown object type: {type(self)}")


class GetEpochsMixin:
    """Class to add epoch selection and metadata to certain classes."""

    def __getitem__(
        self: Self,
        item,
    ) -> Self:
        """Return an Epochs object with a copied subset of epochs.

        Parameters
        ----------
        item : int | slice | array-like | str
            See Notes for use cases.

        Returns
        -------
        epochs : instance of Epochs
            The subset of epochs.

        Notes
        -----
        Epochs can be accessed as ``epochs[...]`` in several ways:

        1. **Integer or slice:** ``epochs[idx]`` will return an `~mne.Epochs`
           object with a subset of epochs chosen by index (supports single
           index and Python-style slicing).

        2. **String:** ``epochs['name']`` will return an `~mne.Epochs` object
           comprising only the epochs labeled ``'name'`` (i.e., epochs created
           around events with the label ``'name'``).

           If there are no epochs labeled ``'name'`` but there are epochs
           labeled with /-separated tags (e.g. ``'name/left'``,
           ``'name/right'``), then ``epochs['name']`` will select the epochs
           with labels that contain that tag (e.g., ``epochs['left']`` selects
           epochs labeled ``'audio/left'`` and ``'visual/left'``, but not
           ``'audio_left'``).

           If multiple tags are provided *as a single string* (e.g.,
           ``epochs['name_1/name_2']``), this selects epochs containing *all*
           provided tags. For example, ``epochs['audio/left']`` selects
           ``'audio/left'`` and ``'audio/quiet/left'``, but not
           ``'audio/right'``. Note that tag-based selection is insensitive to
           order: tags like ``'audio/left'`` and ``'left/audio'`` will be
           treated the same way when selecting via tag.

        3. **List of strings:** ``epochs[['name_1', 'name_2', ... ]]`` will
           return an `~mne.Epochs` object comprising epochs that match *any* of
           the provided names (i.e., the list of names is treated as an
           inclusive-or condition). If *none* of the provided names match any
           epoch labels, a ``KeyError`` will be raised.

           If epoch labels are /-separated tags, then providing multiple tags
           *as separate list entries* will likewise act as an inclusive-or
           filter. For example, ``epochs[['audio', 'left']]`` would select
           ``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not
           ``'visual/right'``.

        4. **Pandas query:** ``epochs['pandas query']`` will return an
           `~mne.Epochs` object with a subset of epochs (and matching
           metadata) selected by the query called with
           ``self.metadata.eval``, e.g.::

               epochs["col_a > 2 and col_b == 'foo'"]

           would return all epochs whose associated ``col_a`` metadata was
           greater than two, and whose ``col_b`` metadata was the string 'foo'.
           Query-based indexing only works if Pandas is installed and
           ``self.metadata`` is a :class:`pandas.DataFrame`.

           .. versionadded:: 0.16
        """
        return self._getitem(item)

    def _item_to_select(self, item):
        if isinstance(item, str):
            item = [item]

        # Convert string to indices
        if (
            isinstance(item, list | tuple)
            and len(item) > 0
            and isinstance(item[0], str)
        ):
            select = self._keys_to_idx(item)
        elif isinstance(item, slice):
            select = item
        else:
            select = np.atleast_1d(item)
            if len(select) == 0:
                select = np.array([], int)
        return select

    def _getitem(
        self,
        item,
        reason="IGNORED",
        copy=True,
        drop_event_id=True,
        select_data=True,
        return_indices=False,
    ):
        """
        Select epochs from current object.

        Parameters
        ----------
        item: slice, array-like, str, or list
            see `__getitem__` for details.
        reason: str, list/tuple of str
            entry in `drop_log` for unselected epochs
        copy: bool
            return a copy of the current object
        drop_event_id: bool
            remove non-existing event-ids after selection
        select_data: bool
            apply selection to data
            (use `select_data=False` if subclasses do not have a
             valid `_data` field, or data has already been subselected)
        return_indices: bool
            return the indices of selected epochs from the original object
            in addition to the new `Epochs` objects

        Returns
        -------
        `Epochs` or tuple(Epochs, np.ndarray) if `return_indices` is True
            subset of epochs (and optionally array with kept epoch indices)
        """
        inst = self.copy() if copy else self
        if self._data is not None:
            np.copyto(inst._data, self._data, casting="no")
        del self

        select = inst._item_to_select(item)
        has_selection = hasattr(inst, "selection")
        if has_selection:
            key_selection = inst.selection[select]
            drop_log = list(inst.drop_log)
            if reason is not None:
                _validate_type(reason, (list, tuple, str), "reason")
                if isinstance(reason, list | tuple):
                    for r in reason:
                        _validate_type(r, str, r)
                if isinstance(reason, str):
                    reason = (reason,)
                reason = tuple(reason)
                for idx in np.setdiff1d(inst.selection, key_selection):
                    drop_log[idx] = reason
            inst.drop_log = tuple(drop_log)
            inst.selection = key_selection
            del drop_log

        inst.events = np.atleast_2d(inst.events[select])
        if inst.metadata is not None:
            pd = _check_pandas_installed(strict=False)
            if pd:
                metadata = inst.metadata.iloc[select]
                if has_selection:
                    metadata.index = inst.selection
            else:
                metadata = np.array(inst.metadata, "object")[select].tolist()

            # will reset the index for us
            GetEpochsMixin.metadata.fset(inst, metadata, verbose=False)
        if inst.preload and select_data:
            # ensure that each Epochs instance owns its own data so we can
            # resize later if necessary
            inst._data = np.require(inst._data[select], requirements=["O"])
        if drop_event_id:
            # update event id to reflect new content of inst
            inst.event_id = {
                k: v for k, v in inst.event_id.items() if v in inst.events[:, 2]
            }

        if return_indices:
            return inst, select
        else:
            return inst

    def _keys_to_idx(self, keys):
        """Find entries in event dict."""
        from ..event import match_event_names  # avoid circular import

        keys = keys if isinstance(keys, list | tuple) else [keys]
        try:
            # Assume it's a condition name
            return np.where(
                np.any(
                    np.array(
                        [
                            self.events[:, 2] == self.event_id[k]
                            for k in match_event_names(self.event_id, keys)
                        ]
                    ),
                    axis=0,
                )
            )[0]
        except KeyError as err:
            # Could we in principle use metadata with these Epochs and keys?
            if len(keys) != 1 or self.metadata is None:
                # If not, raise original error
                raise
            msg = str(err.args[0])  # message for KeyError
            pd = _check_pandas_installed(strict=False)
            # See if the query can be done
            if pd:
                md = self.metadata if hasattr(self, "_metadata") else None
                self._check_metadata(metadata=md)
                try:
                    # Try metadata
                    vals = (
                        self.metadata.reset_index()
                        .query(keys[0], engine="python")
                        .index.values
                    )
                except Exception as exp:
                    msg += (
                        " The epochs.metadata Pandas query did not "
                        f"yield any results: {exp.args[0]}"
                    )
                else:
                    return vals
            else:
                # If not, warn this might be a problem
                msg += (
                    " The epochs.metadata Pandas query could not "
                    "be performed, consider installing Pandas."
                )
            raise KeyError(msg)

    def __len__(self):
        """Return the number of epochs.

        Returns
        -------
        n_epochs : int
            The number of remaining epochs.

        Notes
        -----
        This function only works if bad epochs have been dropped.

        Examples
        --------
        This can be used as::

            >>> epochs.drop_bad()  # doctest: +SKIP
            >>> len(epochs)  # doctest: +SKIP
            43
            >>> len(epochs.events)  # doctest: +SKIP
            43
        """
        from ..epochs import BaseEpochs

        if isinstance(self, BaseEpochs) and not self._bad_dropped:
            raise RuntimeError(
                "Since bad epochs have not been dropped, the "
                "length of the Epochs is not known. Load the "
                "Epochs with preload=True, or call "
                "Epochs.drop_bad(). To find the number "
                "of events in the Epochs, use "
                "len(Epochs.events)."
            )
        return len(self.events)

    def __iter__(self):
        """Facilitate iteration over epochs.

        This method resets the object iteration state to the first epoch.

        Notes
        -----
        This enables the use of this Python pattern::

            >>> for epoch in epochs:  # doctest: +SKIP
            >>>     print(epoch)  # doctest: +SKIP

        Where ``epoch`` is given by successive outputs of
        :meth:`mne.Epochs.next`.
        """
        self._current = 0
        self._current_detrend_picks = self._detrend_picks
        return self

    def __next__(self, return_event_id=False):
        """Iterate over epoch data.

        Parameters
        ----------
        return_event_id : bool
            If True, return both the epoch data and an event_id.

        Returns
        -------
        epoch : array of shape (n_channels, n_times)
            The epoch data.
        event_id : int
            The event id. Only returned if ``return_event_id`` is ``True``.
        """
        if not hasattr(self, "_current_detrend_picks"):
            self.__iter__()  # ensure we're ready to iterate
        if self.preload:
            if self._current >= len(self._data):
                self._stop_iter()
            epoch = self._data[self._current]
            self._current += 1
        else:
            is_good = False
            while not is_good:
                if self._current >= len(self.events):
                    self._stop_iter()
                epoch_noproj = self._get_epoch_from_raw(self._current)
                epoch_noproj = self._detrend_offset_decim(
                    epoch_noproj, self._current_detrend_picks
                )
                epoch = self._project_epoch(epoch_noproj)
                self._current += 1
                is_good, _ = self._is_good_epoch(epoch)
            # If delayed-ssp mode, pass 'virgin' data after rejection decision.
            if self._do_delayed_proj:
                epoch = epoch_noproj

        if not return_event_id:
            return epoch
        else:
            return epoch, self.events[self._current - 1][-1]

    def _stop_iter(self):
        del self._current
        del self._current_detrend_picks
        raise StopIteration  # signal the end

    next = __next__  # originally for Python2, now b/c public

    def _check_metadata(self, metadata=None, reset_index=False):
        """Check metadata consistency."""
        # reset_index=False will not copy!
        if metadata is None:
            return
        else:
            pd = _check_pandas_installed(strict=False)
            if pd:
                _validate_type(metadata, types=pd.DataFrame, item_name="metadata")
                if len(metadata) != len(self.events):
                    raise ValueError(
                        "metadata must have the same number of "
                        f"rows ({len(metadata)}) as events ({len(self.events)})"
                    )
                if reset_index:
                    if hasattr(self, "selection"):
                        # makes a copy
                        metadata = metadata.reset_index(drop=True)
                        metadata.index = self.selection
                    else:
                        metadata = deepcopy(metadata)
            else:
                _validate_type(metadata, types=list, item_name="metadata")
                if reset_index:
                    metadata = deepcopy(metadata)
        return metadata

    @property
    def metadata(self):
        """Get the metadata."""
        return self._metadata

    @metadata.setter
    @verbose
    def metadata(self, metadata, verbose=None):
        metadata = self._check_metadata(metadata, reset_index=True)
        if metadata is not None:
            if _check_pandas_installed(strict=False):
                n_col = metadata.shape[1]
            else:
                n_col = len(metadata[0])
            n_col = f" with {n_col} columns"
        else:
            n_col = ""
        if hasattr(self, "_metadata") and self._metadata is not None:
            action = "Removing" if metadata is None else "Replacing"
            action += " existing"
        else:
            action = "Not setting" if metadata is None else "Adding"
        logger.info(f"{action} metadata{n_col}")
        self._metadata = metadata


def _check_decim(info, decim, offset, check_filter=True):
    """Check decimation parameters."""
    if decim < 1 or decim != int(decim):
        raise ValueError("decim must be an integer > 0")
    decim = int(decim)
    new_sfreq = info["sfreq"] / float(decim)
    offset = int(offset)
    if not 0 <= offset < decim:
        raise ValueError(
            f"decim must be at least 0 and less than {decim}, got {offset}"
        )
    if check_filter:
        lowpass = info["lowpass"]
        if decim > 1 and lowpass is None:
            warn(
                "The measurement information indicates data is not low-pass "
                f"filtered. The decim={decim} parameter will result in a "
                f"sampling frequency of {new_sfreq} Hz, which can cause "
                "aliasing artifacts."
            )
        elif decim > 1 and new_sfreq < 3 * lowpass:
            warn(
                "The measurement information indicates a low-pass frequency "
                f"of {lowpass} Hz. The decim={decim} parameter will result "
                f"in a sampling frequency of {new_sfreq} Hz, which can "
                "cause aliasing artifacts."
            )  # > 50% nyquist lim
    return decim, offset, new_sfreq


class TimeMixin:
    """Class for time operations on any MNE object that has a time axis."""

    def time_as_index(self, times, use_rounding=False):
        """Convert time to indices.

        Parameters
        ----------
        times : list-like | float | int
            List of numbers or a number representing points in time.
        use_rounding : bool
            If True, use rounding (instead of truncation) when converting
            times to indices. This can help avoid non-unique indices.

        Returns
        -------
        index : ndarray
            Indices corresponding to the times supplied.
        """
        from ..source_estimate import _BaseSourceEstimate

        if isinstance(self, _BaseSourceEstimate):
            sfreq = 1.0 / self.tstep
        else:
            sfreq = self.info["sfreq"]
        index = (np.atleast_1d(times) - self.times[0]) * sfreq
        if use_rounding:
            index = np.round(index)
        return index.astype(int)

    def _handle_tmin_tmax(self, tmin, tmax):
        """Convert seconds to index into data.

        Parameters
        ----------
        tmin : int | float | None
            Start time of data to get in seconds.
        tmax : int | float | None
            End time of data to get in seconds.

        Returns
        -------
        start : int
            Integer index into data corresponding to tmin.
        stop : int
            Integer index into data corresponding to tmax.

        """
        _validate_type(
            tmin,
            types=("numeric", None),
            item_name="tmin",
            type_name="int, float, None",
        )
        _validate_type(
            tmax,
            types=("numeric", None),
            item_name="tmax",
            type_name="int, float, None",
        )

        # handle tmin/tmax as start and stop indices into data array
        n_times = self.times.size
        start = 0 if tmin is None else self.time_as_index(tmin)[0]
        stop = n_times if tmax is None else self.time_as_index(tmax)[0]

        # truncate start/stop to the open interval [0, n_times]
        start = min(max(0, start), n_times)
        stop = min(max(0, stop), n_times)

        return start, stop

    @property
    def times(self):
        """Time vector in seconds."""
        return self._times_readonly

    def _set_times(self, times):
        """Set self._times_readonly (and make it read only)."""
        # naming used to indicate that it shouldn't be
        # changed directly, but rather via this method
        self._times_readonly = times.copy()
        self._times_readonly.flags["WRITEABLE"] = False


class ExtendedTimeMixin(TimeMixin):
    """Class for time operations on epochs/evoked-like MNE objects."""

    @property
    def tmin(self):
        """First time point."""
        return self.times[0]

    @property
    def tmax(self):
        """Last time point."""
        return self.times[-1]

    @verbose
    def crop(self, tmin=None, tmax=None, include_tmax=True, verbose=None):
        """Crop data to a given time interval.

        Parameters
        ----------
        tmin : float | None
            Start time of selection in seconds.
        tmax : float | None
            End time of selection in seconds.
        %(include_tmax)s
        %(verbose)s

        Returns
        -------
        inst : instance of Raw, Epochs, Evoked, AverageTFR, or SourceEstimate
            The cropped time-series object, modified in-place.

        Notes
        -----
        %(notes_tmax_included_by_default)s
        """
        t_vars = dict(tmin=tmin, tmax=tmax)
        for name, t_var in t_vars.items():
            _validate_type(
                t_var,
                types=("numeric", None),
                item_name=name,
            )

        if tmin is None:
            tmin = self.tmin
        elif tmin < self.tmin:
            warn(
                f"tmin is not in time interval. tmin is set to "
                f"{type(self)}.tmin ({self.tmin:g} s)"
            )
            tmin = self.tmin

        if tmax is None:
            tmax = self.tmax
        elif tmax > self.tmax:
            warn(
                f"tmax is not in time interval. tmax is set to "
                f"{type(self)}.tmax ({self.tmax:g} s)"
            )
            tmax = self.tmax
            include_tmax = True

        mask = _time_mask(
            self.times, tmin, tmax, sfreq=self.info["sfreq"], include_tmax=include_tmax
        )
        self._set_times(self.times[mask])
        self._raw_times = self._raw_times[mask]
        self._update_first_last()
        self._data = self._data[..., mask]

        return self

    @verbose
    def decimate(self, decim, offset=0, *, verbose=None):
        """Decimate the time-series data.

        Parameters
        ----------
        %(decim)s
        %(offset_decim)s
        %(verbose)s

        Returns
        -------
        inst : MNE-object
            The decimated object.

        See Also
        --------
        mne.Epochs.resample
        mne.io.Raw.resample

        Notes
        -----
        %(decim_notes)s

        If ``decim`` is 1, this method does not copy the underlying data.

        .. versionadded:: 0.10.0

        References
        ----------
        .. footbibliography::
        """
        # if epochs have frequencies, they are not in time (EpochsTFR)
        # and so do not need to be checked whether they have been
        # appropriately filtered to avoid aliasing
        from ..epochs import BaseEpochs
        from ..evoked import Evoked
        from ..time_frequency import BaseTFR

        # This should be the list of classes that inherit
        _validate_type(self, (BaseEpochs, Evoked, BaseTFR), "inst")
        decim, offset, new_sfreq = _check_decim(
            self.info, decim, offset, check_filter=not hasattr(self, "freqs")
        )
        start_idx = int(round(-self._raw_times[0] * (self.info["sfreq"] * self._decim)))
        self._decim *= decim
        i_start = start_idx % self._decim + offset
        decim_slice = slice(i_start, None, self._decim)
        with self.info._unlock():
            self.info["sfreq"] = new_sfreq

        if self.preload:
            if decim != 1:
                self._data = self._data[..., decim_slice].copy()
                self._raw_times = self._raw_times[decim_slice].copy()
            else:
                self._data = np.ascontiguousarray(self._data)
            self._decim_slice = slice(None)
            self._decim = 1
        else:
            self._decim_slice = decim_slice
        self._set_times(self._raw_times[self._decim_slice])
        self._update_first_last()
        return self

    def shift_time(self, tshift, relative=True):
        """Shift time scale in epoched or evoked data.

        Parameters
        ----------
        tshift : float
            The (absolute or relative) time shift in seconds. If ``relative``
            is True, positive tshift increases the time value associated with
            each sample, while negative tshift decreases it.
        relative : bool
            If True, increase or decrease time values by ``tshift`` seconds.
            Otherwise, shift the time values such that the time of the first
            sample equals ``tshift``.

        Returns
        -------
        epochs : MNE-object
            The modified instance.

        Notes
        -----
        This method allows you to shift the *time* values associated with each
        data sample by an arbitrary amount. It does *not* resample the signal
        or change the *data* values in any way.
        """
        _check_preload(self, "shift_time")
        start = tshift + (self.times[0] if relative else 0.0)
        new_times = start + np.arange(len(self.times)) / self.info["sfreq"]
        self._set_times(new_times)
        self._update_first_last()
        return self

    def _update_first_last(self):
        """Update self.first and self.last (sample indices)."""
        from ..dipole import DipoleFixed
        from ..evoked import Evoked

        if isinstance(self, Evoked | DipoleFixed):
            self.first = int(round(self.times[0] * self.info["sfreq"]))
            self.last = len(self.times) + self.first - 1


def _prepare_write_metadata(metadata):
    """Convert metadata to JSON for saving."""
    if metadata is not None:
        if not isinstance(metadata, list):
            metadata = metadata.reset_index().to_json(orient="records")
        else:  # Pandas DataFrame
            metadata = json.dumps(metadata)
        assert isinstance(metadata, str)
    return metadata


def _prepare_read_metadata(metadata):
    """Convert saved metadata back from JSON."""
    if metadata is not None:
        pd = _check_pandas_installed(strict=False)
        # use json.loads because this preserves ordering
        # (which is necessary for round-trip equivalence)
        metadata = json.loads(metadata, object_pairs_hook=OrderedDict)
        assert isinstance(metadata, list)
        if pd:
            metadata = pd.DataFrame.from_records(metadata)
            if "index" in metadata.columns:
                metadata.set_index("index", inplace=True)
            assert isinstance(metadata, pd.DataFrame)
    return metadata
