File: mxne_debiasing.py

package info (click to toggle)
python-mne 0.17%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 95,104 kB
  • sloc: python: 110,639; makefile: 222; sh: 15
file content (136 lines) | stat: -rwxr-xr-x 3,775 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Authors: Daniel Strohmeier <daniel.strohmeier@tu-ilmenau.de>
#          Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#
# License: BSD (3-clause)

from math import sqrt
import numpy as np
from scipy import linalg

from ..utils import check_random_state, logger, verbose


def power_iteration_kron(A, C, max_iter=1000, tol=1e-3, random_state=0):
    """Find the largest singular value for the matrix kron(C.T, A).

    It uses power iterations.

    Parameters
    ----------
    A : array
        An array
    C : array
        An array
    max_iter : int
        Maximum number of iterations
    random_state : int | RandomState | None
        Random state for random number generation

    Returns
    -------
    L : float
        largest singular value

    Notes
    -----
    http://en.wikipedia.org/wiki/Power_iteration
    """
    AS_size = C.shape[0]
    rng = check_random_state(random_state)
    B = rng.randn(AS_size, AS_size)
    B /= linalg.norm(B, 'fro')
    ATA = np.dot(A.T, A)
    CCT = np.dot(C, C.T)
    L0 = np.inf
    for _ in range(max_iter):
        Y = np.dot(np.dot(ATA, B), CCT)
        L = linalg.norm(Y, 'fro')

        if abs(L - L0) < tol:
            break

        B = Y / L
        L0 = L
    return L


@verbose
def compute_bias(M, G, X, max_iter=1000, tol=1e-6, n_orient=1, verbose=None):
    """Compute scaling to correct amplitude bias.

    It solves the following optimization problem using FISTA:

    min 1/2 * (|| M - GDX ||fro)^2
    s.t. D >= 1 and D is a diagonal matrix

    Reference for the FISTA algorithm:
    Amir Beck and Marc Teboulle
    A Fast Iterative Shrinkage-Thresholding Algorithm for Linear Inverse
    Problems, SIAM J. Imaging Sci., 2(1), 183-202. (20 pages)
    http://epubs.siam.org/doi/abs/10.1137/080716542

    Parameters
    ----------
    M : array
        measurement data.
    G : array
        leadfield matrix.
    X : array
        reconstructed time courses with amplitude bias.
    max_iter : int
        Maximum number of iterations.
    tol : float
        The tolerance on convergence.
    n_orient : int
        The number of orientations (1 for fixed and 3 otherwise).
    verbose : bool, str, int, or None
        If not None, override default verbose level (see :func:`mne.verbose`
        and :ref:`Logging documentation <tut_logging>` for more).

    Returns
    -------
    D : array
        Debiasing weights.
    """
    n_sources = X.shape[0]

    lipschitz_constant = 1.1 * power_iteration_kron(G, X)

    # initializations
    D = np.ones(n_sources)
    Y = np.ones(n_sources)
    t = 1.0

    for i in range(max_iter):
        D0 = D

        # gradient step
        R = M - np.dot(G * Y, X)
        D = Y + np.sum(np.dot(G.T, R) * X, axis=1) / lipschitz_constant
        # Equivalent but faster than:
        # D = Y + np.diag(np.dot(np.dot(G.T, R), X.T)) / lipschitz_constant

        # prox ie projection on constraint
        if n_orient != 1:  # take care of orientations
            # The scaling has to be the same for all orientations
            D = np.mean(D.reshape(-1, n_orient), axis=1)
            D = np.tile(D, [n_orient, 1]).T.ravel()
        D = np.maximum(D, 1.0)

        t0 = t
        t = 0.5 * (1.0 + sqrt(1.0 + 4.0 * t ** 2))
        Y.fill(0.0)
        dt = (t0 - 1.0) / t
        Y = D + dt * (D - D0)

        Ddiff = linalg.norm(D - D0, np.inf)

        if Ddiff < tol:
            logger.info("Debiasing converged after %d iterations "
                        "max(|D - D0| = %e < %e)" % (i, Ddiff, tol))
            break
    else:
        Ddiff = linalg.norm(D - D0, np.inf)
        logger.info("Debiasing did not converge after %d iterations! "
                    "max(|D - D0| = %e >= %e)" % (max_iter, Ddiff, tol))
    return D