"""Functions to plot on circle as for connectivity."""

# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
#          Denis Engemann <denis.engemann@gmail.com>
#          Martin Luessi <mluessi@nmr.mgh.harvard.edu>
#
# License: Simplified BSD


from itertools import cycle
from functools import partial

import numpy as np

from .utils import plt_show, _get_cmap
from ..utils import _validate_type


def circular_layout(node_names, node_order, start_pos=90, start_between=True,
                    group_boundaries=None, group_sep=10):
    """Create layout arranging nodes on a circle.

    Parameters
    ----------
    node_names : list of str
        Node names.
    node_order : list of str
        List with node names defining the order in which the nodes are
        arranged. Must have the elements as node_names but the order can be
        different. The nodes are arranged clockwise starting at "start_pos"
        degrees.
    start_pos : float
        Angle in degrees that defines where the first node is plotted.
    start_between : bool
        If True, the layout starts with the position between the nodes. This is
        the same as adding "180. / len(node_names)" to start_pos.
    group_boundaries : None | array-like
        List of of boundaries between groups at which point a "group_sep" will
        be inserted. E.g. "[0, len(node_names) / 2]" will create two groups.
    group_sep : float
        Group separation angle in degrees. See "group_boundaries".

    Returns
    -------
    node_angles : array, shape=(n_node_names,)
        Node angles in degrees.
    """
    n_nodes = len(node_names)

    if len(node_order) != n_nodes:
        raise ValueError('node_order has to be the same length as node_names')

    if group_boundaries is not None:
        boundaries = np.array(group_boundaries, dtype=np.int64)
        if np.any(boundaries >= n_nodes) or np.any(boundaries < 0):
            raise ValueError('"group_boundaries" has to be between 0 and '
                             'n_nodes - 1.')
        if len(boundaries) > 1 and np.any(np.diff(boundaries) <= 0):
            raise ValueError('"group_boundaries" must have non-decreasing '
                             'values.')
        n_group_sep = len(group_boundaries)
    else:
        n_group_sep = 0
        boundaries = None

    # convert it to a list with indices
    node_order = [node_order.index(name) for name in node_names]
    node_order = np.array(node_order)
    if len(np.unique(node_order)) != n_nodes:
        raise ValueError('node_order has repeated entries')

    node_sep = (360. - n_group_sep * group_sep) / n_nodes

    if start_between:
        start_pos += node_sep / 2

        if boundaries is not None and boundaries[0] == 0:
            # special case when a group separator is at the start
            start_pos += group_sep / 2
            boundaries = boundaries[1:] if n_group_sep > 1 else None

    node_angles = np.ones(n_nodes, dtype=np.float64) * node_sep
    node_angles[0] = start_pos
    if boundaries is not None:
        node_angles[boundaries] += group_sep

    node_angles = np.cumsum(node_angles)[node_order]

    return node_angles


def _plot_connectivity_circle_onpick(event, fig=None, ax=None, indices=None,
                                     n_nodes=0, node_angles=None,
                                     ylim=[9, 10]):
    """Isolate connections around a single node when user left clicks a node.

    On right click, resets all connections.
    """
    if event.inaxes != ax:
        return

    if event.button == 1:  # left click
        # click must be near node radius
        if not ylim[0] <= event.ydata <= ylim[1]:
            return

        # all angles in range [0, 2*pi]
        node_angles = node_angles % (np.pi * 2)
        node = np.argmin(np.abs(event.xdata - node_angles))

        patches = event.inaxes.patches
        for ii, (x, y) in enumerate(zip(indices[0], indices[1])):
            patches[ii].set_visible(node in [x, y])
        fig.canvas.draw()
    elif event.button == 3:  # right click
        patches = event.inaxes.patches
        for ii in range(np.size(indices, axis=1)):
            patches[ii].set_visible(True)
        fig.canvas.draw()


