File: setup_minimiser.py

package info (click to toggle)
dials 3.25.0%2Bdfsg3-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 20,112 kB
  • sloc: python: 134,740; cpp: 34,526; makefile: 160; sh: 142
file content (69 lines) | stat: -rw-r--r-- 2,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Setup experimental geometry for refinement test cases"""

from __future__ import annotations

from libtbx.phil import command_line, parse

from dials.algorithms.refinement.engine import (
    GaussNewtonIterations,
    LBFGScurvs,
    SimpleLBFGS,
)


class Extract:
    """Parse and extract minimiser setup from PHIL"""

    def __init__(
        self,
        master_phil,
        target,
        prediction_parameterisation,
        local_overrides="",
        cmdline_args=None,
        verbose=True,
    ):
        self._target = target
        self._prediction_parameterisation = prediction_parameterisation
        self._verbose = verbose

        arg_interpreter = command_line.argument_interpreter(master_phil=master_phil)

        user_phil = parse(local_overrides)
        cmdline_phils = []
        if cmdline_args:
            for arg in cmdline_args:
                cmdline_phils.append(arg_interpreter.process(arg))

        working_phil = master_phil.fetch(sources=[user_phil] + cmdline_phils)

        self._params = working_phil.extract().minimiser.parameters

        self.refiner = self.build_minimiser()

    def build_minimiser(self):
        assert self._params.engine in ["SimpleLBFGS", "LBFGScurvs", "GaussNewton"]

        if self._params.engine == "SimpleLBFGS":
            refiner = SimpleLBFGS(
                target=self._target,
                prediction_parameterisation=self._prediction_parameterisation,
                log=self._params.logfile,
            )
            return refiner

        if self._params.engine == "LBFGScurvs":
            refiner = LBFGScurvs(
                target=self._target,
                prediction_parameterisation=self._prediction_parameterisation,
                log=self._params.logfile,
            )
            return refiner

        if self._params.engine == "GaussNewton":
            refiner = GaussNewtonIterations(
                target=self._target,
                prediction_parameterisation=self._prediction_parameterisation,
                log=self._params.logfile,
            )
            return refiner