File: montage.py

package info (click to toggle)
python-mne 1.9.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 131,492 kB
  • sloc: python: 213,302; javascript: 12,910; sh: 447; makefile: 144
file content (135 lines) | stat: -rw-r--r-- 4,324 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

"""Functions to plot EEG sensor montages or digitizer montages."""

from copy import deepcopy

import numpy as np
from scipy.spatial.distance import cdist

from .._fiff._digitization import _get_fid_coords
from .._fiff.meas_info import create_info
from ..utils import _check_option, _validate_type, logger, verbose
from .utils import plot_sensors


@verbose
def plot_montage(
    montage,
    *,
    scale=None,
    scale_factor=None,
    show_names=True,
    kind="topomap",
    show=True,
    sphere=None,
    axes=None,
    verbose=None,
):
    """Plot a montage.

    Parameters
    ----------
    montage : instance of DigMontage
        The montage to visualize.
    scale : float
        Determines the scale of the channel points and labels; values < 1 will scale
        down, whereas values > 1 will scale up. Default to None, which implies 1.
    scale_factor : float
        Determines the size of the points. Deprecated, use scale instead.
    show_names : bool | list
        Whether to display all channel names. If a list, only the channel
        names in the list are shown. Defaults to True.
    kind : str
        Whether to plot the montage as '3d' or 'topomap' (default).
    show : bool
        Show figure if True.
    %(sphere_topomap_auto)s
    %(axes_montage)s

        .. versionadded:: 1.4
    %(verbose)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        The figure object.
    """
    import matplotlib.pyplot as plt

    from ..channels import DigMontage, make_dig_montage

    if scale_factor is not None:
        msg = "scale_factor has been deprecated and will be removed. Use scale instead."
        if scale is not None:
            raise ValueError(
                " ".join(["scale and scale_factor cannot be used together.", msg])
            )
        logger.info(msg)
    if scale is None:
        scale = 1

    _check_option("kind", kind, ["topomap", "3d"])
    _validate_type(montage, DigMontage, item_name="montage")
    ch_names = montage.ch_names
    title = None

    if len(ch_names) == 0:
        raise RuntimeError("No valid channel positions found.")

    pos = np.array(list(montage._get_ch_pos().values()))

    dists = cdist(pos, pos)

    # only consider upper triangular part by setting the rest to np.nan
    dists[np.tril_indices(dists.shape[0])] = np.nan
    dupes = np.argwhere(np.isclose(dists, 0))
    if dupes.any():
        montage = deepcopy(montage)
        n_chans = pos.shape[0]
        n_dupes = dupes.shape[0]
        idx = np.setdiff1d(np.arange(len(pos)), dupes[:, 1]).tolist()
        logger.info(f"{n_dupes} duplicate electrode labels found:")
        logger.info(", ".join([ch_names[d[0]] + "/" + ch_names[d[1]] for d in dupes]))
        logger.info(f"Plotting {n_chans - n_dupes} unique labels.")
        ch_names = [ch_names[i] for i in idx]
        ch_pos = dict(zip(ch_names, pos[idx, :]))
        # XXX: this might cause trouble if montage was originally in head
        fid, _ = _get_fid_coords(montage.dig)
        montage = make_dig_montage(ch_pos=ch_pos, **fid)

    info = create_info(ch_names, sfreq=256, ch_types="eeg")
    info.set_montage(montage, on_missing="ignore")
    fig = plot_sensors(
        info,
        kind=kind,
        show_names=show_names,
        show=show,
        title=title,
        sphere=sphere,
        axes=axes,
    )

    if scale_factor is not None:
        # scale points
        collection = fig.axes[0].collections[0]
        collection.set_sizes([scale_factor])
    elif scale is not None:
        # scale points
        collection = fig.axes[0].collections[0]
        collection.set_sizes([scale * 10])

        # scale labels
        labels = fig.findobj(match=plt.Text)
        x_label, y_label = fig.axes[0].xaxis.label, fig.axes[0].yaxis.label
        z_label = fig.axes[0].zaxis.label if kind == "3d" else None
        tick_labels = fig.axes[0].get_xticklabels() + fig.axes[0].get_yticklabels()
        if kind == "3d":
            tick_labels += fig.axes[0].get_zticklabels()
        for label in labels:
            if label not in [x_label, y_label, z_label] + tick_labels:
                label.set_fontsize(label.get_fontsize() * scale)

    return fig