File: wavelet_packets.py

package info (click to toggle)
pywavelets 0.1.7~svn97-1
  • links: PTS, VCS
  • area: main
  • in suites: lenny
  • size: 2,408 kB
  • ctags: 1,492
  • sloc: ansic: 3,375; python: 1,910; makefile: 44
file content (361 lines) | stat: -rw-r--r-- 11,919 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
# -*- coding: utf-8 -*-

# Copyright (c) 2006-2008 Filip Wasilewski <filip.wasilewski@gmail.com>
# See COPYING for license details.

# $Id: wavelet_packets.py 95 2008-03-06 18:40:42Z filipw $

"""Wavelet packet transform"""

from _pywt import MODES, Wavelet, dwt, idwt, dwt_max_level

class Node(object):
    """
    WaveletPacket tree node.
    Subnodes are called 'a' and 'd', like approximation and detail coefficients
    in Discrete Wavelet Transform
    """
    def __init__(self, parent, data, nodeName):
        self.parent = parent
        if parent is not None:
            self.wavelet = parent.wavelet
            self.mode = parent.mode
            self.level = parent.level + 1
            self.maxlevel = parent.maxlevel
            self.path = parent.path + nodeName
        else:
            self.path = ""

        # data - signal on level 0, coeffs on higher levels
        self.data = data

        # children
        self.a = None
        self.d = None

        # other attributes
        self._isZeroTree = False

    def createChild(self, part, data=None):
        #print "create", part, self.path
        if part in ("a", "d"):
            if getattr(self, part) is not None:
                print "replacing node", part, getattr(self, part).path, getattr(self, part).data, data
                raise Warning
            setattr(self, part, Node(self, data, part))
        else:
            raise ValueError
                        
    def decompose(self):
        """
        Decompose node data creating two subnodes with DWT coefficients"
        """
        if self.level < self.maxlevel:
            a, d = dwt(self.data, self.wavelet, self.mode)
            self.createChild("a", a)
            self.createChild("d", d)
            return self.a, self.d
        else:
            raise ValueError("Maximum level value reached")
        
    def reconstruct(self, update=False):
        """
        Reconstruct node's data value using coefficients from subnodes.
        If update param is True, then reconstructed data replaces node's data.

        Returns None if node is marked as ZeroTree.
        Returns original node data if all subnodes are None or are marked as ZeroTrees.
        Returns IDWT reconstructed data returned by reconstruct() method of two nodes otherwise.
        """
        if self.isZeroTree:
            return None
        
        elif (self.a is None or self.a.isZeroTree) and (self.d is None or self.d.isZeroTree):
            return self.data

        else:
            data_a = None
            data_d = None
            if self.a is not None:
                data_a = self.a.reconstruct()
            if self.d is not None:
                data_d = self.d.reconstruct()
                
            if data_a is None and data_d is None:
                raise ValueError, "Can not reconstruct. Tree is missing data"
            else:
                rec = idwt(data_a, data_d, self.wavelet, self.mode, correct_size=True)
                if update:                    
                    self.data = rec
                return rec

    def markZeroTree(self, flag=True, remove_sub=True):
        """
        Mark node as ZeroTree.

        If flag equals True, node will be marked as ZT.
        If remove_sub is True, subnodes will be removed from tree.
        """
        if not flag:
            if not self._isZeroTree:
                self._isZeroTree = False
                self.decompose()
        else:
            self._isZeroTree = True

            if remove_sub:
                self.a = None
                self.d = None

    isZeroTree = property(lambda self: self._isZeroTree, markZeroTree)
    
    def getChild(self, part, decompose=True):
        """
        Returns subnode 'a' or 'd'.

        part - subnode name ('a' or 'd')
        decompose - if True and subnodes do not exist, they will be created with
            values from decomposition of current node (some lazy evaluation here)
        """
        if part in ("a", "d"):
            if not self.isZeroTree:
                child = getattr(self, part)
                if decompose and child is None:
                    self.decompose()
                    child = getattr(self, part)
                return child
            else:
                return None
        else:
            raise ValueError("Child node can only have 'a' or 'd' name, not '%s'" % part)
    def __getitem__(self, path):
        return self.get_node(path).data
    
    def get_node(self, path):
        """
        Find node of given path in tree.

        path - string composed of "a" and "d", of total length not greater than maxlevel.

        If node does not exist yet, it will be created by decomposition of its
        parent node.
        """
        if isinstance(path, basestring):
            if(len(path)):
                return self.getChild(path[0], True)[path[1:]]
            else:
                return self
        else:
            raise IndexError("Invalid path")

    def __setitem__(self, path, data):
        self.set_node(path, data)
        
    def set_node(self, path, data):
        if isinstance(path, basestring):
            if(len(path)):
                child = self.getChild(path[0], False)
                if child is None:
                    self.createChild(path[0], data)
                    child = self.getChild(path[0], False)
                child[path[1:]] = data
            else:
                self.data = data
        else:
            raise IndexError("Invalid path")

    def walk(self, func, args=tuple()):
        """
        Walk tree and call func on every node -> func(node, *args)
        If func returns True, descending to subnodes will be proceeded.
        
        func - callable object
        args - additional func parms
        """
        if func(self, *args) and self.level < self.maxlevel:
            a = self.getChild("a")
            d = self.getChild("d")
            a.walk(func, args)    
            d.walk(func, args)
            
    def walk_depth(self, func, args=tuple()):
        """
        Walk tree and call func on every node starting from bottom most nodes.
       
        func - callable object
        args - additional func parms
        """
        if self.level < self.maxlevel:
            a = self.getChild("a")
            d = self.getChild("d")
            a.walk_depth(func, args)
            d.walk_depth(func, args)
        func(self, *args)

    # other methods
    def energy(self):
        """sum of squared data values"""
        return sum(self.data*self.data)
    
    def __str__(self):
        return str(self.data)
    
