
|
from math import floor, ceil
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
Wavelet, _check_dtype)
from ._functions import integrate_wavelet, scale2frequency
__all__ = ["cwt"]
import numpy as np
try:
# Prefer scipy.fft (new in SciPy 1.4)
import scipy.fft
fftmodule = scipy.fft
next_fast_len = fftmodule.next_fast_len
except ImportError:
try:
import scipy.fftpack
fftmodule = scipy.fftpack
next_fast_len = fftmodule.next_fast_len
except ImportError:
fftmodule = np.fft
# provide a fallback so scipy is an optional requirement
def next_fast_len(n):
"""Round up size to the nearest power of two.
Given a number of samples `n`, returns the next power of two
following this number to take advantage of FFT speedup.
This fallback is less efficient than `scipy.fftpack.next_fast_len`
"""
return 2**ceil(np.log2(n))
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
"""
cwt(data, scales, wavelet)
One dimensional Continuous Wavelet Transform.
Parameters
----------
data : array_like
Input signal
scales : array_like
The wavelet scales to use. One can use
``f = scale2frequency(wavelet, scale)/sampling_period`` to determine
what physical frequency, ``f``. Here, ``f`` is in hertz when the
``sampling_period`` is given in seconds.
wavelet : Wavelet object or name
Wavelet to use
sampling_period : float
Sampling period for the frequencies output (optional).
The values computed for ``coefs`` are independent of the choice of
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
period).
method : {'conv', 'fft'}, optional
The method used to compute the CWT. Can be any of:
- ``conv`` uses ``numpy.convolve``.
- ``fft`` uses frequency domain convolution.
- ``auto`` uses automatic selection based on an estimate of the
computational complexity at each scale.
The ``conv`` method complexity is ``O(len(scale) * len(data))``.
The ``fft`` method is ``O(N * log2(N))`` with
``N = len(scale) + len(data) - 1``. It is well suited for large size
signals but slightly slower than ``conv`` on small ones.
axis: int, optional
Axis over which to compute the CWT. If not given, the last axis is
used.
Returns
-------
coefs : array_like
Continuous wavelet transform of the input signal for the given scales
and wavelet. The first axis of ``coefs`` corresponds to the scales.
The remaining axes match the shape of ``data``.
frequencies : array_like
If the unit of sampling period are seconds and given, than frequencies
are in hertz. Otherwise, a sampling period of 1 is assumed.
Notes
-----
Size of coefficients arrays depends on the length of the input array and
the length of given scales.
Examples
--------
>>> import pywt
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> x = np.arange(512)
>>> y = np.sin(2*np.pi*x/32)
>>> coef, freqs=pywt.cwt(y,np.arange(1,129),'gaus1')
>>> plt.matshow(coef) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
----------
>>> import pywt
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> t = np.linspace(-1, 1, 200, endpoint=False)
>>> sig = np.cos(2 * np.pi * 7 * t) + np.real(np.exp(-7*(t-0.4)**2)*np.exp(1j*2*np.pi*2*(t-0.4)))
>>> widths = np.arange(1, 31)
>>> cwtmatr, freqs = pywt.cwt(sig, widths, 'mexh')
>>> plt.imshow(cwtmatr, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max()) # doctest: +SKIP
>>> plt.show() # doctest: +SKIP
"""
# accept array_like input; make a copy to ensure a contiguous array
dt = _check_dtype(data)
data = np.asarray(data, dtype=dt)
dt_cplx = np.result_type(dt, np.complex64)
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if np.isscalar(scales):
scales = np.array([scales])
if not np.isscalar(axis):
raise np.AxisError("axis must be a scalar.")
dt_out = dt_cplx if wavelet.complex_cwt else dt
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
precision = 10
int_psi, x = integrate_wavelet(wavelet, precision=precision)
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
# convert int_psi, x to the same precision as the data
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
int_psi = np.asarray(int_psi, dtype=dt_psi)
x = np.asarray(x, dtype=data.real.dtype)
if method == 'fft':
size_scale0 = -1
fft_data = None
elif not method == 'conv':
raise ValueError("method must be 'conv' or 'fft'")
if data.ndim > 1:
# move axis to be transformed last (so it is contiguous)
data = data.swapaxes(-1, axis)
# reshape to (n_batch, data.shape[-1])
data_shape_pre = data.shape
data = data.reshape((-1, data.shape[-1]))
for i, scale in enumerate(scales):
step = x[1] - x[0]
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
j = j.astype(int) # floor
if j[-1] >= int_psi.size:
j = np.extract(j < int_psi.size, j)
int_psi_scale = int_psi[j][::-1]
if method == 'conv':
if data.ndim == 1:
conv = np.convolve(data, int_psi_scale)
else:
# batch convolution via loop
conv_shape = list(data.shape)
conv_shape[-1] += int_psi_scale.size - 1
conv_shape = tuple(conv_shape)
conv = np.empty(conv_shape, dtype=dt_out)
for n in range(data.shape[0]):
conv[n, :] = np.convolve(data[n], int_psi_scale)
else:
# The padding is selected for:
# - optimal FFT complexity
# - to be larger than the two signals length to avoid circular
# convolution
size_scale = next_fast_len(
data.shape[-1] + int_psi_scale.size - 1
)
if size_scale != size_scale0:
# Must recompute fft_data when the padding size changes.
fft_data = fftmodule.fft(data, size_scale, axis=-1)
size_scale0 = size_scale
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
if out.dtype.kind != 'c':
coef = coef.real
# transform axis is always -1 due to the data reshape above
d = (coef.shape[-1] - data.shape[-1]) / 2.
if d > 0:
coef = coef[..., floor(d):-ceil(d)]
elif d < 0:
raise ValueError(
"Selected scale of {} too small.".format(scale))
if data.ndim > 1:
# restore original data shape and axis position
coef = coef.reshape(data_shape_pre)
coef = coef.swapaxes(axis, -1)
out[i, ...] = coef
frequencies = scale2frequency(wavelet, scales, precision)
if np.isscalar(frequencies):
frequencies = np.array([frequencies])
frequencies /= sampling_period
return out, frequencies
|