#!/usr/bin/env python
# Created by John Travers, Robert Hetland, 2007
""" Test functions for rbf module """
from __future__ import division, print_function, absolute_import


import numpy as np
from numpy.testing import (assert_, assert_array_almost_equal,
                           assert_almost_equal, run_module_suite)
from numpy import linspace, sin, random, exp, allclose
from scipy.interpolate.rbf import Rbf

FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
             'cubic', 'quintic', 'thin-plate', 'linear')


def check_rbf1d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (1D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = linspace(0,10,9)
        y = sin(x)
        rbf = Rbf(x, y, function=function)
        yi = rbf(x)
        assert_array_almost_equal(y, yi)
        assert_almost_equal(rbf(float(x[0])), y[0])
    finally:
        np.seterr(**olderr)


def check_rbf2d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (2D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = random.rand(50,1)*4-2
        y = random.rand(50,1)*4-2
        z = x*exp(-x**2-1j*y**2)
        rbf = Rbf(x, y, z, epsilon=2, function=function)
        zi = rbf(x, y)
        zi.shape = x.shape
        assert_array_almost_equal(z, zi)
    finally:
        np.seterr(**olderr)


def check_rbf3d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (3D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = random.rand(50,1)*4-2
        y = random.rand(50,1)*4-2
        z = random.rand(50,1)*4-2
        d = x*exp(-x**2-y**2)
        rbf = Rbf(x, y, z, d, epsilon=2, function=function)
        di = rbf(x, y, z)
        di.shape = x.shape
        assert_array_almost_equal(di, d)
    finally:
        np.seterr(**olderr)


def test_rbf_interpolation():
    for function in FUNCTIONS:
        yield check_rbf1d_interpolation, function
        yield check_rbf2d_interpolation, function
        yield check_rbf3d_interpolation, function


def check_rbf1d_regularity(function, atol):
    """Check that the Rbf function approximates a smooth function well away
    from the nodes."""
    olderr = np.seterr(all="ignore")
    try:
        x = linspace(0, 10, 9)
        y = sin(x)
        rbf = Rbf(x, y, function=function)
        xi = linspace(0, 10, 100)
        yi = rbf(xi)
        #import matplotlib.pyplot as plt
        #plt.figure()
        #plt.plot(x, y, 'o', xi, sin(xi), ':', xi, yi, '-')
        #plt.title(function)
        #plt.show()
        msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
        assert_(allclose(yi, sin(xi), atol=atol), msg)
    finally:
        np.seterr(**olderr)


def test_rbf_regularity():
    tolerances = {
        'multiquadric': 0.05,
        'inverse multiquadric': 0.02,
        'gaussian': 0.01,
        'cubic': 0.15,
        'quintic': 0.1,
        'thin-plate': 0.1,
        'linear': 0.2
    }
    for function in FUNCTIONS:
        yield check_rbf1d_regularity, function, tolerances.get(function, 1e-2)


def test_default_construction():
    """Check that the Rbf class can be constructed with the default
    multiquadric basis function. Regression test for ticket #1228."""
    x = linspace(0,10,9)
    y = sin(x)
    rbf = Rbf(x, y)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_function_is_callable():
    """Check that the Rbf class can be constructed with function=callable."""
    x = linspace(0,10,9)
    y = sin(x)
    linfunc = lambda x:x
    rbf = Rbf(x, y, function=linfunc)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_two_arg_function_is_callable():
    """Check that the Rbf class can be constructed with a two argument
    function=callable."""
    def _func(self, r):
        return self.epsilon + r

    x = linspace(0,10,9)
    y = sin(x)
    rbf = Rbf(x, y, function=_func)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_rbf_epsilon_none():
    x = linspace(0, 10, 9)
    y = sin(x)
    rbf = Rbf(x, y, epsilon=None)


if __name__ == "__main__":
    run_module_suite()
