File: test_two_theta_refinement.py

package info (click to toggle)
dials 3.25.0%2Bdfsg3-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 20,112 kB
  • sloc: python: 134,740; cpp: 34,526; makefile: 160; sh: 142
file content (331 lines) | stat: -rw-r--r-- 10,318 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""
Test refinement of a crystal unit cell using a two theta target.
"""

from __future__ import annotations

from copy import deepcopy
from math import pi

from dxtbx.model.experiment_list import Experiment, ExperimentList
from libtbx.test_utils import approx_equal

from dials.algorithms.refinement.two_theta_refiner import (
    TwoThetaExperimentsPredictor,
    TwoThetaPredictionParameterisation,
    TwoThetaReflectionManager,
    TwoThetaTarget,
)


def generate_reflections(experiments):
    from cctbx.sgtbx import space_group, space_group_symbols
    from scitbx.array_family import flex

    from dials.algorithms.refinement.prediction.managed_predictors import (
        ScansExperimentsPredictor,
        ScansRayPredictor,
    )
    from dials.algorithms.spot_prediction import IndexGenerator, ray_intersection

    detector = experiments[0].detector
    crystal = experiments[0].crystal

    # All indices in a 2.0 Angstrom sphere
    resolution = 2.0
    index_generator = IndexGenerator(
        crystal.get_unit_cell(),
        space_group(space_group_symbols(1).hall()).type(),
        resolution,
    )
    indices = index_generator.to_array()

    # Predict rays within the sequence range
    scan = experiments[0].scan
    sequence_range = scan.get_oscillation_range(deg=False)
    ray_predictor = ScansRayPredictor(experiments, sequence_range)
    obs_refs = ray_predictor(indices)

    # Take only those rays that intersect the detector
    intersects = ray_intersection(detector, obs_refs)
    obs_refs = obs_refs.select(intersects)

    # Make a reflection predictor and re-predict for all these reflections. The
    # result is the same, but we gain also the flags and xyzcal.px columns
    ref_predictor = ScansExperimentsPredictor(experiments)
    obs_refs["id"] = flex.int(len(obs_refs), 0)
    obs_refs = ref_predictor(obs_refs)

    # Set 'observed' centroids from the predicted ones
    obs_refs["xyzobs.mm.value"] = obs_refs["xyzcal.mm"]

    # Invent some variances for the centroid positions of the simulated data
    im_width = 0.1 * pi / 180.0
    px_size = detector[0].get_pixel_size()
    var_x = flex.double(len(obs_refs), (px_size[0] / 2.0) ** 2)
    var_y = flex.double(len(obs_refs), (px_size[1] / 2.0) ** 2)
    var_phi = flex.double(len(obs_refs), (im_width / 2.0) ** 2)
    obs_refs["xyzobs.mm.variance"] = flex.vec3_double(var_x, var_y, var_phi)

    return obs_refs, ref_predictor


def test_fd_derivatives():
    """Test derivatives of the prediction equation"""

    from libtbx.phil import parse

    from . import geometry_phil

    # Import model builder
    from .setup_geometry import Extract

    # Create models
    overrides = """geometry.parameters.crystal.a.length.range = 10 50
  geometry.parameters.crystal.b.length.range = 10 50
  geometry.parameters.crystal.c.length.range = 10 50"""
    master_phil = parse(geometry_phil)
    models = Extract(master_phil, overrides)

    mydetector = models.detector
    mygonio = models.goniometer
    mycrystal = models.crystal
    mybeam = models.beam

    # Build a mock scan for a 72 degree sequence
    from dxtbx.model import ScanFactory

    sf = ScanFactory()
    myscan = sf.make_scan(
        image_range=(1, 720),
        exposure_times=0.1,
        oscillation=(0, 0.1),
        epochs=list(range(720)),
        deg=True,
    )

    # Create a parameterisation of the crystal unit cell
    from dials.algorithms.refinement.parameterisation.crystal_parameters import (
        CrystalUnitCellParameterisation,
    )

    xluc_param = CrystalUnitCellParameterisation(mycrystal)

    # Create an ExperimentList
    experiments = ExperimentList()
    experiments.append(
        Experiment(
            beam=mybeam,
            detector=mydetector,
            goniometer=mygonio,
            scan=myscan,
            crystal=mycrystal,
            imageset=None,
        )
    )

    # Build a prediction parameterisation for two theta prediction
    pred_param = TwoThetaPredictionParameterisation(
        experiments,
        detector_parameterisations=None,
        beam_parameterisations=None,
        xl_orientation_parameterisations=None,
        xl_unit_cell_parameterisations=[xluc_param],
    )

    # Generate some reflections
    obs_refs, ref_predictor = generate_reflections(experiments)

    # Build a ReflectionManager with overloads for handling 2theta residuals
    refman = TwoThetaReflectionManager(obs_refs, experiments, outlier_detector=None)

    # Build a TwoThetaExperimentsPredictor
    ref_predictor = TwoThetaExperimentsPredictor(experiments)

    # Make a target for the least squares 2theta residual
    target = TwoThetaTarget(experiments, ref_predictor, refman, pred_param)

    # Keep only reflections that pass inclusion criteria and have predictions
    reflections = refman.get_matches()

    # Get analytical gradients
    an_grads = pred_param.get_gradients(reflections)

    # Get finite difference gradients
    p_vals = pred_param.get_param_vals()
    deltas = [1.0e-7] * len(p_vals)

    for i in range(len(deltas)):
        val = p_vals[i]

        p_vals[i] -= deltas[i] / 2.0
        pred_param.set_param_vals(p_vals)

        target.predict()
        reflections = refman.get_matches()

        rev_state = reflections["2theta_resid"].deep_copy()

        p_vals[i] += deltas[i]
        pred_param.set_param_vals(p_vals)

        target.predict()
        reflections = refman.get_matches()

        fwd_state = reflections["2theta_resid"].deep_copy()
        p_vals[i] = val

        fd = fwd_state - rev_state
        fd /= deltas[i]

        # compare with analytical calculation
        assert approx_equal(fd, an_grads[i]["d2theta_dp"], eps=1.0e-6)

    # return to the initial state
    pred_param.set_param_vals(p_vals)


def test_refinement(dials_data):
    """Test a refinement run"""

    # Get a beam and detector from a experiments. This one has a CS-PAD, but that
    # is irrelevant
    data_dir = dials_data("refinement_test_data", pathlib=True)
    experiments_path = data_dir / "cspad-single-image.expt"

    # load models
    from dxtbx.model.experiment_list import ExperimentListFactory

    experiments = ExperimentListFactory.from_serialized_format(
        experiments_path, check_format=False
    )
    im_set = experiments.imagesets()[0]
    detector = deepcopy(im_set.get_detector())
    beam = im_set.get_beam()

    # Invent a crystal, goniometer and scan for this test
    from dxtbx.model import Crystal

    crystal = Crystal(
        (40.0, 0.0, 0.0), (0.0, 40.0, 0.0), (0.0, 0.0, 40.0), space_group_symbol="P1"
    )
    orig_xl = deepcopy(crystal)

    from dxtbx.model import GoniometerFactory

    goniometer = GoniometerFactory.known_axis((1.0, 0.0, 0.0))

    # Build a mock scan for a 180 degree sequence
    from dxtbx.model import ScanFactory

    sf = ScanFactory()
    scan = sf.make_scan(
        image_range=(1, 1800),
        exposure_times=0.1,
        oscillation=(0, 0.1),
        epochs=list(range(1800)),
        deg=True,
    )
    sequence_range = scan.get_oscillation_range(deg=False)
    im_width = scan.get_oscillation(deg=False)[1]
    assert sequence_range == (0.0, pi)
    assert approx_equal(im_width, 0.1 * pi / 180.0)

    # Build an experiment list
    experiments = ExperimentList()
    experiments.append(
        Experiment(
            beam=beam,
            detector=detector,
            goniometer=goniometer,
            scan=scan,
            crystal=crystal,
            imageset=None,
        )
    )

    # simulate some reflections
    refs, _ = generate_reflections(experiments)

    # change unit cell a bit (=0.1 Angstrom length upsets, 0.1 degree of
    # alpha and beta angles)
    from dials.algorithms.refinement.parameterisation.crystal_parameters import (
        CrystalUnitCellParameterisation,
    )

    xluc_param = CrystalUnitCellParameterisation(crystal)
    cell_params = crystal.get_unit_cell().parameters()
    cell_params = [a + b for a, b in zip(cell_params, [0.1, -0.1, 0.1, 0.1, -0.1, 0.0])]
    from cctbx.uctbx import unit_cell
    from rstbx.symmetry.constraints.parameter_reduction import symmetrize_reduce_enlarge
    from scitbx import matrix

    new_uc = unit_cell(cell_params)
    newB = matrix.sqr(new_uc.fractionalization_matrix()).transpose()
    S = symmetrize_reduce_enlarge(crystal.get_space_group())
    S.set_orientation(orientation=newB)
    X = tuple([e * 1.0e5 for e in S.forward_independent_parameters()])
    xluc_param.set_param_vals(X)

    # reparameterise the crystal at the perturbed geometry
    xluc_param = CrystalUnitCellParameterisation(crystal)

    # Dummy parameterisations for other models
    beam_param = None
    xlo_param = None
    det_param = None

    # parameterisation of the prediction equation
    from dials.algorithms.refinement.parameterisation.parameter_report import (
        ParameterReporter,
    )

    pred_param = TwoThetaPredictionParameterisation(
        experiments, det_param, beam_param, xlo_param, [xluc_param]
    )
    param_reporter = ParameterReporter(det_param, beam_param, xlo_param, [xluc_param])

    # reflection manager
    refman = TwoThetaReflectionManager(refs, experiments, nref_per_degree=20)

    # reflection predictor
    ref_predictor = TwoThetaExperimentsPredictor(experiments)

    # target function
    target = TwoThetaTarget(experiments, ref_predictor, refman, pred_param)

    # minimisation engine
    from dials.algorithms.refinement.engine import (
        LevenbergMarquardtIterations as Refinery,
    )

    refinery = Refinery(
        target=target,
        prediction_parameterisation=pred_param,
        log=None,
        max_iterations=20,
    )

    # Refiner
    from dials.algorithms.refinement.refiner import Refiner

    refiner = Refiner(
        experiments=experiments,
        pred_param=pred_param,
        param_reporter=param_reporter,
        refman=refman,
        target=target,
        refinery=refinery,
    )
    refiner.run()

    # compare crystal with original crystal
    refined_xl = refiner.get_experiments()[0].crystal

    # print refined_xl
    assert refined_xl.is_similar_to(
        orig_xl, uc_rel_length_tolerance=0.001, uc_abs_angle_tolerance=0.01
    )

    # print "Unit cell esds:"
    # print refined_xl.get_cell_parameter_sd()