File: test_rbf.py

package info (click to toggle)
python-scipy 0.14.0-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 52,228 kB
  • ctags: 63,719
  • sloc: python: 112,726; fortran: 88,685; cpp: 86,979; ansic: 85,860; makefile: 530; sh: 236
file content (144 lines) | stat: -rw-r--r-- 4,159 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python
# Created by John Travers, Robert Hetland, 2007
""" Test functions for rbf module """
from __future__ import division, print_function, absolute_import


import numpy as np
from numpy.testing import (assert_, assert_array_almost_equal,
                           assert_almost_equal, run_module_suite)
from numpy import linspace, sin, random, exp, allclose
from scipy.interpolate.rbf import Rbf

FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
             'cubic', 'quintic', 'thin-plate', 'linear')


def check_rbf1d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (1D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = linspace(0,10,9)
        y = sin(x)
        rbf = Rbf(x, y, function=function)
        yi = rbf(x)
        assert_array_almost_equal(y, yi)
        assert_almost_equal(rbf(float(x[0])), y[0])
    finally:
        np.seterr(**olderr)


def check_rbf2d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (2D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = random.rand(50,1)*4-2
        y = random.rand(50,1)*4-2
        z = x*exp(-x**2-1j*y**2)
        rbf = Rbf(x, y, z, epsilon=2, function=function)
        zi = rbf(x, y)
        zi.shape = x.shape
        assert_array_almost_equal(z, zi)
    finally:
        np.seterr(**olderr)


def check_rbf3d_interpolation(function):
    """Check that the Rbf function interpolates through the nodes (3D)"""
    olderr = np.seterr(all="ignore")
    try:
        x = random.rand(50,1)*4-2
        y = random.rand(50,1)*4-2
        z = random.rand(50,1)*4-2
        d = x*exp(-x**2-y**2)
        rbf = Rbf(x, y, z, d, epsilon=2, function=function)
        di = rbf(x, y, z)
        di.shape = x.shape
        assert_array_almost_equal(di, d)
    finally:
        np.seterr(**olderr)


def test_rbf_interpolation():
    for function in FUNCTIONS:
        yield check_rbf1d_interpolation, function
        yield check_rbf2d_interpolation, function
        yield check_rbf3d_interpolation, function


def check_rbf1d_regularity(function, atol):
    """Check that the Rbf function approximates a smooth function well away
    from the nodes."""
    olderr = np.seterr(all="ignore")
    try:
        x = linspace(0, 10, 9)
        y = sin(x)
        rbf = Rbf(x, y, function=function)
        xi = linspace(0, 10, 100)
        yi = rbf(xi)
        #import matplotlib.pyplot as plt
        #plt.figure()
        #plt.plot(x, y, 'o', xi, sin(xi), ':', xi, yi, '-')
        #plt.title(function)
        #plt.show()
        msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
        assert_(allclose(yi, sin(xi), atol=atol), msg)
    finally:
        np.seterr(**olderr)


def test_rbf_regularity():
    tolerances = {
        'multiquadric': 0.05,
        'inverse multiquadric': 0.02,
        'gaussian': 0.01,
        'cubic': 0.15,
        'quintic': 0.1,
        'thin-plate': 0.1,
        'linear': 0.2
    }
    for function in FUNCTIONS:
        yield check_rbf1d_regularity, function, tolerances.get(function, 1e-2)


def test_default_construction():
    """Check that the Rbf class can be constructed with the default
    multiquadric basis function. Regression test for ticket #1228."""
    x = linspace(0,10,9)
    y = sin(x)
    rbf = Rbf(x, y)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_function_is_callable():
    """Check that the Rbf class can be constructed with function=callable."""
    x = linspace(0,10,9)
    y = sin(x)
    linfunc = lambda x:x
    rbf = Rbf(x, y, function=linfunc)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_two_arg_function_is_callable():
    """Check that the Rbf class can be constructed with a two argument
    function=callable."""
    def _func(self, r):
        return self.epsilon + r

    x = linspace(0,10,9)
    y = sin(x)
    rbf = Rbf(x, y, function=_func)
    yi = rbf(x)
    assert_array_almost_equal(y, yi)


def test_rbf_epsilon_none():
    x = linspace(0, 10, 9)
    y = sin(x)
    rbf = Rbf(x, y, epsilon=None)


if __name__ == "__main__":
    run_module_suite()