File: test_optimisers.py

package info (click to toggle)
python-cogent 2023.2.12a1%2Bdfsg-2%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 12,416 kB
  • sloc: python: 89,165; makefile: 117; sh: 16
file content (120 lines) | stat: -rw-r--r-- 3,343 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python


import os
import sys

from unittest import TestCase, main

from cogent3.maths.optimisers import MaximumEvaluationsReached, maximise


__author__ = "Peter Maxwell and Gavin Huttley"
__copyright__ = "Copyright 2007-2022, The Cogent Project"
__credits__ = ["Peter Maxwell", "Gavin Huttley"]
__license__ = "BSD-3"
__version__ = "2023.2.12a1"
__maintainer__ = "Gavin Huttley"
__email__ = "gavin.huttley@anu.edu.au"
__status__ = "Production"


def quartic(x):
    # Has global maximum at -4 and local maximum at 2
    # http://www.wolframalpha.com/input/?i=x**2*%283*x**2%2B8*x-48%29
    # Scaled down 10-fold to avoid having to change init_temp
    return x ** 2 * (3 * x ** 2 + 8 * x - 48)


class NullFile(object):
    def write(self, x):
        pass

    def isatty(self):
        return False


def quiet(f, *args, **kw):
    # Checkpointer still has print statements
    orig = sys.stdout
    try:
        sys.stdout = NullFile()
        result = f(*args, **kw)
    finally:
        sys.stdout = orig
    return result


def MakeF():
    evals = [0]
    last = [0]

    def f(x):
        evals[0] += 1
        last[0] = x
        # Scaled down 10-fold to avoid having to change init_temp
        return -0.1 * quartic(x)

    return f, last, evals


class OptimiserTestCase(TestCase):
    def _test_optimisation(self, target=-4, xinit=1.0, bounds=None, **kw):
        bounds = bounds or ([-10, 10])
        f, last, evals = MakeF()

        x = quiet(maximise, f, [xinit], bounds, **kw)
        self.assertEqual(x, last[0])  # important for Calculator
        error = abs(x[0] - target)
        self.assertTrue(error < 0.0001, (kw, x, target, x))

    def test_global(self):
        # Should find global minimum
        self._test_optimisation(local=False, seed=1)

    def test_bounded(self):
        # Global minimum out of bounds, so find secondary one
        # numpy.seterr('raise')
        self._test_optimisation(bounds=([0.0], [10.0]), target=2, seed=1)

    def test_local(self):
        # Global minimum not the nearest one
        self._test_optimisation(local=True, target=2)

    def test_limited(self):
        self.assertRaises(
            MaximumEvaluationsReached, self._test_optimisation, max_evaluations=5
        )

    # def test_limited_warning(self):
    #     """optimiser warning if max_evaluations exceeded"""
    #     self._test_optimisation(max_evaluations=5, limit_action='warn')

    def test_get_max_eval_count(self):
        """return the evaluation count from optimisation"""
        f, last, evals = MakeF()
        x, e = quiet(
            maximise, f, xinit=[1.0], bounds=([-10, 10]), return_eval_count=True
        )
        # picking arbitrary numerical value
        self.assertGreaterEqual(e, 10)

    def test_checkpointing(self):
        filename = "checkpoint.tmp.pickle"
        if os.path.exists(filename):
            os.remove(filename)
        self._test_optimisation(filename=filename, seed=1, init_temp=10)
        self._test_optimisation(filename=filename, seed=1, init_temp=10)
        self.assertRaises(
            Exception,
            self._test_optimisation,
            filename=filename,
            seed=1,
            init_temp=3.21,
        )
        if os.path.exists(filename):
            os.remove(filename)


if __name__ == "__main__":
    main()