File: _thresholding.py

package info (click to toggle)
pywavelets 1.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,680 kB
  • sloc: python: 8,849; ansic: 5,134; makefile: 93
file content (250 lines) | stat: -rw-r--r-- 8,793 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# Copyright (c) 2012-2016 The PyWavelets Developers
#                         <https://github.com/PyWavelets/pywt>
# See COPYING for license details.

"""
The thresholding helper module implements the most popular signal thresholding
functions.
"""

from __future__ import division, print_function, absolute_import
import numpy as np

__all__ = ['threshold', 'threshold_firm']


def soft(data, value, substitute=0):
    data = np.asarray(data)
    magnitude = np.absolute(data)

    with np.errstate(divide='ignore'):
        # divide by zero okay as np.inf values get clipped, so ignore warning.
        thresholded = (1 - value/magnitude)
        thresholded.clip(min=0, max=None, out=thresholded)
        thresholded = data * thresholded

    if substitute == 0:
        return thresholded
    else:
        cond = np.less(magnitude, value)
        return np.where(cond, substitute, thresholded)


def nn_garrote(data, value, substitute=0):
    """Non-negative Garrote."""
    data = np.asarray(data)
    magnitude = np.absolute(data)

    with np.errstate(divide='ignore'):
        # divide by zero okay as np.inf values get clipped, so ignore warning.
        thresholded = (1 - value**2/magnitude**2)
        thresholded.clip(min=0, max=None, out=thresholded)
        thresholded = data * thresholded

    if substitute == 0:
        return thresholded
    else:
        cond = np.less(magnitude, value)
        return np.where(cond, substitute, thresholded)


def hard(data, value, substitute=0):
    data = np.asarray(data)
    cond = np.less(np.absolute(data), value)
    return np.where(cond, substitute, data)


def greater(data, value, substitute=0):
    data = np.asarray(data)
    if np.iscomplexobj(data):
        raise ValueError("greater thresholding only supports real data")
    return np.where(np.less(data, value), substitute, data)


def less(data, value, substitute=0):
    data = np.asarray(data)
    if np.iscomplexobj(data):
        raise ValueError("less thresholding only supports real data")
    return np.where(np.greater(data, value), substitute, data)


thresholding_options = {'soft': soft,
                        'hard': hard,
                        'greater': greater,
                        'less': less,
                        'garrote': nn_garrote,
                        # misspelled garrote for backwards compatibility
                        'garotte': nn_garrote,
                        }


def threshold(data, value, mode='soft', substitute=0):
    """
    Thresholds the input data depending on the mode argument.

    In ``soft`` thresholding [1]_, data values with absolute value less than
    `param` are replaced with `substitute`. Data values with absolute value
    greater or equal to the thresholding value are shrunk toward zero
    by `value`.  In other words, the new value is
    ``data/np.abs(data) * np.maximum(np.abs(data) - value, 0)``.

    In ``hard`` thresholding, the data values where their absolute value is
    less than the value param are replaced with `substitute`. Data values with
    absolute value greater or equal to the thresholding value stay untouched.

    ``garrote`` corresponds to the Non-negative garrote threshold [2]_, [3]_.
    It is intermediate between ``hard`` and ``soft`` thresholding.  It behaves
    like soft thresholding for small data values and approaches hard
    thresholding for large data values.

    In ``greater`` thresholding, the data is replaced with `substitute` where
    data is below the thresholding value. Greater data values pass untouched.

    In ``less`` thresholding, the data is replaced with `substitute` where data
    is above the thresholding value. Lesser data values pass untouched.

    Both ``hard`` and ``soft`` thresholding also support complex-valued data.

    Parameters
    ----------
    data : array_like
        Numeric data.
    value : scalar
        Thresholding value.
    mode : {'soft', 'hard', 'garrote', 'greater', 'less'}
        Decides the type of thresholding to be applied on input data. Default
        is 'soft'.
    substitute : float, optional
        Substitute value (default: 0).

    Returns
    -------
    output : array
        Thresholded array.

    See Also
    --------
    threshold_firm

    References
    ----------
    .. [1] D.L. Donoho and I.M. Johnstone. Ideal Spatial Adaptation via
        Wavelet Shrinkage. Biometrika. Vol. 81, No. 3, pp.425-455, 1994.
        DOI:10.1093/biomet/81.3.425
    .. [2] L. Breiman. Better Subset Regression Using the Nonnegative Garrote.
        Technometrics, Vol. 37, pp. 373-384, 1995.
        DOI:10.2307/1269730
    .. [3] H-Y. Gao.  Wavelet Shrinkage Denoising Using the Non-Negative
        Garrote.  Journal of Computational and Graphical Statistics Vol. 7,
        No. 4, pp.469-488. 1998.
        DOI:10.1080/10618600.1998.10474789

    Examples
    --------
    >>> import numpy as np
    >>> import pywt
    >>> data = np.linspace(1, 4, 7)
    >>> data
    array([ 1. ,  1.5,  2. ,  2.5,  3. ,  3.5,  4. ])
    >>> pywt.threshold(data, 2, 'soft')
    array([ 0. ,  0. ,  0. ,  0.5,  1. ,  1.5,  2. ])
    >>> pywt.threshold(data, 2, 'hard')
    array([ 0. ,  0. ,  2. ,  2.5,  3. ,  3.5,  4. ])
    >>> pywt.threshold(data, 2, 'garrote')
    array([ 0.        ,  0.        ,  0.        ,  0.9       ,  1.66666667,
            2.35714286,  3.        ])
    >>> pywt.threshold(data, 2, 'greater')
    array([ 0. ,  0. ,  2. ,  2.5,  3. ,  3.5,  4. ])
    >>> pywt.threshold(data, 2, 'less')
    array([ 1. ,  1.5,  2. ,  0. ,  0. ,  0. ,  0. ])

    """

    try:
        return thresholding_options[mode](data, value, substitute)
    except KeyError:
        # Make sure error is always identical by sorting keys
        keys = ("'{0}'".format(key) for key in
                sorted(thresholding_options.keys()))
        raise ValueError("The mode parameter only takes values from: {0}."
                         .format(', '.join(keys)))


