1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
# Author: Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from .mixin import TransformerMixin
from .base import BaseEstimator
from ..time_frequency.tfr import _compute_tfr, _check_tfr_param
class TimeFrequency(TransformerMixin, BaseEstimator):
"""Time frequency transformer.
Time-frequency transform of times series along the last axis.
Parameters
----------
freqs : array-like of floats, shape (n_freqs,)
The frequencies.
sfreq : float | int, defaults to 1.0
Sampling frequency of the data.
method : 'multitaper' | 'morlet', defaults to 'morlet'
The time-frequency method. 'morlet' convolves a Morlet wavelet.
'multitaper' uses Morlet wavelets windowed with multiple DPSS
multitapers.
n_cycles : float | array of float, defaults to 7.0
Number of cycles in the Morlet wavelet. Fixed number
or one per frequency.
time_bandwidth : float, defaults to None
If None and method=multitaper, will be set to 4.0 (3 tapers).
Time x (Full) Bandwidth product. Only applies if
method == 'multitaper'. The number of good tapers (low-bias) is
chosen automatically based on this to equal floor(time_bandwidth - 1).
use_fft : bool, defaults to True
Use the FFT for convolutions or not.
decim : int | slice, defaults to 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note:: Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str, defaults to 'complex'
* 'complex' : single trial complex.
* 'power' : single trial power.
* 'phase' : single trial phase.
n_jobs : int, defaults to 1
The number of epochs to process at the same time. The parallelization
is implemented across channels.
verbose : bool, str, int, or None, defaults to None
If not None, override default verbose level (see :func:`mne.verbose`
and :ref:`Logging documentation <tut_logging>` for more).
See Also
--------
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_multitaper
"""
def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0,
time_bandwidth=None, use_fft=True, decim=1, output='complex',
n_jobs=1, verbose=None): # noqa: D102
"""Init TimeFrequency transformer."""
freqs, sfreq, _, n_cycles, time_bandwidth, decim = \
_check_tfr_param(freqs, sfreq, method, True, n_cycles,
time_bandwidth, use_fft, decim, output)
self.freqs = freqs
self.sfreq = sfreq
self.method = method
self.n_cycles = n_cycles
self.time_bandwidth = time_bandwidth
self.use_fft = use_fft
self.decim = decim
# Check that output is not an average metric (e.g. ITC)
if output not in ['complex', 'power', 'phase']:
raise ValueError("output must be 'complex', 'power', 'phase'. "
"Got %s instead." % output)
self.output = output
self.n_jobs = n_jobs
self.verbose = verbose
def fit_transform(self, X, y=None):
"""Time-frequency transform of times series along the last axis.
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data samples. The channel dimension can be zero- or
1-dimensional.
y : None
For scikit-learn compatibility purposes.
Returns
-------
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
The time-frequency transform of the data, where n_channels can be
zero- or 1-dimensional.
"""
return self.fit(X, y).transform(X)
def fit(self, X, y=None): # noqa: D401
"""Do nothing (for scikit-learn compatibility purposes).
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data.
y : array | None
The target values.
Returns
-------
self : object
Return self.
"""
return self
def transform(self, X):
"""Time-frequency transform of times series along the last axis.
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data samples. The channel dimension can be zero- or
1-dimensional.
Returns
-------
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
The time-frequency transform of the data, where n_channels can be
zero- or 1-dimensional.
"""
# Ensure 3-dimensional X
shape = X.shape[1:-1]
if not shape:
X = X[:, np.newaxis, :]
# Compute time-frequency
Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method,
self.n_cycles, True, self.time_bandwidth,
self.use_fft, self.decim, self.output, self.n_jobs,
self.verbose)
# Back to original shape
if not shape:
Xt = Xt[:, 0, :]
return Xt
|