File: transform2d.py

package info (click to toggle)
python-dtcwt 0.14.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 8,588 kB
  • sloc: python: 6,287; sh: 29; makefile: 13
file content (289 lines) | stat: -rw-r--r-- 11,237 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
from __future__ import division, absolute_import

import logging
import numpy as np
from six.moves import xrange

from dtcwt.coeffs import biort as _biort, qshift as _qshift
from dtcwt.defaults import DEFAULT_BIORT, DEFAULT_QSHIFT
from dtcwt.utils import appropriate_complex_type_for, asfarray, memoize
from dtcwt.opencl.lowlevel import axis_convolve, axis_convolve_dfilter, q2c
from dtcwt.opencl.lowlevel import to_device, to_queue, to_array, empty

from dtcwt.numpy import Pyramid
from dtcwt.numpy import Transform2d as Transform2dNumPy

try:
    from pyopencl.array import concatenate, Array as CLArray
except ImportError:
    # The lack of OpenCL will be caught by the low-level routines.
    pass

def dtwavexfm2(X, nlevels=3, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, include_scale=False, queue=None):
    t = Transform2d(biort=biort, qshift=qshift, queue=queue)
    r = t.forward(X, nlevels=nlevels, include_scale=include_scale)
    if include_scale:
        return r.lowpass, r.highpasses, r.scales
    else:
        return r.lowpass, r.highpasses

class Pyramid(object):
    """
    An interface-compatible version of
    :py:class:`dtcwt.Pyramid` where the initialiser
    arguments are assumed to by :py:class:`pyopencl.array.Array` instances.

    The attributes defined in :py:class:`dtcwt.Pyramid`
    are implemented via properties. The original OpenCL arrays may be accessed
    via the ``cl_...`` attributes.

    .. note::

        The copy from device to host is performed *once* and then memoized.
        This makes repeated access to the host-side attributes efficient but
        will mean that any changes to the device-side arrays will not be
        reflected in the host-side attributes after their first access. You
        should not be modifying the arrays once you return an instance of this
        class anyway but if you do, beware!

    .. py:attribute:: cl_lowpass

        The CL array containing the lowpass image.

    .. py:attribute:: cl_highpasses

        A tuple of CL arrays containing the subband images.

    .. py:attribute:: cl_scales

        *(optional)* Either ``None`` or a tuple of lowpass images for each
        scale.

    """
    def __init__(self, lowpass, highpasses, scales=None):
        self.cl_lowpass = lowpass
        self.cl_highpasses = highpasses
        self.cl_scales = scales

    @property
    def lowpass(self):
        if not hasattr(self, '_lowpass'):
            self._lowpass = to_array(self.cl_lowpass) if self.cl_lowpass is not None else None
        return self._lowpass

    @property
    def highpasses(self):
        if not hasattr(self, '_highpasses'):
            self._highpasses = tuple(to_array(x) for x in self.cl_highpasses) if self.cl_highpasses is not None else None
        return self._highpasses

    @property
    def scales(self):
        if not hasattr(self, '_scales'):
            self._scales = tuple(to_array(x) for x in self.cl_scales) if self.cl_scales is not None else None
        return self._scales