class WaveletPacket(Node):
    """
    WaveletPacket(data, wavelet, mode'sp1', maxlevel=None)
    Data structure representing Wavelet Packet decomposition of signal.

    data - original data (signal)
    wavelet - wavelet used in DWT decomposition and reconstruction
    mode - signal extension mode - see MODES
    maxlevel - maximum level of decomposition
    """
    def __init__(self, data, wavelet, mode='sp1', maxlevel=None):
        Node.__init__(self, None, data, "")

        if not isinstance(wavelet, Wavelet):
            wavelet = Wavelet(wavelet)
        self.wavelet = wavelet
        self.mode = mode

        if data is not None:
            self.data_size = len(data)
            if maxlevel is None:
                maxlevel = dwt_max_level(self.data_size, self.wavelet.dec_len)
        else:
            self.data_size = None
        
        self.maxlevel = maxlevel
        self.level = 0
        self.frequency = (0., 1.)

    def __getitem__(self, path):
        return self.get_node(path).data
    
    def get_node(self, path):
        """
        Find node of given path in tree.

        path - string composed of "a" and "d", of total length not greater than maxlevel.

        If node does not exist yet, it will be created by decomposition of its
        parent node.
        """
        if len(path) > self.maxlevel:
            raise IndexError("Path length out of range")
        else:
            return Node.get_node(self, path)

    def __setitem__(self, path, value):
        if len(path) > self.maxlevel:
            raise IndexError, "path length out of range"
        else:
            return Node.__setitem__(self, path, value)

    def __delitem__(self, path):
        """
        Mark node of given path in tree as ZeroTree.

        path - string composed of "a" and "d", of total length not greater than maxlevel.

        If node does not exist yet, it will be created by decomposition of its
        parent node.
        """
        if len(path) > 0:
            self.get_node(path).markZeroTree(True, remove_sub=True)
        else:
            raise IndexError("Invalid path")

    #def decompose(self, level):
    #    def f(node, maxlevel):
    #        return node.level < maxlevel
    #    self.walk(f, (self.maxlevel,))
        
    def reconstruct(self, update=True):
        """
        Reconstruct data value using coefficients from subnodes.
        
        If update is True, then data values will be replaced by
        reconstruction values, also in subnodes.
        """
        if self.a is not None or self.d is not None:
            data = Node.reconstruct(self, update)
            if self.data_size is not None and len(data) > self.data_size:
                data = data[:self.data_size]
            if update:
                self.data = data
            return data
        return self.data # return original data
    
    def walk(self, func, args=tuple()):
        self.getChild("a").walk(func, args)
        self.getChild("d").walk(func, args)
    walk.__doc__ = Node.walk.__doc__

    def walk_depth(self, func, args=tuple()):
        self.getChild("a").walk_depth(func, args)
        self.getChild("d").walk_depth(func, args)
    walk_depth.__doc__ = Node.walk_depth.__doc__
    
    def get_level(self, level, order="natural"):
        """
        Returns all nodes from specified level.

        order - "natural" - left to right in tree
              - "freq" - frequency ordered
        """
        if level > self.maxlevel:
            raise ValueError, ("Specified level is greater than maximum level number (%d > %d)" % (level, self.maxlevel))

        result = []

        def collect(node):
            if node.level == level:
                result.append(node)
                return False
            return True

        self.walk(collect)
        if order == "natural":
            return result
        elif order == "freq":
            graycode = ["0", "1"]
            for i in range(level-1):
                graycode = [("0" + c) for c in graycode] + [("1" + c) for c in graycode[::-1]]
            order = [int(c, 2) for c in graycode]
            return [result[i] for i in order]
            
        else:
            raise ValueError("wrong order name %s" % order)
        
    def get_nonzero(self, decompose=False):
        """
        Returns leaf nodes not belonging to any zero tree.
        """
        result = []
        
        def collect(node):
            if node.isZeroTree:
                return False

            if node.level == node.maxlevel:
                result.append(node)
                return False
            if decompose:
                if node.a is None and node.d is None:
                    #self.decompose()
                    return True
                if node.a.isZeroTree and node.d.isZeroTree:
                    result.append(node)
                    return False
                return True
            else:
                if (node.a is None or node.a.isZeroTree) and (node.d is None or node.d.isZeroTree):
                    result.append(node)
                    return False
                return True
                
        self.walk(collect)
        return result

__all__ = ['Node', 'WaveletPacket']