# -*- coding: utf-8 -*-
"""Dipole viz specific functions."""

# Authors: Eric Larson <larson.eric.d@gmail.com>
#
# License: Simplified BSD

import os.path as op
import numpy as np

from .utils import plt_show, _validate_if_list_of_axes
from .._freesurfer import _get_head_surface, _estimate_talxfm_rigid
from ..surface import read_surface
from ..transforms import apply_trans, invert_transform, _get_trans
from ..utils import _validate_type, _check_option, get_subjects_dir


def _check_concat_dipoles(dipole):
    from ..dipole import Dipole, _concatenate_dipoles
    if not isinstance(dipole, Dipole):
        dipole = _concatenate_dipoles(dipole)
    return dipole


def _plot_dipole_mri_outlines(dipoles, *, subject, trans, ax, subjects_dir,
                              color, scale, coord_frame, show, block,
                              head_source, title, surf, width):
    from matplotlib.collections import LineCollection, PatchCollection
    from matplotlib.patches import Circle
    from scipy.spatial import ConvexHull
    import matplotlib.pyplot as plt
    extra = 'when mode is "outlines"'
    trans = _get_trans(trans, fro='head', to='mri')[0]
    _check_option('coord_frame', coord_frame, ['head', 'mri', 'mri_rotated'],
                  extra=extra)
    _validate_type(surf, (str, None), 'surf')
    _check_option('surf', surf, ('white', 'pial', None))
    if ax is None:
        _, ax = plt.subplots(
            1, 3, figsize=(7, 2.5), squeeze=True, constrained_layout=True)
    _validate_if_list_of_axes(ax, 3, name='ax')
    dipoles = _check_concat_dipoles(dipoles)
    color = 'r' if color is None else color
    scale = 0.03 if scale is None else scale
    width = 0.015 if width is None else width
    fig = ax[0].figure
    surfs = dict()
    hemis = ('lh', 'rh')
    if surf is not None:
        for hemi in hemis:
            surfs[hemi] = read_surface(
                op.join(subjects_dir, subject, 'surf',
                        f'{hemi}.{surf}'), return_dict=True)[2]
            surfs[hemi]['rr'] /= 1000.
    subjects_dir = get_subjects_dir(subjects_dir)
    surfs['head'] = _get_head_surface(head_source, subject, subjects_dir)
    del head_source
    mri_trans = head_trans = np.eye(4)
    if coord_frame in ('mri', 'mri_rotated'):
        head_trans = trans['trans']
        if coord_frame == 'mri_rotated':
            rot = _estimate_talxfm_rigid(subject, subjects_dir)
            rot[:3, 3] = 0.
            head_trans = rot @ head_trans
            mri_trans = rot @ mri_trans
    else:
        assert coord_frame == 'head'
        mri_trans = invert_transform(trans)['trans']
    for s in surfs.values():
        s['rr'] = 1000 * apply_trans(mri_trans, s['rr'])
    del mri_trans
    levels = dict()
    if surf is not None:
        use_rr = np.concatenate([surfs[key]['rr'] for key in hemis])
    else:
        use_rr = surfs['head']['rr']
    views = [('Axial', 'XY'), ('Coronal', 'XZ'), ('Sagittal', 'YZ')]
    # axial: 25% up the Z axis
    axial = float(np.percentile(use_rr[:, 2], 20.))
    coronal = float(np.percentile(use_rr[:, 1], 55.))
    for key in hemis + ('head',):
        levels[key] = dict(Axial=axial, Coronal=coronal)
    if surf is not None:
        levels['rh']['Sagittal'] = float(
            np.percentile(surfs['rh']['rr'][:, 0], 50))
    levels['head']['Sagittal'] = 0.
    for ax_, (name, coords) in zip(ax, views):
        idx = list(map(dict(X=0, Y=1, Z=2).get, coords))
        miss = np.setdiff1d(np.arange(3), idx)[0]
        pos = 1000 * apply_trans(head_trans, dipoles.pos)
        ori = 1000 * apply_trans(head_trans, dipoles.ori, move=False)
        lims = dict()
        for ii, char in enumerate(coords):
            lim = surfs['head']['rr'][:, idx[ii]]
            lim = np.array([lim.min(), lim.max()])
            lims[char] = lim
        ax_.quiver(
            pos[:, idx[0]], pos[:, idx[1]],
            scale * ori[:, idx[0]], scale * ori[:, idx[1]],
            color=color, pivot='middle', zorder=5,
            scale_units='xy', angles='xy', scale=1.,
            width=width, minshaft=0.5, headwidth=2.5, headlength=2.5,
            headaxislength=2)
        coll = PatchCollection(
            [Circle((x, y), radius=scale * 1000 * width * 6)
             for x, y in zip(pos[:, idx[0]], pos[:, idx[1]])],
            linewidths=0., facecolors=color, zorder=6)
        for key, surf in surfs.items():
            try:
                level = levels[key][name]
            except KeyError:
                continue
            if key != 'head':
                rrs = surf['rr'][:, idx]
                tris = ConvexHull(rrs).simplices
                segments = LineCollection(
                    rrs[:, [0, 1]][tris],
                    linewidths=1, linestyles='-', colors='k', zorder=3,
                    alpha=0.25)
                ax_.add_collection(segments)
            ax_.tricontour(
                surf['rr'][:, idx[0]], surf['rr'][:, idx[1]],
                surf['tris'], surf['rr'][:, miss],
                levels=[level], colors='k', linewidths=1.0, linestyles=['-'],
                zorder=4, alpha=0.5)
            # TODO: this breaks the PatchCollection in MPL
            # for coll in h.collections:
            #     coll.set_clip_on(False)
        ax_.add_collection(coll)
        ax_.set(
            title=name, xlim=lims[coords[0]], ylim=lims[coords[1]],
            xlabel=coords[0] + ' (mm)', ylabel=coords[1] + ' (mm)')
        for spine in ax_.spines.values():
            spine.set_visible(False)
        ax_.grid(True, ls=':', zorder=2)
        ax_.set_aspect('equal')

    if title is not None:
        fig.suptitle(title)
    plt_show(show, block=block)

    return fig


def _plot_dipole_3d(dipoles, *, coord_frame, color, fig, trans, scale, mode):
    from .backends.renderer import _get_renderer
    _check_option('coord_frame', coord_frame, ('head', 'mri'))
    color = 'r' if color is None else color
    scale = 0.005 if scale is None else scale
    renderer = _get_renderer(fig=fig, size=(600, 600))
    pos = dipoles.pos
    ori = dipoles.ori
    if coord_frame != 'head':
        trans = _get_trans(trans, fro='head', to=coord_frame)[0]
        pos = apply_trans(trans, pos)
        ori = apply_trans(trans, ori)

    renderer.sphere(center=pos, color=color, scale=scale)
    if mode == 'arrow':
        x, y, z = pos.T
        u, v, w = ori.T
        renderer.quiver3d(x, y, z, u, v, w, scale=3 * scale,
                          color=color, mode='arrow')
    renderer.show()
    fig = renderer.scene()
    return fig
