File: test_optimisers.py

package info (click to toggle)
python-cogent 2024.5.7a1%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 74,600 kB
  • sloc: python: 92,479; makefile: 117; sh: 16
file content (127 lines) | stat: -rw-r--r-- 3,442 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import sys

from unittest import TestCase

import numpy
import pytest

from cogent3.maths.optimisers import (
    MaximumEvaluationsReached,
    _standardise_data,
    maximise,
)


def quartic(x):
    # Has global maximum at -4 and local maximum at 2
    # http://www.wolframalpha.com/input/?i=x**2*%283*x**2%2B8*x-48%29
    # Scaled down 10-fold to avoid having to change init_temp
    return x**2 * (3 * x**2 + 8 * x - 48)


class NullFile(object):
    def write(self, x):
        pass

    def isatty(self):
        return False


def quiet(f, *args, **kw):
    # Checkpointer still has print statements
    orig = sys.stdout
    try:
        sys.stdout = NullFile()
        result = f(*args, **kw)
    finally:
        sys.stdout = orig
    return result


def MakeF():
    evals = [0]
    last = [0]

    def f(x):
        evals[0] += 1
        last[0] = x
        # Scaled down 10-fold to avoid having to change init_temp
        return -0.1 * quartic(x)

    return f, last, evals


class OptimiserTestCase(TestCase):
    def _test_optimisation(self, target=-4, xinit=1.0, bounds=None, **kw):
        bounds = bounds or ([-10, 10])
        f, last, evals = MakeF()

        x = quiet(maximise, f, [xinit], bounds, show_progress=False, **kw)
        self.assertEqual(x, last[0])  # important for Calculator
        error = abs(x[0] - target)
        self.assertTrue(error < 0.0001, (kw, x, target, x))

    def test_global(self):
        # Should find global minimum
        self._test_optimisation(local=False, seed=1)

    def test_bounded(self):
        # Global minimum out of bounds, so find secondary one
        # numpy.seterr('raise')
        self._test_optimisation(bounds=([0.0], [10.0]), target=2, seed=1)

    def test_local(self):
        # Global minimum not the nearest one
        self._test_optimisation(local=True, target=2)

    def test_limited(self):
        self.assertRaises(
            MaximumEvaluationsReached, self._test_optimisation, max_evaluations=5
        )

    # def test_limited_warning(self):
    #     """optimiser warning if max_evaluations exceeded"""
    #     self._test_optimisation(max_evaluations=5, limit_action='warn')

    def test_get_max_eval_count(self):
        """return the evaluation count from optimisation"""
        f, last, evals = MakeF()
        x, e = quiet(
            maximise,
            f,
            xinit=[1.0],
            bounds=([-10, 10]),
            return_eval_count=True,
            show_progress=False,
        )
        # picking arbitrary numerical value
        self.assertGreaterEqual(e, 10)

    def test_checkpointing(self):
        filename = "checkpoint.tmp.pickle"
        if os.path.exists(filename):
            os.remove(filename)
        self._test_optimisation(filename=filename, seed=1, init_temp=10)
        self._test_optimisation(filename=filename, seed=1, init_temp=10)
        self.assertRaises(
            Exception,
            self._test_optimisation,
            filename=filename,
            seed=1,
            init_temp=3.21,
        )
        if os.path.exists(filename):
            os.remove(filename)


@pytest.mark.parametrize("val", (numpy.array(3.7), numpy.array([3.7]), 3.7))
def test_standardise_data(val):
    got = _standardise_data(val)
    assert got == (3.7,)


@pytest.mark.parametrize("val", (37, "37"))
def test_standardise_data_str(val):
    got = _standardise_data(val)
    assert got == ("37",)