File: fastica_.py

package info (click to toggle)
scikit-learn 0.11.0-2%2Bdeb7u1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 13,900 kB
  • sloc: python: 34,740; ansic: 8,860; cpp: 8,849; pascal: 230; makefile: 211; sh: 14
file content (388 lines) | stat: -rw-r--r-- 12,615 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""
Python implementation of the fast ICA algorithms.

Reference: Tables 8.3 and 8.4 page 196 in the book:
Independent Component Analysis, by  Hyvarinen et al.
"""

# Author: Pierre Lafaye de Micheaux, Stefan van der Walt, Gael Varoquaux,
#         Bertrand Thirion, Alexandre Gramfort
# License: BSD 3 clause
import warnings
import numpy as np
from scipy import linalg

from ..base import BaseEstimator
from ..utils import array2d, as_float_array, check_random_state

__all__ = ['fastica', 'FastICA']


def _gs_decorrelation(w, W, j):
    """
    Orthonormalize w wrt the first j rows of W

    Parameters
    ----------
    w: array of shape(n), to be orthogonalized
    W: array of shape(p, n), null space definition
    j: int < p

    caveats
    -------
    assumes that W is orthogonal
    w changed in place
    """
    w -= np.dot(np.dot(w, W[:j].T), W[:j])
    return w


def _sym_decorrelation(W):
    """ Symmetric decorrelation
    i.e. W <- (W * W.T) ^{-1/2} * W
    """
    K = np.dot(W, W.T)
    s, u = linalg.eigh(K)
    # u (resp. s) contains the eigenvectors (resp. square roots of
    # the eigenvalues) of W * W.T
    W = np.dot(np.dot(np.dot(u, np.diag(1.0 / np.sqrt(s))), u.T), W)
    return W


def _ica_def(X, tol, g, gprime, fun_args, max_iter, w_init):
    """Deflationary FastICA using fun approx to neg-entropy function

    Used internally by FastICA.
    """

    n_components = w_init.shape[0]
    W = np.zeros((n_components, n_components), dtype=float)

    # j is the index of the extracted component
    for j in range(n_components):
        w = w_init[j, :].copy()
        w /= np.sqrt((w ** 2).sum())

        n_iterations = 0
        # we set lim to tol+1 to be sure to enter at least once in next while
        lim = tol + 1
        while ((lim > tol) & (n_iterations < (max_iter - 1))):
            wtx = np.dot(w.T, X)
            gwtx = g(wtx, fun_args)
            g_wtx = gprime(wtx, fun_args)
            w1 = (X * gwtx).mean(axis=1) - g_wtx.mean() * w

            _gs_decorrelation(w1, W, j)

            w1 /= np.sqrt((w1 ** 2).sum())

            lim = np.abs(np.abs((w1 * w).sum()) - 1)
            w = w1
            n_iterations = n_iterations + 1

        W[j, :] = w

    return W


def _ica_par(X, tol, g, gprime, fun_args, max_iter, w_init):
    """Parallel FastICA.

    Used internally by FastICA --main loop

    """
    n, p = X.shape

    W = _sym_decorrelation(w_init)

    # we set lim to tol+1 to be sure to enter at least once in next while
    lim = tol + 1
    it = 0
    while ((lim > tol) and (it < (max_iter - 1))):
        wtx = np.dot(W, X)
        gwtx = g(wtx, fun_args)
        g_wtx = gprime(wtx, fun_args)
        W1 = np.dot(gwtx, X.T) / float(p) \
             - np.dot(np.diag(g_wtx.mean(axis=1)), W)

        W1 = _sym_decorrelation(W1)

        lim = max(abs(abs(np.diag(np.dot(W1, W.T))) - 1))
        W = W1
        it += 1

    return W


def fastica(X, n_components=None, algorithm="parallel", whiten=True,
            fun="logcosh", fun_prime='', fun_args={}, max_iter=200,
            tol=1e-04, w_init=None, random_state=None):
    """Perform Fast Independent Component Analysis.

    Parameters
    ----------
    X : array-like, shape = [n_samples, n_features]
        Training vector, where n_samples is the number of samples and
        n_features is the number of features.
    n_components : int, optional
        Number of components to extract. If None no dimension reduction
        is performed.
    algorithm : {'parallel', 'deflation'}, optional
        Apply a parallel or deflational FASTICA algorithm.
    whiten: boolean, optional
        If True perform an initial whitening of the data.
        If False, the data is assumed to have already been
        preprocessed: it should be centered, normed and white.
        Otherwise you will get incorrect results.
        In this case the parameter n_components will be ignored.
    fun : string or function, optional
        The functional form of the G function used in the
        approximation to neg-entropy. Could be either 'logcosh', 'exp',
        or 'cube'.
        You can also provide your own function but in this case, its
        derivative should be provided via argument fun_prime
    fun_prime : empty string ('') or function, optional
        See fun.
    fun_args: dictionary, optional
        If empty and if fun='logcosh', fun_args will take value
        {'alpha' : 1.0}
    max_iter: int, optional
        Maximum number of iterations to perform
    tol: float, optional
        A positive scalar giving the tolerance at which the
        un-mixing matrix is considered to have converged
    w_init: (n_components, n_components) array, optional
        Initial un-mixing array of dimension (n.comp,n.comp).
        If None (default) then an array of normal r.v.'s is used
    source_only: boolean, optional
        if True, only the sources matrix is returned
    random_state: int or RandomState
        Pseudo number generator state used for random sampling.

    Returns
    -------
    K: (n_components, p) array or None.
        If whiten is 'True', K is the pre-whitening matrix that projects data
        onto the first n.comp principal components. If whiten is 'False', K is
        'None'.

    W: (n_components, n_components) array
        estimated un-mixing matrix
        The mixing matrix can be obtained by::

            w = np.dot(W, K.T)
            A = w.T * (w * w.T).I

    S: (n_components, n) array
        estimated source matrix


    Notes
    -----

    The data matrix X is considered to be a linear combination of
    non-Gaussian (independent) components i.e. X = AS where columns of S
    contain the independent components and A is a linear mixing
    matrix. In short ICA attempts to `un-mix' the data by estimating an
    un-mixing matrix W where ``S = W K X.``

    This implementation was originally made for data of shape
    [n_features, n_samples]. Now the input is transposed
    before the algorithm is applied. This makes it slightly
    faster for Fortran-ordered input.

    Implemented using FastICA:
    `A. Hyvarinen and E. Oja, Independent Component Analysis:
    Algorithms and Applications, Neural Networks, 13(4-5), 2000,
    pp. 411-430`

    """
    random_state = check_random_state(random_state)
    # make interface compatible with other decompositions
    X = array2d(X).T

    algorithm_funcs = {'parallel': _ica_par,
                       'deflation': _ica_def}

    alpha = fun_args.get('alpha', 1.0)
    if (alpha < 1) or (alpha > 2):
        raise ValueError("alpha must be in [1,2]")

    if isinstance(fun, str):
        # Some standard nonlinear functions
        # XXX: these should be optimized, as they can be a bottleneck.
        if fun == 'logcosh':
            def g(x, fun_args):
                alpha = fun_args.get('alpha', 1.0)
                return np.tanh(alpha * x)

            def gprime(x, fun_args):
                alpha = fun_args.get('alpha', 1.0)
                return alpha * (1 - (np.tanh(alpha * x)) ** 2)

        elif fun == 'exp':
            def g(x, fun_args):
                return x * np.exp(-(x ** 2) / 2)

            def gprime(x, fun_args):
                return (1 - x ** 2) * np.exp(-(x ** 2) / 2)

        elif fun == 'cube':
            def g(x, fun_args):
                return x ** 3

            def gprime(x, fun_args):
                return 3 * x ** 2
        else:
            raise ValueError(
                        'fun argument should be one of logcosh, exp or cube')
    elif callable(fun):
        def g(x, fun_args):
            return fun(x, **fun_args)

        def gprime(x, fun_args):
            return fun_prime(x, **fun_args)
    else:
        raise ValueError('fun argument should be either a string '
                         '(one of logcosh, exp or cube) or a function')

    n, p = X.shape

    if whiten == False and n_components is not None:
        n_components = None
        warnings.warn('Ignoring n_components with whiten=False.')

    if n_components is None:
        n_components = min(n, p)
    if (n_components > min(n, p)):
        n_components = min(n, p)
        print("n_components is too large: it will be set to %s" % n_components)

    if whiten:
        # Centering the columns (ie the variables)
        X = X - X.mean(axis=-1)[:, np.newaxis]

        # Whitening and preprocessing by PCA
        u, d, _ = linalg.svd(X, full_matrices=False)

        del _
        K = (u / d).T[:n_components]  # see (6.33) p.140
        del u, d
        X1 = np.dot(K, X)
        # see (13.6) p.267 Here X1 is white and data
        # in X has been projected onto a subspace by PCA
        X1 *= np.sqrt(p)
    else:
        # X must be casted to floats to avoid typing issues with numpy
        # 2.0 and the line below
        X1 = as_float_array(X, copy=True)

    if w_init is None:
        w_init = random_state.normal(size=(n_components, n_components))
    else:
        w_init = np.asarray(w_init)
        if w_init.shape != (n_components, n_components):
            raise ValueError("w_init has invalid shape -- should be %(shape)s"
                             % {'shape': (n_components, n_components)})

    kwargs = {'tol': tol,
              'g': g,
              'gprime': gprime,
              'fun_args': fun_args,
              'max_iter': max_iter,
              'w_init': w_init}

    func = algorithm_funcs.get(algorithm, 'parallel')

    W = func(X1, **kwargs)
    del X1

    if whiten:
        S = np.dot(np.dot(W, K), X)
        return K, W, S.T
    else:
        S = np.dot(W, X)
        return None, W, S.T


class FastICA(BaseEstimator):
    """FastICA; a fast algorithm for Independent Component Analysis

    Parameters
    ----------
    n_components : int, optional
        Number of components to use. If none is passed, all are used.
    algorithm : {'parallel', 'deflation'}
        Apply parallel or deflational algorithm for FastICA
    whiten : boolean, optional
        If whiten is false, the data is already considered to be
        whitened, and no whitening is performed.
    fun : {'logcosh', 'exp', or 'cube'}, or a callable
        The non-linear function used in the FastICA loop to approximate
        negentropy. If a function is passed, it derivative should be
        passed as the 'fun_prime' argument.
    fun_prime : None or a callable
        The derivative of the non-linearity used.
    max_iter : int, optional
        Maximum number of iterations during fit
    tol : float, optional
        Tolerance on update at each iteration
    w_init : None of an (n_components, n_components) ndarray
        The mixing matrix to be used to initialize the algorithm.
    random_state: int or RandomState
        Pseudo number generator state used for random sampling.

    Attributes
    ----------
    `unmixing_matrix_` : 2D array, [n_components, n_samples]
        The unmixing matrix

    Notes
    -----

    Implementation based on
    `A. Hyvarinen and E. Oja, Independent Component Analysis:
    Algorithms and Applications, Neural Networks, 13(4-5), 2000,
    pp. 411-430`

    """

    def __init__(self, n_components=None, algorithm='parallel', whiten=True,
                 fun='logcosh', fun_prime='', fun_args=None, max_iter=200,
                 tol=1e-4, w_init=None, random_state=None):
        super(FastICA, self).__init__()
        self.n_components = n_components
        self.algorithm = algorithm
        self.whiten = whiten
        self.fun = fun
        self.fun_prime = fun_prime
        self.fun_args = {} if fun_args is None else fun_args
        self.max_iter = max_iter
        self.tol = tol
        self.w_init = w_init
        self.random_state = random_state

    def fit(self, X):
        whitening_, unmixing_, sources_ = fastica(X, self.n_components,
                        self.algorithm, self.whiten,
                        self.fun, self.fun_prime, self.fun_args, self.max_iter,
                        self.tol, self.w_init,
                        random_state=self.random_state)
        if self.whiten == True:
            self.unmixing_matrix_ = np.dot(unmixing_, whitening_)
        else:
            self.unmixing_matrix_ = unmixing_
        self.components_ = sources_
        return self

    def transform(self, X):
        """Apply un-mixing matrix "W" to X to recover the sources

        S = X * W.T
        """
        return np.dot(X, self.unmixing_matrix_.T)

    def get_mixing_matrix(self):
        """Compute the mixing matrix
        """
        return linalg.pinv(self.unmixing_matrix_)