class Transform2d(Transform2dNumPy):
    """
    An implementation of the 2D DT-CWT via OpenCL. *biort* and *qshift* are the
    wavelets which parameterise the transform.

    If *queue* is non-*None* it is an instance of
    :py:class:`pyopencl.CommandQueue` which is used to compile and execute the
    OpenCL kernels which implement the transform. If it is *None*, the first
    available compute device is used.

    If *biort* or *qshift* are strings, they are used as an argument to the
    :py:func:`dtcwt.coeffs.biort` or :py:func:`dtcwt.coeffs.qshift` functions.
    Otherwise, they are interpreted as tuples of vectors giving filter
    coefficients. In the *biort* case, this should be (h0o, g0o, h1o, g1o). In
    the *qshift* case, this should be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b).

    .. note::

        At the moment *only* the **forward** transform is accelerated. The
        inverse transform uses the NumPy backend.

    """
    def __init__(self, biort=DEFAULT_BIORT, qshift=DEFAULT_QSHIFT, queue=None):
        super(Transform2d, self).__init__(biort=biort, qshift=qshift)
        self.queue = to_queue(queue)

    def forward(self, X, nlevels=3, include_scale=False):
        """Perform a *n*-level DTCWT-2D decompostion on a 2D matrix *X*.

        :param X: 2D real array
        :param nlevels: Number of levels of wavelet decomposition

        :returns: A :py:class:`dtcwt.Pyramid` compatible object representing the transform-domain signal

        .. note::

            *X* may be a :py:class:`pyopencl.array.Array` instance which has
            already been copied to the device. In which case, it must be 2D.
            (I.e. a vector will not be auto-promoted.)

        .. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
        .. codeauthor:: Nick Kingsbury, Cambridge University, Sept 2001
        .. codeauthor:: Cian Shaffrey, Cambridge University, Sept 2001

        """
        queue = self.queue

        if isinstance(X, CLArray):
            if len(X.shape) != 2:
                raise ValueError('Input array must be two-dimensional')
        else:
            # If not an array, copy to device
            X = np.atleast_2d(asfarray(X))

        # If biort has 6 elements instead of 4, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.biort) == 4:
            h0o, g0o, h1o, g1o = self.biort
        elif len(self.biort) == 6:
            h0o, g0o, h1o, g1o, h2o, g2o = self.biort
        else:
            raise ValueError('Biort wavelet must have 6 or 4 components.')

        # If qshift has 12 elements instead of 8, then it's a modified
        # rotationally symmetric wavelet
        # FIXME: there's probably a nicer way to do this
        if len(self.qshift) == 8:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b = self.qshift
        elif len(self.qshift) == 12:
            h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b, h2a, h2b = self.qshift[:10]
        else:
            raise ValueError('Qshift wavelet must have 12 or 8 components.')

        original_size = X.shape

        if len(X.shape) >= 3:
            raise ValueError('The entered image is {0}, please enter each image slice separately.'.
                    format('x'.join(list(str(s) for s in X.shape))))

        # The next few lines of code check to see if the image is odd in size, if so an extra ...
        # row/column will be added to the bottom/right of the image
        initial_row_extend = 0  #initialise
        initial_col_extend = 0
        if original_size[0] % 2 != 0:
            # if X.shape[0] is not divisible by 2 then we need to extend X by adding a row at the bottom
            X = to_array(X)
            X = np.vstack((X, X[[-1],:]))  # Any further extension will be done in due course.
            initial_row_extend = 1

        if original_size[1] % 2 != 0:
            # if X.shape[1] is not divisible by 2 then we need to extend X by adding a col to the left
            X = to_array(X)
            X = np.hstack((X, X[:,[-1]]))
            initial_col_extend = 1

        extended_size = X.shape

        # Copy X to the device if necessary
        X = to_device(X, queue=queue)

        if nlevels == 0:
            if include_scale:
                return Pyramid(X, (), ())
            else:
                return Pyramid(X, ())

        # initialise
        Yh = [None,] * nlevels
        if include_scale:
            # this is only required if the user specifies a third output component.
            Yscale = [None,] * nlevels

        complex_dtype = np.complex64

        if nlevels >= 1:
            # Do odd top-level filters on cols.
            Lo = axis_convolve(X,h0o,axis=0,queue=queue)
            Hi = axis_convolve(X,h1o,axis=0,queue=queue)
            if len(self.biort) >= 6:
                Ba = axis_convolve(X,h2o,axis=0,queue=queue)

            # Do odd top-level filters on rows.
            LoLo = axis_convolve(Lo,h0o,axis=1,queue=queue)

            if len(self.biort) >= 6:
                diag = axis_convolve(Ba,h2o,axis=1,queue=queue)
            else:
                diag = axis_convolve(Hi,h1o,axis=1,queue=queue)

            Yh[0] = q2c(
                axis_convolve(Hi,h0o,axis=1,queue=queue),
                axis_convolve(Lo,h1o,axis=1,queue=queue),
                diag,
                queue=queue
            )

            if include_scale:
                Yscale[0] = LoLo

        for level in xrange(1, nlevels):
            row_size, col_size = LoLo.shape

            if row_size % 4 != 0:
                # Extend by 2 rows if no. of rows of LoLo are not divisible by 4
                LoLo = to_array(LoLo)
                LoLo = np.vstack((LoLo[:1,:], LoLo, LoLo[-1:,:]))

            if col_size % 4 != 0:
                # Extend by 2 cols if no. of cols of LoLo are not divisible by 4
                LoLo = to_array(LoLo)
                LoLo = np.hstack((LoLo[:,:1], LoLo, LoLo[:,-1:]))

            # Do even Qshift filters on rows.
            Lo = axis_convolve_dfilter(LoLo,h0b,axis=0,queue=queue)
            Hi = axis_convolve_dfilter(LoLo,h1b,axis=0,queue=queue)
            if len(self.qshift) >= 12:
                Ba = axis_convolve_dfilter(LoLo,h2b,axis=0,queue=queue)

            # Do even Qshift filters on columns.
            LoLo = axis_convolve_dfilter(Lo,h0b,axis=1,queue=queue)

            if len(self.qshift) >= 12:
                diag = axis_convolve_dfilter(Ba,h2b,axis=1,queue=queue)
            else:
                diag = axis_convolve_dfilter(Hi,h1b,axis=1,queue=queue)

            Yh[level] = q2c(
                axis_convolve_dfilter(Hi,h0b,axis=1,queue=queue),
                axis_convolve_dfilter(Lo,h1b,axis=1,queue=queue),
                diag,
                queue=queue
            )

            if include_scale:
                Yscale[level] = LoLo

        Yl = LoLo

        if initial_row_extend == 1 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row and rightmost column have been duplicated, prior to decomposition.')

        if initial_row_extend == 1 and initial_col_extend == 0:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The bottom row has been duplicated, prior to decomposition.')

        if initial_row_extend == 0 and initial_col_extend == 1:
            logging.warn('The image entered is now a {0} NOT a {1}.'.format(
                'x'.join(list(str(s) for s in extended_size)),
                'x'.join(list(str(s) for s in original_size))))
            logging.warn(
                'The rightmost column has been duplicated, prior to decomposition.')

        if include_scale:
            return Pyramid(Yl, tuple(Yh), tuple(Yscale))
        else:
            return Pyramid(Yl, tuple(Yh))