def _plot_connectivity_circle(con, node_names, indices=None, n_lines=None,
                              node_angles=None, node_width=None,
                              node_height=None, node_colors=None,
                              facecolor='black', textcolor='white',
                              node_edgecolor='black', linewidth=1.5,
                              colormap='hot', vmin=None, vmax=None,
                              colorbar=True, title=None,
                              colorbar_size=None, colorbar_pos=None,
                              fontsize_title=12, fontsize_names=8,
                              fontsize_colorbar=8, padding=6.,
                              ax=None, interactive=True,
                              node_linewidth=2., show=True):
    import matplotlib.pyplot as plt
    import matplotlib.path as m_path
    import matplotlib.patches as m_patches
    from matplotlib.projections.polar import PolarAxes

    _validate_type(ax, (None, PolarAxes))

    n_nodes = len(node_names)

    if node_angles is not None:
        if len(node_angles) != n_nodes:
            raise ValueError('node_angles has to be the same length '
                             'as node_names')
        # convert it to radians
        node_angles = node_angles * np.pi / 180
    else:
        # uniform layout on unit circle
        node_angles = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)

    if node_width is None:
        # widths correspond to the minimum angle between two nodes
        dist_mat = node_angles[None, :] - node_angles[:, None]
        dist_mat[np.diag_indices(n_nodes)] = 1e9
        node_width = np.min(np.abs(dist_mat))
    else:
        node_width = node_width * np.pi / 180

    if node_height is None:
        node_height = 1.0

    if node_colors is not None:
        if len(node_colors) < n_nodes:
            node_colors = cycle(node_colors)
    else:
        # assign colors using colormap
        try:
            spectral = plt.cm.spectral
        except AttributeError:
            spectral = plt.cm.Spectral
        node_colors = [spectral(i / float(n_nodes))
                       for i in range(n_nodes)]

    # handle 1D and 2D connectivity information
    if con.ndim == 1:
        if indices is None:
            raise ValueError('indices has to be provided if con.ndim == 1')
    elif con.ndim == 2:
        if con.shape[0] != n_nodes or con.shape[1] != n_nodes:
            raise ValueError('con has to be 1D or a square matrix')
        # we use the lower-triangular part
        indices = np.tril_indices(n_nodes, -1)
        con = con[indices]
    else:
        raise ValueError('con has to be 1D or a square matrix')

    # get the colormap
    colormap = _get_cmap(colormap)

    # Use a polar axes
    if ax is None:
        fig = plt.figure(figsize=(8, 8), facecolor=facecolor)
        ax = fig.add_subplot(polar=True)
    else:
        fig = ax.figure
    ax.set_facecolor(facecolor)

    # No ticks, we'll put our own
    ax.set_xticks([])
    ax.set_yticks([])

    # Set y axes limit, add additional space if requested
    ax.set_ylim(0, 10 + padding)

    # Remove the black axes border which may obscure the labels
    ax.spines['polar'].set_visible(False)

    # Draw lines between connected nodes, only draw the strongest connections
    if n_lines is not None and len(con) > n_lines:
        con_thresh = np.sort(np.abs(con).ravel())[-n_lines]
    else:
        con_thresh = 0.

    # get the connections which we are drawing and sort by connection strength
    # this will allow us to draw the strongest connections first
    con_abs = np.abs(con)
    con_draw_idx = np.where(con_abs >= con_thresh)[0]

    con = con[con_draw_idx]
    con_abs = con_abs[con_draw_idx]
    indices = [ind[con_draw_idx] for ind in indices]

    # now sort them
    sort_idx = np.argsort(con_abs)
    del con_abs
    con = con[sort_idx]
    indices = [ind[sort_idx] for ind in indices]

    # Get vmin vmax for color scaling
    if vmin is None:
        vmin = np.min(con[np.abs(con) >= con_thresh])
    if vmax is None:
        vmax = np.max(con)
    vrange = vmax - vmin

    # We want to add some "noise" to the start and end position of the
    # edges: We modulate the noise with the number of connections of the
    # node and the connection strength, such that the strongest connections
    # are closer to the node center
    nodes_n_con = np.zeros((n_nodes), dtype=np.int64)
    for i, j in zip(indices[0], indices[1]):
        nodes_n_con[i] += 1
        nodes_n_con[j] += 1

    # initialize random number generator so plot is reproducible
    rng = np.random.mtrand.RandomState(0)

    n_con = len(indices[0])
    noise_max = 0.25 * node_width
    start_noise = rng.uniform(-noise_max, noise_max, n_con)
    end_noise = rng.uniform(-noise_max, noise_max, n_con)

    nodes_n_con_seen = np.zeros_like(nodes_n_con)
    for i, (start, end) in enumerate(zip(indices[0], indices[1])):
        nodes_n_con_seen[start] += 1
        nodes_n_con_seen[end] += 1

        start_noise[i] *= ((nodes_n_con[start] - nodes_n_con_seen[start]) /
                           float(nodes_n_con[start]))
        end_noise[i] *= ((nodes_n_con[end] - nodes_n_con_seen[end]) /
                         float(nodes_n_con[end]))

    # scale connectivity for colormap (vmin<=>0, vmax<=>1)
    con_val_scaled = (con - vmin) / vrange

    # Finally, we draw the connections
    for pos, (i, j) in enumerate(zip(indices[0], indices[1])):
        # Start point
        t0, r0 = node_angles[i], 10

        # End point
        t1, r1 = node_angles[j], 10

        # Some noise in start and end point
        t0 += start_noise[pos]
        t1 += end_noise[pos]

        verts = [(t0, r0), (t0, 5), (t1, 5), (t1, r1)]
        codes = [m_path.Path.MOVETO, m_path.Path.CURVE4, m_path.Path.CURVE4,
                 m_path.Path.LINETO]
        path = m_path.Path(verts, codes)

        color = colormap(con_val_scaled[pos])

        # Actual line
        patch = m_patches.PathPatch(path, fill=False, edgecolor=color,
                                    linewidth=linewidth, alpha=1.)
        ax.add_patch(patch)

    # Draw ring with colored nodes
    height = np.ones(n_nodes) * node_height
    bars = ax.bar(node_angles, height, width=node_width, bottom=9,
                  edgecolor=node_edgecolor, lw=node_linewidth,
                  facecolor='.9', align='center')

    for bar, color in zip(bars, node_colors):
        bar.set_facecolor(color)

    # Draw node labels
    angles_deg = 180 * node_angles / np.pi
    for name, angle_rad, angle_deg in zip(node_names, node_angles, angles_deg):
        if angle_deg >= 270:
            ha = 'left'
        else:
            # Flip the label, so text is always upright
            angle_deg += 180
            ha = 'right'

        ax.text(angle_rad, 9.4 + node_height, name, size=fontsize_names,
                rotation=angle_deg, rotation_mode='anchor',
                horizontalalignment=ha, verticalalignment='center',
                color=textcolor)

    if title is not None:
        ax.set_title(title, color=textcolor, fontsize=fontsize_title)

    if colorbar:
        sm = plt.cm.ScalarMappable(cmap=colormap,
                                   norm=plt.Normalize(vmin, vmax))
        sm.set_array(np.linspace(vmin, vmax))
        colorbar_kwargs = dict()
        if colorbar_size is not None:
            colorbar_kwargs.update(shrink=colorbar_size)
        if colorbar_pos is not None:
            colorbar_kwargs.update(anchor=colorbar_pos)
        cb = fig.colorbar(sm, ax=ax, **colorbar_kwargs)
        cb_yticks = plt.getp(cb.ax.axes, 'yticklabels')
        cb.ax.tick_params(labelsize=fontsize_colorbar)
        plt.setp(cb_yticks, color=textcolor)

    # Add callback for interaction
    if interactive:
        callback = partial(_plot_connectivity_circle_onpick, fig=fig,
                           ax=ax, indices=indices, n_nodes=n_nodes,
                           node_angles=node_angles)

        fig.canvas.mpl_connect('button_press_event', callback)

    plt_show(show)
    return fig, ax


