File: test_functions.py

package info (click to toggle)
python-gplearn 0.4.2-5
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,308 kB
  • sloc: python: 2,755; makefile: 158
file content (180 lines) | stat: -rw-r--r-- 6,579 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""Testing the Genetic Programming functions module."""

# Author: Trevor Stephens <trevorstephens.com>
#
# License: BSD 3 clause

import pickle

import numpy as np
import pytest
from numpy import maximum
from sklearn.datasets import load_diabetes, load_breast_cancer
from sklearn.utils.validation import check_random_state

from gplearn.functions import _protected_sqrt, make_function
from gplearn.genetic import SymbolicRegressor, SymbolicTransformer
from gplearn.genetic import SymbolicClassifier

# load the diabetes dataset and randomly permute it
rng = check_random_state(0)
diabetes = load_diabetes()
perm = rng.permutation(diabetes.target.size)
diabetes.data = diabetes.data[perm]
diabetes.target = diabetes.target[perm]

# load the breast cancer dataset and randomly permute it
cancer = load_breast_cancer()
perm = check_random_state(0).permutation(cancer.target.size)
cancer.data = cancer.data[perm]
cancer.target = cancer.target[perm]


def test_validate_function():
    """Check that valid functions are accepted & invalid ones raise error"""

    # Check arity tests
    _ = make_function(function=_protected_sqrt, name='sqrt', arity=1)
    # non-integer arity
    with pytest.raises(ValueError):
        make_function(function=_protected_sqrt, name='sqrt', arity='1')
    with pytest.raises(ValueError):
        make_function(function=_protected_sqrt, name='sqrt', arity=1.0)
    # non-bool wrap
    with pytest.raises(ValueError):
        make_function(function=_protected_sqrt, name='sqrt', arity=1, wrap='f')
    # non-matching arity
    with pytest.raises(ValueError):
        make_function(function=_protected_sqrt, name='sqrt', arity=2)
    with pytest.raises(ValueError):
        make_function(function=maximum, name='max', arity=1)

    # Check name test
    with pytest.raises(ValueError):
        make_function(function=_protected_sqrt, name=2, arity=1)

    # Check return type tests
    def bad_fun1(x1, x2):
        return 'ni'
    with pytest.raises(ValueError):
        make_function(function=bad_fun1, name='ni', arity=2)

    # Check return shape tests
    def bad_fun2(x1):
        return np.ones((2, 1))
    with pytest.raises(ValueError):
        make_function(function=bad_fun2, name='ni', arity=1)

    # Check closure for negatives test
    def _unprotected_sqrt(x1):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.sqrt(x1)
    with pytest.raises(ValueError):
        make_function(function=_unprotected_sqrt, name='sqrt', arity=1)

    # Check closure for zeros test
    def _unprotected_div(x1, x2):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.divide(x1, x2)
    with pytest.raises(ValueError):
        make_function(function=_unprotected_div, name='div', arity=2)


def test_function_in_program():
    """Check that using a custom function in a program works"""

    def logic(x1, x2, x3, x4):
        return np.where(x1 > x2, x3, x4)

    logical = make_function(function=logic,
                            name='logical',
                            arity=4)
    function_set = ['add', 'sub', 'mul', 'div', logical]
    est = SymbolicTransformer(generations=2, population_size=2000,
                              hall_of_fame=100, n_components=10,
                              function_set=function_set,
                              parsimony_coefficient=0.0005,
                              max_samples=0.9, random_state=0)
    est.fit(diabetes.data[:300, :], diabetes.target[:300])

    formula = est._programs[0][3].__str__()
    expected_formula = ('add(X3, logical(div(X5, sub(X5, X5)), '
                        'add(X9, -0.621), X8, X4))')
    assert(expected_formula == formula)


def test_parallel_custom_function():
    """Regression test for running parallel training with custom functions"""

    def _logical(x1, x2, x3, x4):
        return np.where(x1 > x2, x3, x4)

    logical = make_function(function=_logical,
                            name='logical',
                            arity=4)
    est = SymbolicRegressor(generations=2,
                            function_set=['add', 'sub', 'mul', 'div', logical],
                            random_state=0,
                            n_jobs=2)
    est.fit(diabetes.data, diabetes.target)
    _ = pickle.dumps(est)

    # Unwrapped functions should fail
    logical = make_function(function=_logical,
                            name='logical',
                            arity=4,
                            wrap=False)
    est = SymbolicRegressor(generations=2,
                            function_set=['add', 'sub', 'mul', 'div', logical],
                            random_state=0,
                            n_jobs=2)
    est.fit(diabetes.data, diabetes.target)
    with pytest.raises(AttributeError):
        pickle.dumps(est)

    # Single threaded will also fail in non-interactive sessions
    est = SymbolicRegressor(generations=2,
                            function_set=['add', 'sub', 'mul', 'div', logical],
                            random_state=0)
    est.fit(diabetes.data, diabetes.target)
    with pytest.raises(AttributeError):
        pickle.dumps(est)


def test_parallel_custom_transformer():
    """Regression test for running parallel training with custom transformer"""

    def _sigmoid(x1):
        with np.errstate(over='ignore', under='ignore'):
            return 1 / (1 + np.exp(-x1))

    sigmoid = make_function(function=_sigmoid,
                            name='sig',
                            arity=1)
    est = SymbolicClassifier(generations=2,
                             transformer=sigmoid,
                             random_state=0,
                             n_jobs=2)
    est.fit(cancer.data, cancer.target)
    _ = pickle.dumps(est)

    # Unwrapped functions should fail
    sigmoid = make_function(function=_sigmoid,
                            name='sig',
                            arity=1,
                            wrap=False)
    est = SymbolicClassifier(generations=2,
                             transformer=sigmoid,
                             random_state=0,
                             n_jobs=2)
    est.fit(cancer.data, cancer.target)
    with pytest.raises(AttributeError):
        pickle.dumps(est)

    # Single threaded will also fail in non-interactive sessions
    est = SymbolicClassifier(generations=2,
                             transformer=sigmoid,
                             random_state=0)
    est.fit(cancer.data, cancer.target)
    with pytest.raises(AttributeError):
        pickle.dumps(est)