1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
|
"""Functions to plot on circle as for connectivity."""
from __future__ import print_function
# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# Denis Engemann <denis.engemann@gmail.com>
# Martin Luessi <mluessi@nmr.mgh.harvard.edu>
#
# License: Simplified BSD
from itertools import cycle
from functools import partial
import numpy as np
from .utils import plt_show, _set_ax_facecolor
from ..externals.six import string_types
def circular_layout(node_names, node_order, start_pos=90, start_between=True,
group_boundaries=None, group_sep=10):
"""Create layout arranging nodes on a circle.
Parameters
----------
node_names : list of str
Node names.
node_order : list of str
List with node names defining the order in which the nodes are
arranged. Must have the elements as node_names but the order can be
different. The nodes are arranged clockwise starting at "start_pos"
degrees.
start_pos : float
Angle in degrees that defines where the first node is plotted.
start_between : bool
If True, the layout starts with the position between the nodes. This is
the same as adding "180. / len(node_names)" to start_pos.
group_boundaries : None | array-like
List of of boundaries between groups at which point a "group_sep" will
be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
group_sep : float
Group separation angle in degrees. See "group_boundaries".
Returns
-------
node_angles : array, shape=(len(node_names,))
Node angles in degrees.
"""
n_nodes = len(node_names)
if len(node_order) != n_nodes:
raise ValueError('node_order has to be the same length as node_names')
if group_boundaries is not None:
boundaries = np.array(group_boundaries, dtype=np.int)
if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
raise ValueError('"group_boundaries" has to be between 0 and '
'n_nodes - 1.')
if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
raise ValueError('"group_boundaries" must have non-decreasing '
'values.')
n_group_sep = len(group_boundaries)
else:
n_group_sep = 0
boundaries = None
# convert it to a list with indices
node_order = [node_order.index(name) for name in node_names]
node_order = np.array(node_order)
if len(np.unique(node_order)) != n_nodes:
raise ValueError('node_order has repeated entries')
node_sep = (360. - n_group_sep * group_sep) / n_nodes
if start_between:
start_pos += node_sep / 2
if boundaries is not None and boundaries[0] == 0:
# special case when a group separator is at the start
start_pos += group_sep / 2
boundaries = boundaries[1:] if n_group_sep > 1 else None
node_angles = np.ones(n_nodes, dtype=np.float) * node_sep
node_angles[0] = start_pos
if boundaries is not None:
node_angles[boundaries] += group_sep
node_angles = np.cumsum(node_angles)[node_order]
return node_angles
def _plot_connectivity_circle_onpick(event, fig=None, axes=None, indices=None,
n_nodes=0, node_angles=None,
ylim=[9, 10]):
"""Isolate connections around a single node when user left clicks a node.
On right click, resets all connections.
"""
if event.inaxes != axes:
return
if event.button == 1: # left click
# click must be near node radius
if not ylim[0] <= event.ydata <= ylim[1]:
return
# all angles in range [0, 2*pi]
node_angles = node_angles % (np.pi * 2)
node = np.argmin(np.abs(event.xdata - node_angles))
patches = event.inaxes.patches
for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
patches[ii].set_visible(node in [x, y])
fig.canvas.draw()
elif event.button == 3: # right click
patches = event.inaxes.patches
for ii in range(np.size(indices, axis=1)):
patches[ii].set_visible(True)
fig.canvas.draw()
def plot_connectivity_circle(con, node_names, indices=None, n_lines=None,
node_angles=None, node_width=None,
node_colors=None, facecolor='black',
textcolor='white', node_edgecolor='black',
linewidth=1.5, colormap='hot', vmin=None,
vmax=None, colorbar=True, title=None,
colorbar_size=0.2, colorbar_pos=(-0.3, 0.1),
fontsize_title=12, fontsize_names=8,
fontsize_colorbar=8, padding=6.,
fig=None, subplot=111, interactive=True,
node_linewidth=2., show=True):
"""Visualize connectivity as a circular graph.
Parameters
----------
con : array
Connectivity scores. Can be a square matrix, or a 1D array. If a 1D
array is provided, "indices" has to be used to define the connection
indices.
node_names : list of str
Node names. The order corresponds to the order in con.
indices : tuple of arrays | None
Two arrays with indices of connections for which the connections
strengths are defined in con. Only needed if con is a 1D array.
n_lines : int | None
If not None, only the n_lines strongest connections (strength=abs(con))
are drawn.
node_angles : array, shape=(len(node_names,)) | None
Array with node positions in degrees. If None, the nodes are equally
spaced on the circle. See mne.viz.circular_layout.
node_width : float | None
Width of each node in degrees. If None, the minimum angle between any
two nodes is used as the width.
node_colors : list of tuples | list of str
List with the color to use for each node. If fewer colors than nodes
are provided, the colors will be repeated. Any color supported by
matplotlib can be used, e.g., RGBA tuples, named colors.
facecolor : str
Color to use for background. See matplotlib.colors.
textcolor : str
Color to use for text. See matplotlib.colors.
node_edgecolor : str
Color to use for lines around nodes. See matplotlib.colors.
linewidth : float
Line width to use for connections.
colormap : str
Colormap to use for coloring the connections.
vmin : float | None
Minimum value for colormap. If None, it is determined automatically.
vmax : float | None
Maximum value for colormap. If None, it is determined automatically.
colorbar : bool
Display a colorbar or not.
title : str
The figure title.
colorbar_size : float
Size of the colorbar.
colorbar_pos : 2-tuple
Position of the colorbar.
fontsize_title : int
Font size to use for title.
fontsize_names : int
Font size to use for node names.
fontsize_colorbar : int
Font size to use for colorbar.
padding : float
Space to add around figure to accommodate long labels.
fig : None | instance of matplotlib.pyplot.Figure
The figure to use. If None, a new figure with the specified background
color will be created.
subplot : int | 3-tuple
Location of the subplot when creating figures with multiple plots. E.g.
121 or (1, 2, 1) for 1 row, 2 columns, plot 1. See
matplotlib.pyplot.subplot.
interactive : bool
When enabled, left-click on a node to show only connections to that
node. Right-click shows all connections.
node_linewidth : float
Line with for nodes.
show : bool
Show figure if True.
Returns
-------
fig : instance of matplotlib.pyplot.Figure
The figure handle.
axes : instance of matplotlib.axes.PolarAxesSubplot
The subplot handle.
Notes
-----
This code is based on the circle graph example by Nicolas P. Rougier
http://www.labri.fr/perso/nrougier/coding/.
By default, :func:`matplotlib.pyplot.savefig` does not take ``facecolor``
into account when saving, even if set when a figure is generated. This
can be addressed via, e.g.::
>>> fig.savefig(fname_fig, facecolor='black') # doctest:+SKIP
If ``facecolor`` is not set via :func:`matplotlib.pyplot.savefig`, the
figure labels, title, and legend may be cut off in the output figure.
"""
import matplotlib.pyplot as plt
import matplotlib.path as m_path
import matplotlib.patches as m_patches
n_nodes = len(node_names)
if node_angles is not None:
if len(node_angles) != n_nodes:
raise ValueError('node_angles has to be the same length '
'as node_names')
# convert it to radians
node_angles = node_angles * np.pi / 180
else:
# uniform layout on unit circle
node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
if node_width is None:
# widths correspond to the minimum angle between two nodes
dist_mat = node_angles[None, :] - node_angles[:, None]
dist_mat[np.diag_indices(n_nodes)] = 1e9
node_width = np.min(np.abs(dist_mat))
else:
node_width = node_width * np.pi / 180
if node_colors is not None:
if len(node_colors) < n_nodes:
node_colors = cycle(node_colors)
else:
# assign colors using colormap
try:
spectral = plt.cm.spectral
except AttributeError:
spectral = plt.cm.Spectral
node_colors = [spectral(i / float(n_nodes))
for i in range(n_nodes)]
# handle 1D and 2D connectivity information
if con.ndim == 1:
if indices is None:
raise ValueError('indices has to be provided if con.ndim == 1')
elif con.ndim == 2:
if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
raise ValueError('con has to be 1D or a square matrix')
# we use the lower-triangular part
indices = np.tril_indices(n_nodes, -1)
con = con[indices]
else:
raise ValueError('con has to be 1D or a square matrix')
# get the colormap
if isinstance(colormap, string_types):
colormap = plt.get_cmap(colormap)
# Make figure background the same colors as axes
if fig is None:
fig = plt.figure(figsize=(8, 8), facecolor=facecolor)
# Use a polar axes
if not isinstance(subplot, tuple):
subplot = (subplot,)
axes = plt.subplot(*subplot, polar=True)
_set_ax_facecolor(axes, facecolor)
# No ticks, we'll put our own
plt.xticks([])
plt.yticks([])
# Set y axes limit, add additional space if requested
plt.ylim(0, 10 + padding)
# Remove the black axes border which may obscure the labels
axes.spines['polar'].set_visible(False)
# Draw lines between connected nodes, only draw the strongest connections
if n_lines is not None and len(con) > n_lines:
con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
else:
con_thresh = 0.
# get the connections which we are drawing and sort by connection strength
# this will allow us to draw the strongest connections first
con_abs = np.abs(con)
con_draw_idx = np.where(con_abs >= con_thresh)[0]
con = con[con_draw_idx]
con_abs = con_abs[con_draw_idx]
indices = [ind[con_draw_idx] for ind in indices]
# now sort them
sort_idx = np.argsort(con_abs)
con_abs = con_abs[sort_idx]
con = con[sort_idx]
indices = [ind[sort_idx] for ind in indices]
# Get vmin vmax for color scaling
if vmin is None:
vmin = np.min(con[np.abs(con) >= con_thresh])
if vmax is None:
vmax = np.max(con)
vrange = vmax - vmin
# We want to add some "noise" to the start and end position of the
# edges: We modulate the noise with the number of connections of the
# node and the connection strength, such that the strongest connections
# are closer to the node center
nodes_n_con = np.zeros((n_nodes), dtype=np.int)
for i, j in zip(indices[0], indices[1]):
nodes_n_con[i] += 1
nodes_n_con[j] += 1
# initialize random number generator so plot is reproducible
rng = np.random.mtrand.RandomState(seed=0)
n_con = len(indices[0])
noise_max = 0.25 * node_width
start_noise = rng.uniform(-noise_max, noise_max, n_con)
end_noise = rng.uniform(-noise_max, noise_max, n_con)
nodes_n_con_seen = np.zeros_like(nodes_n_con)
for i, (start, end) in enumerate(zip(indices[0], indices[1])):
nodes_n_con_seen[start] += 1
nodes_n_con_seen[end] += 1
start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) /
float(nodes_n_con[start]))
end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) /
float(nodes_n_con[end]))
# scale connectivity for colormap (vmin<=>0, vmax<=>1)
con_val_scaled = (con - vmin) / vrange
# Finally, we draw the connections
for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
# Start point
t0, r0 = node_angles[i], 10
# End point
t1, r1 = node_angles[j], 10
# Some noise in start and end point
t0 += start_noise[pos]
t1 += end_noise[pos]
verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4,
m_path.Path.LINETO]
path = m_path.Path(verts, codes)
color = colormap(con_val_scaled[pos])
# Actual line
patch = m_patches.PathPatch(path, fill=False, edgecolor=color,
linewidth=linewidth, alpha=1.)
axes.add_patch(patch)
# Draw ring with colored nodes
height = np.ones(n_nodes) * 1.0
bars = axes.bar(node_angles, height, width=node_width, bottom=9,
edgecolor=node_edgecolor, lw=node_linewidth,
facecolor='.9', align='center')
for bar, color in zip(bars, node_colors):
bar.set_facecolor(color)
# Draw node labels
angles_deg = 180 * node_angles / np.pi
for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
if angle_deg >= 270:
ha = 'left'
else:
# Flip the label, so text is always upright
angle_deg += 180
ha = 'right'
axes.text(angle_rad, 10.4, name, size=fontsize_names,
rotation=angle_deg, rotation_mode='anchor',
horizontalalignment=ha, verticalalignment='center',
color=textcolor)
if title is not None:
plt.title(title, color=textcolor, fontsize=fontsize_title,
axes=axes)
if colorbar:
sm = plt.cm.ScalarMappable(cmap=colormap,
norm=plt.Normalize(vmin, vmax))
sm.set_array(np.linspace(vmin, vmax))
cb = plt.colorbar(sm, ax=axes, use_gridspec=False,
shrink=colorbar_size,
anchor=colorbar_pos)
cb_yticks = plt.getp(cb.ax.axes, 'yticklabels')
cb.ax.tick_params(labelsize=fontsize_colorbar)
plt.setp(cb_yticks, color=textcolor)
# Add callback for interaction
if interactive:
callback = partial(_plot_connectivity_circle_onpick, fig=fig,
axes=axes, indices=indices, n_nodes=n_nodes,
node_angles=node_angles)
fig.canvas.mpl_connect('button_press_event', callback)
plt_show(show)
return fig, axes
|