1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
# Copyright (c) 2017 The PyWavelets Developers
# <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
import inspect
import numpy as np
from collections.abc import Iterable
from ._extensions._pywt import (Wavelet, ContinuousWavelet,
DiscreteContinuousWavelet, Modes)
def _as_wavelet(wavelet):
"""Convert wavelet name to a Wavelet object."""
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
wavelet = DiscreteContinuousWavelet(wavelet)
if isinstance(wavelet, ContinuousWavelet):
raise ValueError(
"A ContinuousWavelet object was provided, but only discrete "
"Wavelet objects are supported by this function. A list of all "
"supported discrete wavelets can be obtained by running:\n"
"print(pywt.wavelist(kind='discrete'))")
return wavelet
def _wavelets_per_axis(wavelet, axes):
"""Initialize Wavelets for each axis to be transformed.
Parameters
----------
wavelet : Wavelet or tuple of Wavelets
If a single Wavelet is provided, it will used for all axes. Otherwise
one Wavelet per axis must be provided.
axes : list
The tuple of axes to be transformed.
Returns
-------
wavelets : list of Wavelet objects
A tuple of Wavelets equal in length to ``axes``.
"""
axes = tuple(axes)
if isinstance(wavelet, (str, Wavelet)):
# same wavelet on all axes
wavelets = [_as_wavelet(wavelet), ] * len(axes)
elif isinstance(wavelet, Iterable):
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
if len(wavelet) == 1:
wavelets = [_as_wavelet(wavelet[0]), ] * len(axes)
else:
if len(wavelet) != len(axes):
raise ValueError((
"The number of wavelets must match the number of axes "
"to be transformed."))
wavelets = [_as_wavelet(w) for w in wavelet]
else:
raise ValueError("wavelet must be a str, Wavelet or iterable")
return wavelets
def _modes_per_axis(modes, axes):
"""Initialize mode for each axis to be transformed.
Parameters
----------
modes : str or tuple of strings
If a single mode is provided, it will used for all axes. Otherwise
one mode per axis must be provided.
axes : tuple
The tuple of axes to be transformed.
Returns
-------
modes : tuple of int
A tuple of Modes equal in length to ``axes``.
"""
axes = tuple(axes)
if isinstance(modes, (int, str)):
# same wavelet on all axes
modes = [Modes.from_object(modes), ] * len(axes)
elif isinstance(modes, Iterable):
if len(modes) == 1:
modes = [Modes.from_object(modes[0]), ] * len(axes)
else:
# (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
if len(modes) != len(axes):
raise ValueError(("The number of modes must match the number "
"of axes to be transformed."))
modes = [Modes.from_object(mode) for mode in modes]
else:
raise ValueError("modes must be a str, Mode enum or iterable")
return modes
|