1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
#!/usr/bin/env python
# Created by John Travers, Robert Hetland, 2007
""" Test functions for rbf module """
from __future__ import division, print_function, absolute_import
import numpy as np
from numpy.testing import (assert_, assert_array_almost_equal,
assert_almost_equal, run_module_suite)
from numpy import linspace, sin, random, exp, allclose
from scipy.interpolate.rbf import Rbf
FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
'cubic', 'quintic', 'thin-plate', 'linear')
def check_rbf1d_interpolation(function):
"""Check that the Rbf function interpolates through the nodes (1D)"""
olderr = np.seterr(all="ignore")
try:
x = linspace(0,10,9)
y = sin(x)
rbf = Rbf(x, y, function=function)
yi = rbf(x)
assert_array_almost_equal(y, yi)
assert_almost_equal(rbf(float(x[0])), y[0])
finally:
np.seterr(**olderr)
def check_rbf2d_interpolation(function):
"""Check that the Rbf function interpolates through the nodes (2D)"""
olderr = np.seterr(all="ignore")
try:
x = random.rand(50,1)*4-2
y = random.rand(50,1)*4-2
z = x*exp(-x**2-1j*y**2)
rbf = Rbf(x, y, z, epsilon=2, function=function)
zi = rbf(x, y)
zi.shape = x.shape
assert_array_almost_equal(z, zi)
finally:
np.seterr(**olderr)
def check_rbf3d_interpolation(function):
"""Check that the Rbf function interpolates through the nodes (3D)"""
olderr = np.seterr(all="ignore")
try:
x = random.rand(50,1)*4-2
y = random.rand(50,1)*4-2
z = random.rand(50,1)*4-2
d = x*exp(-x**2-y**2)
rbf = Rbf(x, y, z, d, epsilon=2, function=function)
di = rbf(x, y, z)
di.shape = x.shape
assert_array_almost_equal(di, d)
finally:
np.seterr(**olderr)
def test_rbf_interpolation():
for function in FUNCTIONS:
yield check_rbf1d_interpolation, function
yield check_rbf2d_interpolation, function
yield check_rbf3d_interpolation, function
def check_rbf1d_regularity(function, atol):
"""Check that the Rbf function approximates a smooth function well away
from the nodes."""
olderr = np.seterr(all="ignore")
try:
x = linspace(0, 10, 9)
y = sin(x)
rbf = Rbf(x, y, function=function)
xi = linspace(0, 10, 100)
yi = rbf(xi)
#import matplotlib.pyplot as plt
#plt.figure()
#plt.plot(x, y, 'o', xi, sin(xi), ':', xi, yi, '-')
#plt.title(function)
#plt.show()
msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
assert_(allclose(yi, sin(xi), atol=atol), msg)
finally:
np.seterr(**olderr)
def test_rbf_regularity():
tolerances = {
'multiquadric': 0.05,
'inverse multiquadric': 0.02,
'gaussian': 0.01,
'cubic': 0.15,
'quintic': 0.1,
'thin-plate': 0.1,
'linear': 0.2
}
for function in FUNCTIONS:
yield check_rbf1d_regularity, function, tolerances.get(function, 1e-2)
def test_default_construction():
"""Check that the Rbf class can be constructed with the default
multiquadric basis function. Regression test for ticket #1228."""
x = linspace(0,10,9)
y = sin(x)
rbf = Rbf(x, y)
yi = rbf(x)
assert_array_almost_equal(y, yi)
def test_function_is_callable():
"""Check that the Rbf class can be constructed with function=callable."""
x = linspace(0,10,9)
y = sin(x)
linfunc = lambda x:x
rbf = Rbf(x, y, function=linfunc)
yi = rbf(x)
assert_array_almost_equal(y, yi)
def test_two_arg_function_is_callable():
"""Check that the Rbf class can be constructed with a two argument
function=callable."""
def _func(self, r):
return self.epsilon + r
x = linspace(0,10,9)
y = sin(x)
rbf = Rbf(x, y, function=_func)
yi = rbf(x)
assert_array_almost_equal(y, yi)
def test_rbf_epsilon_none():
x = linspace(0, 10, 9)
y = sin(x)
rbf = Rbf(x, y, epsilon=None)
if __name__ == "__main__":
run_module_suite()
|