def plot_channel_labels_circle(labels, colors=None, picks=None, **kwargs):
    """Plot labels for each channel in a circle plot.

    .. note:: This primarily makes sense for sEEG channels where each
              channel can be assigned an anatomical label as the electrode
              passes through various brain areas.

    Parameters
    ----------
    labels : dict
        Lists of labels (values) associated with each channel (keys).
    colors : dict
        The color (value) for each label (key).
    picks : list | tuple
        The channels to consider.
    **kwargs : kwargs
        Keyword arguments for
        :func:`mne_connectivity.viz.plot_connectivity_circle`.

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        The figure handle.
    axes : instance of matplotlib.projections.polar.PolarAxes
        The subplot handle.
    """
    from matplotlib.colors import LinearSegmentedColormap

    _validate_type(labels, dict, 'labels')
    _validate_type(colors, (dict, None), 'colors')
    _validate_type(picks, (list, tuple, None), 'picks')
    if picks is not None:
        labels = {k: v for k, v in labels.items() if k in picks}
    ch_names = list(labels.keys())
    all_labels = list(set([label for val in labels.values()
                           for label in val]))
    n_labels = len(all_labels)
    if colors is not None:
        for label in all_labels:
            if label not in colors:
                raise ValueError(f'No color provided for {label} in `colors`')
        # update all_labels, there may be unconnected labels in colors
        all_labels = list(colors.keys())
        n_labels = len(all_labels)
        # make colormap
        label_colors = [colors[label] for label in all_labels]
        node_colors = ['black'] * len(ch_names) + label_colors
        label_cmap = LinearSegmentedColormap.from_list(
            'label_cmap', label_colors, N=len(label_colors))
    else:
        node_colors = None

    node_names = ch_names + all_labels
    con = np.zeros((len(node_names), len(node_names))) * np.nan
    for idx, ch_name in enumerate(ch_names):
        for label in labels[ch_name]:
            node_idx = node_names.index(label)
            label_color = all_labels.index(label) / n_labels
            con[idx, node_idx] = con[node_idx, idx] = label_color  # symmetric
    # plot
    node_order = ch_names + all_labels[::-1]
    node_angles = circular_layout(node_names, node_order, start_pos=90,
                                  group_boundaries=[0, len(ch_names)])
    # provide defaults but don't overwrite
    if 'node_angles' not in kwargs:
        kwargs.update(node_angles=node_angles)
    if 'colorbar' not in kwargs:
        kwargs.update(colorbar=False)
    if 'node_colors' not in kwargs:
        kwargs.update(node_colors=node_colors)
    if 'vmin' not in kwargs:
        kwargs.update(vmin=0)
    if 'vmax' not in kwargs:
        kwargs.update(vmax=1)
    if 'colormap' not in kwargs:
        kwargs.update(colormap=label_cmap)
    return _plot_connectivity_circle(con, node_names, **kwargs)
