1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
# Author: Jean-Remi King <jeanremi.king@gmail.com>
#
# License: BSD (3-clause)
import numpy as np
from .mixin import TransformerMixin
from .base import BaseEstimator
from ..time_frequency.tfr import _compute_tfr, _check_tfr_param
from ..utils import fill_doc, _check_option
@fill_doc
class TimeFrequency(TransformerMixin, BaseEstimator):
"""Time frequency transformer.
Time-frequency transform of times series along the last axis.
Parameters
----------
freqs : array-like of float, shape (n_freqs,)
The frequencies.
sfreq : float | int, default 1.0
Sampling frequency of the data.
method : 'multitaper' | 'morlet', default 'morlet'
The time-frequency method. 'morlet' convolves a Morlet wavelet.
'multitaper' uses Morlet wavelets windowed with multiple DPSS
multitapers.
n_cycles : float | array of float, default 7.0
Number of cycles in the Morlet wavelet. Fixed number
or one per frequency.
time_bandwidth : float, default None
If None and method=multitaper, will be set to 4.0 (3 tapers).
Time x (Full) Bandwidth product. Only applies if
method == 'multitaper'. The number of good tapers (low-bias) is
chosen automatically based on this to equal floor(time_bandwidth - 1).
use_fft : bool, default True
Use the FFT for convolutions or not.
decim : int | slice, default 1
To reduce memory usage, decimation factor after time-frequency
decomposition.
If `int`, returns tfr[..., ::decim].
If `slice`, returns tfr[..., decim].
.. note:: Decimation may create aliasing artifacts, yet decimation
is done after the convolutions.
output : str, default 'complex'
* 'complex' : single trial complex.
* 'power' : single trial power.
* 'phase' : single trial phase.
%(n_jobs)s
The number of epochs to process at the same time. The parallelization
is implemented across channels.
%(verbose)s
See Also
--------
mne.time_frequency.tfr_morlet
mne.time_frequency.tfr_multitaper
"""
def __init__(self, freqs, sfreq=1.0, method='morlet', n_cycles=7.0,
time_bandwidth=None, use_fft=True, decim=1, output='complex',
n_jobs=1, verbose=None): # noqa: D102
"""Init TimeFrequency transformer."""
freqs, sfreq, _, n_cycles, time_bandwidth, decim = \
_check_tfr_param(freqs, sfreq, method, True, n_cycles,
time_bandwidth, use_fft, decim, output)
self.freqs = freqs
self.sfreq = sfreq
self.method = method
self.n_cycles = n_cycles
self.time_bandwidth = time_bandwidth
self.use_fft = use_fft
self.decim = decim
# Check that output is not an average metric (e.g. ITC)
_check_option('output', output, ['complex', 'power', 'phase'])
self.output = output
self.n_jobs = n_jobs
self.verbose = verbose
def fit_transform(self, X, y=None):
"""Time-frequency transform of times series along the last axis.
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data samples. The channel dimension can be zero- or
1-dimensional.
y : None
For scikit-learn compatibility purposes.
Returns
-------
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
The time-frequency transform of the data, where n_channels can be
zero- or 1-dimensional.
"""
return self.fit(X, y).transform(X)
def fit(self, X, y=None): # noqa: D401
"""Do nothing (for scikit-learn compatibility purposes).
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data.
y : array | None
The target values.
Returns
-------
self : object
Return self.
"""
return self
def transform(self, X):
"""Time-frequency transform of times series along the last axis.
Parameters
----------
X : array, shape (n_samples, n_channels, n_times)
The training data samples. The channel dimension can be zero- or
1-dimensional.
Returns
-------
Xt : array, shape (n_samples, n_channels, n_freqs, n_times)
The time-frequency transform of the data, where n_channels can be
zero- or 1-dimensional.
"""
# Ensure 3-dimensional X
shape = X.shape[1:-1]
if not shape:
X = X[:, np.newaxis, :]
# Compute time-frequency
Xt = _compute_tfr(X, self.freqs, self.sfreq, self.method,
self.n_cycles, True, self.time_bandwidth,
self.use_fft, self.decim, self.output, self.n_jobs,
self.verbose)
# Back to original shape
if not shape:
Xt = Xt[:, 0, :]
return Xt
|