File: transforms.py

package info (click to toggle)
python-spectral 0.22.4-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 1,064 kB
  • sloc: python: 13,161; makefile: 7
file content (162 lines) | stat: -rw-r--r-- 5,287 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
'''
Base classes for various types of transforms.
'''

from __future__ import absolute_import, division, print_function, unicode_literals

import collections
import numpy as np


class LinearTransform:
    '''A callable linear transform object.

    In addition to the __call__ method, which applies the transform to given,
    data, a LinearTransform object also has the following members:

        `dim_in` (int):

            The expected length of input vectors. This will be `None` if the
            input dimension is unknown (e.g., if the transform is a scalar).

        `dim_out` (int):

            The length of output vectors (after linear transformation). This
            will be `None` if the input dimension is unknown (e.g., if
            the transform is a scalar).

        `dtype` (numpy dtype):

            The numpy dtype for the output ndarray data.
    '''
    def __init__(self, A, **kwargs):
        '''Arguments:

            `A` (:class:`~numpy.ndarrray`):

                An (J,K) array to be applied to length-K targets.

        Keyword Argments:

            `pre` (scalar or length-K sequence):

                Additive offset to be applied prior to linear transformation.

            `post` (scalar or length-J sequence):

                An additive offset to be applied after linear transformation.

            `dtype` (numpy dtype):

                Explicit type for transformed data.
        '''

        self._pre = kwargs.get('pre', None)
        self._post = kwargs.get('post', None)
        A = np.array(A, copy=True)
        if A.ndim == 0:
            # Do not know input/ouput dimensions
            self._A = A
            (self.dim_out, self.dim_in) = (None, None)
        else:
            if len(A.shape) == 1:
                self._A = A.reshape(((1,) + A.shape))
            else:
                self._A = A
            (self.dim_out, self.dim_in) = self._A.shape
        self.dtype = kwargs.get('dtype', self._A.dtype)

    def __call__(self, X):
        '''Applies the linear transformation to the given data.

        Arguments:

            `X` (:class:`~numpy.ndarray` or object with `transform` method):

                If `X` is an ndarray, it is either an (M,N,K) array containing
                M*N length-K vectors to be transformed or it is an (R,K) array
                of length-K vectors to be transformed. If `X` is an object with
                a method named `transform` the result of passing the
                `LinearTransform` object to the `transform` method will be
                returned.

        Returns an (M,N,J) or (R,J) array, depending on shape of `X`, where J
        is the length of the first dimension of the array `A` passed to
        __init__.
        '''
        if not isinstance(X, np.ndarray):
            if hasattr(X, 'transform') and isinstance(X.transform, collections.Callable):
                return X.transform(self)
            else:
                raise TypeError('Unable to apply transform to object.')

        shape = X.shape
        if len(shape) == 3:
            X = X.reshape((-1, shape[-1]))
            if self._pre is not None:
                X = X + self._pre
            Y = np.dot(self._A, X.T).T
            if self._post is not None:
                Y += self._post
            return Y.reshape((shape[:2] + (-1,))).squeeze().astype(self.dtype)
        else:
            if self._pre is not None:
                X = X + self._pre
            Y = np.dot(self._A, X.T).T
            if self._post is not None:
                Y += self._post
            return Y.astype(self.dtype)

    def chain(self, transform):
        '''Chains together two linear transforms.
        If the transform `f1` is given by

        .. math::

            F_1(X) = A_1(X + b_1) + c_1

        and `f2` by

        .. math::

            F_2(X) = A_2(X + b_2) + c_2

        then `f1.chain(f2)` returns a new LinearTransform, `f3`, whose output
        is given by

        .. math::

            F_3(X) = F_2(F_1(X))
        '''

        if isinstance(transform, np.ndarray):
            transform = LinearTransform(transform)
        if self.dim_in is not None and transform.dim_out is not None \
                and self.dim_in != transform.dim_out:
            raise Exception('Input/Output dimensions of chained transforms'
                            'do not match.')

        # Internally, the new transform is computed as:
        # Y = f2._A.dot(f1._A).(X + f1._pre) + f2._A.(f1._post + f2._pre) + f2._post
        # However, any of the _pre/_post members could be `None` so that needs
        # to be checked.

        if transform._pre is not None:
            pre = np.array(transform._pre)
        else:
            pre = None
        post = None
        if transform._post is not None:
            post = np.array(transform._post)
            if self._pre is not None:
                post += self._pre
        elif self._pre is not None:
            post = np.array(self._pre)
        if post is not None:
            post = self._A.dot(post)
        if self._post:
            post += self._post
        if post is not None:
            post = np.array(post)
        A = np.dot(self._A, transform._A)
        return LinearTransform(A, pre=pre, post=post)