File: _utils.py

package info (click to toggle)
pywavelets 1.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,680 kB
  • sloc: python: 8,849; ansic: 5,134; makefile: 93
file content (93 lines) | stat: -rw-r--r-- 3,296 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) 2017 The PyWavelets Developers
#                    <https://github.com/PyWavelets/pywt>
# See COPYING for license details.
import inspect
import numpy as np
from collections.abc import Iterable

from ._extensions._pywt import (Wavelet, ContinuousWavelet,
                                DiscreteContinuousWavelet, Modes)


def _as_wavelet(wavelet):
    """Convert wavelet name to a Wavelet object."""
    if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
        wavelet = DiscreteContinuousWavelet(wavelet)
    if isinstance(wavelet, ContinuousWavelet):
        raise ValueError(
            "A ContinuousWavelet object was provided, but only discrete "
            "Wavelet objects are supported by this function.  A list of all "
            "supported discrete wavelets can be obtained by running:\n"
            "print(pywt.wavelist(kind='discrete'))")
    return wavelet


def _wavelets_per_axis(wavelet, axes):
    """Initialize Wavelets for each axis to be transformed.

    Parameters
    ----------
    wavelet : Wavelet or tuple of Wavelets
        If a single Wavelet is provided, it will used for all axes.  Otherwise
        one Wavelet per axis must be provided.
    axes : list
        The tuple of axes to be transformed.

    Returns
    -------
    wavelets : list of Wavelet objects
        A tuple of Wavelets equal in length to ``axes``.

    """
    axes = tuple(axes)
    if isinstance(wavelet, (str, Wavelet)):
        # same wavelet on all axes
        wavelets = [_as_wavelet(wavelet), ] * len(axes)
    elif isinstance(wavelet, Iterable):
        # (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
        if len(wavelet) == 1:
            wavelets = [_as_wavelet(wavelet[0]), ] * len(axes)
        else:
            if len(wavelet) != len(axes):
                raise ValueError((
                    "The number of wavelets must match the number of axes "
                    "to be transformed."))
            wavelets = [_as_wavelet(w) for w in wavelet]
    else:
        raise ValueError("wavelet must be a str, Wavelet or iterable")
    return wavelets


def _modes_per_axis(modes, axes):
    """Initialize mode for each axis to be transformed.

    Parameters
    ----------
    modes : str or tuple of strings
        If a single mode is provided, it will used for all axes.  Otherwise
        one mode per axis must be provided.
    axes : tuple
        The tuple of axes to be transformed.

    Returns
    -------
    modes : tuple of int
        A tuple of Modes equal in length to ``axes``.

    """
    axes = tuple(axes)
    if isinstance(modes, (int, str)):
        # same wavelet on all axes
        modes = [Modes.from_object(modes), ] * len(axes)
    elif isinstance(modes, Iterable):
        if len(modes) == 1:
            modes = [Modes.from_object(modes[0]), ] * len(axes)
        else:
            # (potentially) unique wavelet per axis (e.g. for dual-tree DWT)
            if len(modes) != len(axes):
                raise ValueError(("The number of modes must match the number "
                                  "of axes to be transformed."))
        modes = [Modes.from_object(mode) for mode in modes]
    else:
        raise ValueError("modes must be a str, Mode enum or iterable")
    return modes