def threshold_firm(data, value_low, value_high):
    """Firm threshold.

    The approach is intermediate between soft and hard thresholding [1]_. It
    behaves the same as soft-thresholding for values below `value_low` and
    the same as hard-thresholding for values above `thresh_high`.  For
    intermediate values, the thresholded value is in between that corresponding
    to soft or hard thresholding.

    Parameters
    ----------
    data : array-like
        The data to threshold.  This can be either real or complex-valued.
    value_low : float
        Any values smaller then `value_low` will be set to zero.
    value_high : float
        Any values larger than `value_high` will not be modified.

    Notes
    -----
    This thresholding technique is also known as semi-soft thresholding [2]_.

    For each value, `x`, in `data`. This function computes::

        if np.abs(x) <= value_low:
            return 0
        elif np.abs(x) > value_high:
            return x
        elif value_low < np.abs(x) and np.abs(x) <= value_high:
            return x * value_high * (1 - value_low/x)/(value_high - value_low)

    ``firm`` is a continuous function (like soft thresholding), but is
    unbiased for large values (like hard thresholding).

    If ``value_high == value_low`` this function becomes hard-thresholding.
    If ``value_high`` is infinity, this function becomes soft-thresholding.

    Returns
    -------
    val_new : array-like
        The values after firm thresholding at the specified thresholds.

    See Also
    --------
    threshold

    References
    ----------
    .. [1] H.-Y. Gao and A.G. Bruce. Waveshrink with firm shrinkage.
        Statistica Sinica, Vol. 7, pp. 855-874, 1997.
    .. [2] A. Bruce and H-Y. Gao. WaveShrink: Shrinkage Functions and
        Thresholds. Proc. SPIE 2569, Wavelet Applications in Signal and
        Image Processing III, 1995.
        DOI:10.1117/12.217582
    """

    if value_low < 0:
        raise ValueError("value_low must be non-negative.")

    if value_high < value_low:
        raise ValueError(
            "value_high must be greater than or equal to value_low.")

    data = np.asarray(data)
    magnitude = np.absolute(data)
    with np.errstate(divide='ignore'):
        # divide by zero okay as np.inf values get clipped, so ignore warning.
        vdiff = value_high - value_low
        thresholded = value_high * (1 - value_low/magnitude) / vdiff
        thresholded.clip(min=0, max=None, out=thresholded)
        thresholded = data * thresholded

    # restore hard-thresholding behavior for values > value_high
    large_vals = np.where(magnitude > value_high)
    if np.any(large_vals[0]):
        thresholded[large_vals] = data[large_vals]
    return thresholded