# -*- coding: UTF-8 -*-
# Authors: Thomas Hartmann <thomas.hartmann@th-ht.de>
#          Dirk Gütlin <dirk.guetlin@stud.sbg.ac.at>
#
# License: BSD (3-clause)

import mne
import os.path
import pytest
import copy
import itertools
import numpy as np
from mne.datasets import testing
from mne.io.fieldtrip.utils import NOINFO_WARNING, _create_events
from mne.utils import _check_pandas_installed, requires_h5py
from mne.io.fieldtrip.tests.helpers import (check_info_fields, get_data_paths,
                                            get_raw_data, get_epochs,
                                            get_evoked, _has_h5py,
                                            pandas_not_found_warning_msg,
                                            get_raw_info, check_data,
                                            assert_warning_in_record)

# missing: KIT: biggest problem here is that the channels do not have the same
# names.
# EGI: no calibration done in FT. so data is VERY different

all_systems_raw = ['neuromag306', 'CTF', 'CNT', 'BTI', 'eximia']
all_systems_epochs = ['neuromag306', 'CTF', 'CNT']
all_versions = ['v7', 'v73']
use_info = [True, False]
all_test_params_raw = list(itertools.product(all_systems_raw, all_versions,
                                             use_info))
all_test_params_epochs = list(itertools.product(all_systems_epochs,
                                                all_versions,
                                                use_info))

no_info_warning = {'expected_warning': RuntimeWarning,
                   'message': NOINFO_WARNING}


@testing.requires_testing_data
# Reading the sample CNT data results in a RuntimeWarning because it cannot
# parse the measurement date. We need to ignore that warning.
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.parametrize('cur_system, version, use_info',
                         all_test_params_epochs)
def test_read_evoked(cur_system, version, use_info):
    """Test comparing reading an Evoked object and the FieldTrip version."""
    test_data_folder_ft = get_data_paths(cur_system)
    mne_avg = get_evoked(cur_system)
    if use_info:
        info = get_raw_info(cur_system)
        pytestwarning = {'expected_warning': None}
    else:
        info = None
        pytestwarning = no_info_warning

    cur_fname = os.path.join(test_data_folder_ft,
                             'averaged_%s.mat' % (version,))
    if version == 'v73' and not _has_h5py():
        with pytest.raises(ImportError):
            mne.io.read_evoked_fieldtrip(cur_fname, info)
        return

    with pytest.warns(**pytestwarning):
        avg_ft = mne.io.read_evoked_fieldtrip(cur_fname, info)

    mne_data = mne_avg.data[:, :-1]
    ft_data = avg_ft.data

    check_data(mne_data, ft_data, cur_system)
    check_info_fields(mne_avg, avg_ft, use_info)


@testing.requires_testing_data
# Reading the sample CNT data results in a RuntimeWarning because it cannot
# parse the measurement date. We need to ignore that warning.
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.parametrize('cur_system, version, use_info',
                         all_test_params_epochs)
def test_read_epochs(cur_system, version, use_info):
    """Test comparing reading an Epochs object and the FieldTrip version."""
    pandas = _check_pandas_installed(strict=False)
    has_pandas = pandas is not False
    test_data_folder_ft = get_data_paths(cur_system)
    mne_epoched = get_epochs(cur_system)
    if use_info:
        info = get_raw_info(cur_system)
        pytestwarning = {'expected_warning': None}
    else:
        info = None
        pytestwarning = no_info_warning

    cur_fname = os.path.join(test_data_folder_ft,
                             'epoched_%s.mat' % (version,))
    if has_pandas:
        if version == 'v73' and not _has_h5py():
            with pytest.raises(ImportError):
                mne.io.read_epochs_fieldtrip(cur_fname, info)
            return
        with pytest.warns(**pytestwarning):
            epoched_ft = mne.io.read_epochs_fieldtrip(cur_fname, info)
        assert isinstance(epoched_ft.metadata, pandas.DataFrame)
    else:
        with pytest.warns(None) as warn_record:
            if version == 'v73' and not _has_h5py():
                with pytest.raises(ImportError):
                    mne.io.read_epochs_fieldtrip(cur_fname, info)
                return
            epoched_ft = mne.io.read_epochs_fieldtrip(cur_fname, info)
            assert epoched_ft.metadata is None
            assert_warning_in_record(pandas_not_found_warning_msg, warn_record)
            if pytestwarning['expected_warning'] is not None:
                assert_warning_in_record(pytestwarning['message'], warn_record)

    mne_data = mne_epoched.get_data()[:, :, :-1]
    ft_data = epoched_ft.get_data()

    check_data(mne_data, ft_data, cur_system)
    check_info_fields(mne_epoched, epoched_ft, use_info)


