File: _doc_utils.py

package info (click to toggle)
pywavelets 1.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,680 kB
  • sloc: python: 8,849; ansic: 5,134; makefile: 93
file content (187 lines) | stat: -rw-r--r-- 5,823 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Utilities used to generate various figures in the documentation."""
from itertools import product

import numpy as np
from matplotlib import pyplot as plt

from ._dwt import pad

__all__ = ['wavedec_keys', 'wavedec2_keys', 'draw_2d_wp_basis',
           'draw_2d_fswavedecn_basis', 'boundary_mode_subplot']


def wavedec_keys(level):
    """Subband keys corresponding to a wavedec decomposition."""
    approx = ''
    coeffs = {}
    for lev in range(level):
        for k in ['a', 'd']:
            coeffs[approx + k] = None
        approx = 'a' * (lev + 1)
        if lev < level - 1:
            coeffs.pop(approx)
    return list(coeffs.keys())


def wavedec2_keys(level):
    """Subband keys corresponding to a wavedec2 decomposition."""
    approx = ''
    coeffs = {}
    for lev in range(level):
        for k in ['a', 'h', 'v', 'd']:
            coeffs[approx + k] = None
        approx = 'a' * (lev + 1)
        if lev < level - 1:
            coeffs.pop(approx)
    return list(coeffs.keys())


def _box(bl, ur):
    """(x, y) coordinates for the 4 lines making up a rectangular box.

    Parameters
    ==========
    bl : float
        The bottom left corner of the box
    ur : float
        The upper right corner of the box

    Returns
    =======
    coords : 2-tuple
        The first and second elements of the tuple are the x and y coordinates
        of the box.
    """
    xl, xr = bl[0], ur[0]
    yb, yt = bl[1], ur[1]
    box_x = [xl, xr,
             xr, xr,
             xr, xl,
             xl, xl]
    box_y = [yb, yb,
             yb, yt,
             yt, yt,
             yt, yb]
    return (box_x, box_y)


def _2d_wp_basis_coords(shape, keys):
    # Coordinates of the lines to be drawn by draw_2d_wp_basis
    coords = []
    centers = {}  # retain center of boxes for use in labeling
    for key in keys:
        offset_x = offset_y = 0
        for n, char in enumerate(key):
            if char in ['h', 'd']:
                offset_x += shape[0] // 2**(n + 1)
            if char in ['v', 'd']:
                offset_y += shape[1] // 2**(n + 1)
        sx = shape[0] // 2**(n + 1)
        sy = shape[1] // 2**(n + 1)
        xc, yc = _box((offset_x, -offset_y),
                      (offset_x + sx, -offset_y - sy))
        coords.append((xc, yc))
        centers[key] = (offset_x + sx // 2, -offset_y - sy // 2)
    return coords, centers


def draw_2d_wp_basis(shape, keys, fmt='k', plot_kwargs={}, ax=None,
                     label_levels=0):
    """Plot a 2D representation of a WaveletPacket2D basis."""
    coords, centers = _2d_wp_basis_coords(shape, keys)
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    else:
        fig = ax.get_figure()
    for coord in coords:
        ax.plot(coord[0], coord[1], fmt)
    ax.set_axis_off()
    ax.axis('square')
    if label_levels > 0:
        for key, c in centers.items():
            if len(key) <= label_levels:
                ax.text(c[0], c[1], key,
                        horizontalalignment='center',
                        verticalalignment='center')
    return fig, ax


def _2d_fswavedecn_coords(shape, levels):
    coords = []
    centers = {}  # retain center of boxes for use in labeling
    for key in product(wavedec_keys(levels), repeat=2):
        (key0, key1) = key
        offsets = [0, 0]
        widths = list(shape)
        for n0, char in enumerate(key0):
            if char in ['d']:
                offsets[0] += shape[0] // 2**(n0 + 1)
        for n1, char in enumerate(key1):
            if char in ['d']:
                offsets[1] += shape[1] // 2**(n1 + 1)
        widths[0] = shape[0] // 2**(n0 + 1)
        widths[1] = shape[1] // 2**(n1 + 1)
        xc, yc = _box((offsets[0], -offsets[1]),
                      (offsets[0] + widths[0], -offsets[1] - widths[1]))
        coords.append((xc, yc))
        centers[(key0, key1)] = (offsets[0] + widths[0] / 2,
                                 -offsets[1] - widths[1] / 2)
    return coords, centers


def draw_2d_fswavedecn_basis(shape, levels, fmt='k', plot_kwargs={}, ax=None,
                             label_levels=0):
    """Plot a 2D representation of a WaveletPacket2D basis."""
    coords, centers = _2d_fswavedecn_coords(shape, levels)
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    else:
        fig = ax.get_figure()
    for coord in coords:
        ax.plot(coord[0], coord[1], fmt)
    ax.set_axis_off()
    ax.axis('square')
    if label_levels > 0:
        for key, c in centers.items():
            lev = np.max([len(k) for k in key])
            if lev <= label_levels:
                ax.text(c[0], c[1], key,
                        horizontalalignment='center',
                        verticalalignment='center')
    return fig, ax


def boundary_mode_subplot(x, mode, ax, symw=True):
    """Plot an illustration of the boundary mode in a subplot axis."""

    # if odd-length, periodization replicates the last sample to make it even
    if mode == 'periodization' and len(x) % 2 == 1:
        x = np.concatenate((x, (x[-1], )))

    npad = 2 * len(x)
    t = np.arange(len(x) + 2 * npad)
    xp = pad(x, (npad, npad), mode=mode)

    ax.plot(t, xp, 'k.')
    ax.set_title(mode)

    # plot the original signal in red
    if mode == 'periodization':
        ax.plot(t[npad:npad + len(x) - 1], x[:-1], 'r.')
    else:
        ax.plot(t[npad:npad + len(x)], x, 'r.')

    # add vertical bars indicating points of symmetry or boundary extension
    o2 = np.ones(2)
    left = npad
    if symw:
        step = len(x) - 1
        rng = range(-2, 4)
    else:
        left -= 0.5
        step = len(x)
        rng = range(-2, 4)
    if mode in ['smooth', 'constant', 'zero']:
        rng = range(0, 2)
    for rep in rng:
        ax.plot((left + rep * step) * o2, [xp.min() - .5, xp.max() + .5], 'k-')