File: fitobjective_api.py

package info (click to toggle)
bornagain 23.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 103,936 kB
  • sloc: cpp: 423,131; python: 40,997; javascript: 11,167; awk: 630; sh: 318; ruby: 173; xml: 130; makefile: 51; ansic: 24
file content (139 lines) | stat: -rw-r--r-- 3,990 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Testing python specific API for FitObjective related classes.
"""
import unittest
import numpy as np
import bornagain as ba
from bornagain import deg


class SimulationBuilder:

    def __init__(self):
        self.m_ncalls = 0
        self.m_pars = None
        self.m_nrow = 3
        self.m_ncol = 4
        self.beam = ba.Beam(1., 1., 0)
        self.detector = ba.SphericalDetector(self.m_ncol, -2*deg, 2*deg, self.m_nrow, 0., 3*deg)

    def size(self):
        return self.m_nrow * self.m_ncol

    def build_simulation(self, pars):
        self.m_ncalls += 1
        self.m_pars = dict(pars)

        material = ba.RefractiveMaterial("Shell", 0, 0)
        sample = ba.Sample()
        sample.addLayer(ba.Layer(material))
        sample.addLayer(ba.Layer(material))

        simulation = ba.ScatteringSimulation(self.beam, sample, self.detector)

        return simulation

    def create_data(self):
        result = ba.Datafield(self.detector.clippedFrame())
        result.setAllTo(1.)
        return result


class FitObserver:

    def __init__(self):
        self.m_ncalls = 0
        self.m_iterations = []

    def update(self, fit_objective):
        self.m_ncalls += 1
        self.m_iterations.append(
            fit_objective.iterationInfo().iterationCount())


class FitObjectiveAPITest(unittest.TestCase):

    def test_SimulationBuilderCallback(self):
        """
        Testing simulation construction using Python callback
        """
        pars = ba.Parameters()
        pars.add(ba.Parameter("par0", 0))
        pars.add(ba.Parameter("par1", 1))

        builder = SimulationBuilder()
        data = builder.create_data()

        # adding simulation callback and experimental data
        objective = ba.FitObjective()
        objective.addFitPair(builder.build_simulation, data, 1)
        self.assertEqual(builder.m_ncalls, 0)

        # running objective function
        objective.evaluate(pars)
        self.assertEqual(builder.m_ncalls, 1)
        self.assertEqual(builder.m_pars["par0"], 0)
        self.assertEqual(builder.m_pars["par1"], 1)

        # checking arrays of experimental and simulated data
        expected_sim = []
        expected_data = []
        for _ in range(0, builder.size()):
            expected_sim.append(0)
            expected_data.append(1)
        self.assertEqual(expected_sim, list(objective.flatSimData()))
        self.assertEqual(expected_data, list(objective.flatExpData()))

    def test_FittingObserver(self):
        """
        Testing simulation construction using Python callback
        """
        pars = ba.Parameters()
        pars.add(ba.Parameter("par0", 0))
        pars.add(ba.Parameter("par1", 1))

        # adding simulation callback and experimental data
        builder = SimulationBuilder()
        data = builder.create_data()
        objective = ba.FitObjective()
        objective.addFitPair(builder.build_simulation, data, 1)

        # adding observer
        observer = FitObserver()
        objective.initPlot(5, observer.update)

        # running objective function 11 times
        for _ in range(0, 11):
            objective.evaluate(pars)

        self.assertEqual(observer.m_ncalls, 3)
        self.assertEqual(observer.m_iterations, [1, 6, 11])

    def test_IterationInfo(self):
        """
        Testing map of parameters obtained from IterationInfo
        """

        params = ba.Parameters()
        params.add("bbb", 1)
        params.add("aaa", 2)

        info = ba.IterationInfo()
        info.update(params, 3)
        par_map = info.parameterMap()

        expected_names = ["aaa", "bbb"]
        expected_values = [2, 1.0]
        names = []
        values = []
        for key in par_map:
            names.append(key)
            values.append(par_map[key])

        self.assertEqual(names, expected_names)
        self.assertEqual(values, expected_values)
        self.assertEqual(info.iterationCount(), 1)


if __name__ == '__main__':
    unittest.main()