@testing.requires_testing_data
# Reading the sample CNT data results in a RuntimeWarning because it cannot
# parse the measurement date. We need to ignore that warning.
@pytest.mark.filterwarnings('ignore::RuntimeWarning')
@pytest.mark.parametrize('cur_system, version, use_info', all_test_params_raw)
def test_raw(cur_system, version, use_info):
    """Test comparing reading a raw fiff file and the FieldTrip version."""
    # Load the raw fiff file with mne
    test_data_folder_ft = get_data_paths(cur_system)
    raw_fiff_mne = get_raw_data(cur_system, drop_extra_chs=True)
    if use_info:
        info = get_raw_info(cur_system)
        pytestwarning = {'expected_warning': None}
    else:
        info = None
        pytestwarning = no_info_warning

    cur_fname = os.path.join(test_data_folder_ft,
                             'raw_%s.mat' % (version,))

    if version == 'v73' and not _has_h5py():
        with pytest.raises(ImportError):
            mne.io.read_raw_fieldtrip(cur_fname, info)
        return
    with pytest.warns(**pytestwarning):
        raw_fiff_ft = mne.io.read_raw_fieldtrip(cur_fname, info)

    if cur_system == 'BTI' and not use_info:
        raw_fiff_ft.drop_channels(['MzA', 'MxA', 'MyaA',
                                   'MyA', 'MxaA', 'MzaA'])

    if cur_system == 'eximia' and not use_info:
        raw_fiff_ft.drop_channels(['TRIG2', 'TRIG1', 'GATE'])

    # Check that the data was loaded correctly
    check_data(raw_fiff_mne.get_data(),
               raw_fiff_ft.get_data(),
               cur_system)

    # Check info field
    check_info_fields(raw_fiff_mne, raw_fiff_ft, use_info)


@testing.requires_testing_data
def test_load_epoched_as_raw():
    """Test whether exception is thrown when loading epochs as raw."""
    test_data_folder_ft = get_data_paths('neuromag306')
    info = get_raw_info('neuromag306')
    cur_fname = os.path.join(test_data_folder_ft, 'epoched_v7.mat')

    with pytest.raises(RuntimeError):
        mne.io.read_raw_fieldtrip(cur_fname, info)


@testing.requires_testing_data
def test_invalid_trialinfocolumn():
    """Test for exceptions when using wrong values for trialinfo parameter."""
    test_data_folder_ft = get_data_paths('neuromag306')
    info = get_raw_info('neuromag306')
    cur_fname = os.path.join(test_data_folder_ft, 'epoched_v7.mat')

    with pytest.raises(ValueError):
        mne.io.read_epochs_fieldtrip(cur_fname, info, trialinfo_column=-1)

    with pytest.raises(ValueError):
        mne.io.read_epochs_fieldtrip(cur_fname, info, trialinfo_column=3)


@testing.requires_testing_data
def test_create_events():
    """Test 2dim trialinfo fields."""
    from mne.externals.pymatreader import read_mat

    test_data_folder_ft = get_data_paths('neuromag306')
    cur_fname = os.path.join(test_data_folder_ft, 'epoched_v7.mat')
    original_data = read_mat(cur_fname, ['data', ])

    new_data = copy.deepcopy(original_data)
    new_data['trialinfo'] = np.array([[1, 2, 3, 4],
                                      [1, 2, 3, 4],
                                      [1, 2, 3, 4]])

    with pytest.raises(ValueError):
        _create_events(new_data, -1)

    for cur_col in np.arange(4):
        evts = _create_events(new_data, cur_col)
        assert np.all(evts[:, 2] == cur_col + 1)

    with pytest.raises(ValueError):
        _create_events(new_data, 4)


@testing.requires_testing_data
@pytest.mark.parametrize('version', all_versions)
@requires_h5py
def test_one_channel_elec_bug(version):
    """Test if loading data having only one elec in the elec field works."""
    fname = os.path.join(mne.datasets.testing.data_path(), 'fieldtrip',
                         'one_channel_elec_bug_data_%s.mat' % (version, ))

    with pytest.warns(**no_info_warning):
        mne.io.read_raw_fieldtrip(fname, info=None)
