File: time_frequency.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (151 lines) | stat: -rw-r--r-- 5,436 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Author: Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)

import numpy as np
from .mixin import TransformerMixin
from .base import BaseEstimator
from ..time_frequency.tfr import _compute_tfr, _check_tfr_param


class TimeFrequency(TransformerMixin, BaseEstimator):
    """Time frequency transformer.

    Time-frequency transform of times series along the last axis.

    Parameters
    ----------
    freqs : array-like of floats, shape (n_freqs,)
        The frequencies.
    sfreq : float | int, defaults to 1.0
        Sampling frequency of the data.
    method : 'multitaper' | 'morlet', defaults to 'morlet'
        The time-frequency method. 'morlet' convolves a Morlet wavelet.
        'multitaper' uses Morlet wavelets windowed with multiple DPSS
        multitapers.
    n_cycles : float | array of float, defaults to 7.0
        Number of cycles  in the Morlet wavelet. Fixed number
        or one per frequency.
    time_bandwidth : float, defaults to None
        If None and method=multitaper, will be set to 4.0 (3 tapers).
        Time x (Full) Bandwidth product. Only applies if
        method == 'multitaper'. The number of good tapers (low-bias) is
        chosen automatically based on this to equal floor(time_bandwidth - 1).
    use_fft : bool, defaults to True
        Use the FFT for convolutions or not.
    decim : int | slice, defaults to 1
        To reduce memory usage, decimation factor after time-frequency
        decomposition.
        If `int`, returns tfr[..., ::decim].
        If `slice`, returns tfr[..., decim].

        .. note:: Decimation may create aliasing artifacts, yet decimation
                  is done after the convolutions.

    output : str, defaults to 'complex'
        * 'complex' : single trial complex.
        * 'power' : single trial power.
        * 'phase' : single trial phase.
    n_jobs : int, defaults to 1
        The number of epochs to process at the same time. The parallelization
        is implemented across channels.
    verbose : bool, str, int, or None, defaults to None
        If not None, override default verbose level (see :func:`mne.verbose`
        and :ref:`Logging documentation <tut_logging>` for more).

    See Also
    --------
    mne.time_frequency.tfr_morlet
    mne.time_frequency.tfr_multitaper
    """

    def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0,
                 time_bandwidth=None, use_fft=True, decim=1, output='complex',
                 n_jobs=1, verbose=None):  # noqa: D102
        """Init TimeFrequency transformer."""
        freqs, sfreq, _, n_cycles, time_bandwidth, decim = \
            _check_tfr_param(freqs, sfreq, method, True, n_cycles,
                             time_bandwidth, use_fft, decim, output)
        self.freqs = freqs
        self.sfreq = sfreq
        self.method = method
        self.n_cycles = n_cycles
        self.time_bandwidth = time_bandwidth
        self.use_fft = use_fft
        self.decim = decim
        # Check that output is not an average metric (e.g. ITC)
        if output not in ['complex', 'power', 'phase']:
            raise ValueError("output must be 'complex', 'power', 'phase'. "
                             "Got %s instead." % output)
        self.output = output
        self.n_jobs = n_jobs
        self.verbose = verbose

    def fit_transform(self, X, y=None):
        """Time-frequency transform of times series along the last axis.

        Parameters
        ----------
        X : array, shape (n_samples, n_channels, n_times)
            The training data samples. The channel dimension can be zero- or
            1-dimensional.
        y : None
            For scikit-learn compatibility purposes.

        Returns
        -------
        Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
            The time-frequency transform of the data, where n_channels can be
            zero- or 1-dimensional.
        """
        return self.fit(X, y).transform(X)

    def fit(self, X, y=None):  # noqa: D401
        """Do nothing (for scikit-learn compatibility purposes).

        Parameters
        ----------
        X : array, shape (n_samples, n_channels, n_times)
            The training data.
        y : array | None
            The target values.

        Returns
        -------
        self : object
            Return self.
        """
        return self

    def transform(self, X):
        """Time-frequency transform of times series along the last axis.

        Parameters
        ----------
        X : array, shape (n_samples, n_channels, n_times)
            The training data samples. The channel dimension can be zero- or
            1-dimensional.

        Returns
        -------
        Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
            The time-frequency transform of the data, where n_channels can be
            zero- or 1-dimensional.

        """
        # Ensure 3-dimensional X
        shape = X.shape[1:-1]
        if not shape:
            X = X[:, np.newaxis, :]

        # Compute time-frequency
        Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method,
                          self.n_cycles, True, self.time_bandwidth,
                          self.use_fft, self.decim, self.output, self.n_jobs,
                          self.verbose)

        # Back to original shape
        if not shape:
            Xt = Xt[:, 0, :]

        return Xt