File: check_fitters.py

package info (click to toggle)
python-bumps 1.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 6,200 kB
  • sloc: python: 24,517; xml: 493; ansic: 373; makefile: 211; javascript: 99; sh: 94
file content (111 lines) | stat: -rwxr-xr-x 3,371 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""
Run each fitter on the 3 dimensional Rosenbrock function to make sure they
all converge.
"""

import sys
import os
from os.path import join as joinpath
import tempfile
import subprocess
from pathlib import Path
import h5py

sys.dont_write_bytecode = True

# Add the build dir to the system path
ROOT = Path(__file__).absolute().parent
packages = [str(ROOT)]
if "PYTHONPATH" in os.environ:
    packages.append(os.environ["PYTHONPATH"])
os.environ["PYTHONPATH"] = os.pathsep.join(packages)

# Need bumps on the path to pull in the available fitters
sys.path.insert(0, str(ROOT))
from bumps.fitters import FIT_AVAILABLE_IDS


BUMPS_COMMAND = [sys.executable, "-m", "bumps"]

EXAMPLEDIR = ROOT / "doc" / "examples"


def decode(b):
    return b.decode("utf-8")


def run_fit(fitter, model_args, store, seed=1):
    command = [
        *BUMPS_COMMAND,
        str(model_args[0]),
        f"--fit={fitter}",
        f"--session={store}",
        f"--seed={seed}",
        "--batch",
        f"--export={str(Path(store).parent / ('T1-'+fitter))}",
    ]
    if len(model_args) > 1:
        command += ["--args", *model_args[1:]]
    try:
        # print("Command:", " ".join(command))
        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
        output = decode(output.strip())
        if output:
            print(output)
    except subprocess.CalledProcessError as exc:
        output = decode(exc.output.strip())
        if output:
            print(output)
        if "KeyboardInterrupt" in output:
            raise KeyboardInterrupt()
        else:
            raise RuntimeError("fit failed:\n" + " ".join(command))


def check_fit(fitter, store, target):
    """
    Verify overall chisq value matches target within 1% for all fitters.
    """
    with h5py.File(store) as fd:
        group = fd["problem_history"]
        last_item = list(group.keys())[-1]
        chisq_str = group[last_item].attrs["chisq"]
        value = float(chisq_str.split("(")[0])
        assert abs(value - target) / target < 1e-2, f"error in {fitter}: expected {target} but got {value}"


def run_fits(model_args, path, fitters=FIT_AVAILABLE_IDS, seed=1, target=0):
    failed = []
    for f in fitters:
        print(f"====== fitter: {f}")
        try:
            store = Path(path) / f"{f}.hdf"
            run_fit(f, model_args, str(store), seed=seed)
            check_fit(f, store, target)
        except Exception as exc:
            # import traceback; traceback.print_exc()
            print(str(exc))
            failed.append(f)
    return failed


def main():
    # Note: bumps.fitters.test_fitters already runs curvefit on the "active" fitters
    fitters = sys.argv[1:] if len(sys.argv) > 1 else FIT_AVAILABLE_IDS
    # TODO: use a test function that defines residuals
    test_functions = EXAMPLEDIR / "test_functions" / "model.py"
    # model_args = [test_functions, '"fk(rosenbrock, 3)"']
    model_args, target = [test_functions, "gauss", "3"], 0
    model_args, target = [EXAMPLEDIR / "curvefit" / "curve.py"], 1.760
    seed = 1
    with tempfile.TemporaryDirectory() as path:
        failed = run_fits(model_args, path, fitters=fitters, seed=seed, target=target)
    if failed:
        print("======")
        print("Fits failed for: %s" % (", ".join(failed),))
        sys.exit(1)


if __name__ == "__main__":
    main()