"""
==========================================
From raw data to dSPM on SPM Faces dataset
==========================================

Runs a full pipeline using MNE-Python:
- artifact removal
- averaging Epochs
- forward model computation
- source reconstruction using dSPM on the contrast : "faces - scrambled"

"""
print(__doc__)

# Authors: Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
#          Denis Engemann <denis.engemann@gmail.com>
#
# License: BSD (3-clause)

import matplotlib.pyplot as plt

import mne
from mne.datasets import spm_face
from mne.preprocessing import ICA, create_eog_epochs
from mne import io
from mne.minimum_norm import make_inverse_operator, apply_inverse


data_path = spm_face.data_path()
subjects_dir = data_path + '/subjects'

###############################################################################
# Load and filter data, set up epochs

raw_fname = data_path + '/MEG/spm/SPM_CTF_MEG_example_faces%d_3D_raw.fif'

raw = io.Raw(raw_fname % 1, preload=True)  # Take first run

picks = mne.pick_types(raw.info, meg=True, exclude='bads')
raw.filter(1, 30, method='iir')

events = mne.find_events(raw, stim_channel='UPPT001')

# plot the events to get an idea of the paradigm
mne.viz.plot_events(events, raw.info['sfreq'])

event_ids = {"faces": 1, "scrambled": 2}

tmin, tmax = -0.2, 0.6
baseline = None  # no baseline as high-pass is applied
reject = dict(mag=5e-12)

epochs = mne.Epochs(raw, events, event_ids, tmin, tmax,  picks=picks,
                    baseline=baseline, preload=True, reject=reject)

# Fit ICA, find and remove major artifacts
ica = ICA(n_components=0.95).fit(raw, decim=6, reject=reject)

# compute correlation scores, get bad indices sorted by score
eog_epochs = create_eog_epochs(raw, ch_name='MRT31-2908', reject=reject)
eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name='MRT31-2908')
ica.plot_scores(eog_scores, eog_inds)  # see scores the selection is based on
ica.plot_components(eog_inds)  # view topographic sensitivity of components
ica.exclude += eog_inds[:1]  # we saw the 2nd ECG component looked too dipolar
ica.plot_overlay(eog_epochs.average())  # inspect artifact removal
epochs_cln = ica.apply(epochs, copy=True)  # clean data, default in place

evoked = [epochs_cln[k].average() for k in event_ids]

contrast = evoked[1] - evoked[0]

evoked.append(contrast)

for e in evoked:
    e.plot(ylim=dict(mag=[-400, 400]))

plt.show()

# estimate noise covarariance
noise_cov = mne.compute_covariance(epochs_cln, tmax=0)

###############################################################################
# Visualize fields on MEG helmet

trans_fname = data_path + ('/MEG/spm/SPM_CTF_MEG_example_faces1_3D_'
                           'raw-trans.fif')

maps = mne.make_field_map(evoked[0], trans_fname=trans_fname,
                          subject='spm', subjects_dir=subjects_dir,
                          n_jobs=1)


evoked[0].plot_field(maps, time=0.170)


###############################################################################
# Compute forward model

# Make source space
src = mne.setup_source_space('spm', spacing='oct6', subjects_dir=subjects_dir,
                             overwrite=True)

mri = trans_fname
bem = data_path + '/subjects/spm/bem/spm-5120-5120-5120-bem-sol.fif'
forward = mne.make_forward_solution(contrast.info, mri=mri, src=src, bem=bem)
forward = mne.convert_forward_solution(forward, surf_ori=True)

###############################################################################
# Compute inverse solution

snr = 3.0
lambda2 = 1.0 / snr ** 2
method = 'dSPM'

inverse_operator = make_inverse_operator(contrast.info, forward, noise_cov,
                                         loose=0.2, depth=0.8)

# Compute inverse solution on contrast
stc = apply_inverse(contrast, inverse_operator, lambda2, method,
                    pick_normal=False)
# stc.save('spm_%s_dSPM_inverse' % constrast.comment)

# plot constrast
# Plot brain in 3D with PySurfer if available. Note that the subject name
# is already known by the SourceEstimate stc object.
brain = stc.plot(surface='inflated', hemi='both', subjects_dir=subjects_dir)
brain.set_time(170.0)  # milliseconds
brain.scale_data_colormap(fmin=4, fmid=6, fmax=8, transparent=True)
brain.show_view('ventral')
# brain.save_image('dSPM_map.png')
