1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
#!/usr/bin/env python
"""
Run each fitter on the 3 dimensional Rosenbrock function to make sure they
all converge.
"""
import sys
import os
from os.path import join as joinpath
import tempfile
import subprocess
from pathlib import Path
import h5py
sys.dont_write_bytecode = True
# Add the build dir to the system path
ROOT = Path(__file__).absolute().parent
packages = [str(ROOT)]
if "PYTHONPATH" in os.environ:
packages.append(os.environ["PYTHONPATH"])
os.environ["PYTHONPATH"] = os.pathsep.join(packages)
# Need bumps on the path to pull in the available fitters
sys.path.insert(0, str(ROOT))
from bumps.fitters import FIT_AVAILABLE_IDS
BUMPS_COMMAND = [sys.executable, "-m", "bumps"]
EXAMPLEDIR = ROOT / "doc" / "examples"
def decode(b):
return b.decode("utf-8")
def run_fit(fitter, model_args, store, seed=1):
command = [
*BUMPS_COMMAND,
str(model_args[0]),
f"--fit={fitter}",
f"--session={store}",
f"--seed={seed}",
"--batch",
f"--export={str(Path(store).parent / ('T1-'+fitter))}",
]
if len(model_args) > 1:
command += ["--args", *model_args[1:]]
try:
# print("Command:", " ".join(command))
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
output = decode(output.strip())
if output:
print(output)
except subprocess.CalledProcessError as exc:
output = decode(exc.output.strip())
if output:
print(output)
if "KeyboardInterrupt" in output:
raise KeyboardInterrupt()
else:
raise RuntimeError("fit failed:\n" + " ".join(command))
def check_fit(fitter, store, target):
"""
Verify overall chisq value matches target within 1% for all fitters.
"""
with h5py.File(store) as fd:
group = fd["problem_history"]
last_item = list(group.keys())[-1]
chisq_str = group[last_item].attrs["chisq"]
value = float(chisq_str.split("(")[0])
assert abs(value - target) / target < 1e-2, f"error in {fitter}: expected {target} but got {value}"
def run_fits(model_args, path, fitters=FIT_AVAILABLE_IDS, seed=1, target=0):
failed = []
for f in fitters:
print(f"====== fitter: {f}")
try:
store = Path(path) / f"{f}.hdf"
run_fit(f, model_args, str(store), seed=seed)
check_fit(f, store, target)
except Exception as exc:
# import traceback; traceback.print_exc()
print(str(exc))
failed.append(f)
return failed
def main():
# Note: bumps.fitters.test_fitters already runs curvefit on the "active" fitters
fitters = sys.argv[1:] if len(sys.argv) > 1 else FIT_AVAILABLE_IDS
# TODO: use a test function that defines residuals
test_functions = EXAMPLEDIR / "test_functions" / "model.py"
# model_args = [test_functions, '"fk(rosenbrock, 3)"']
model_args, target = [test_functions, "gauss", "3"], 0
model_args, target = [EXAMPLEDIR / "curvefit" / "curve.py"], 1.760
seed = 1
with tempfile.TemporaryDirectory() as path:
failed = run_fits(model_args, path, fitters=fitters, seed=seed, target=target)
if failed:
print("======")
print("Fits failed for: %s" % (", ".join(failed),))
sys.exit(1)
if __name__ == "__main__":
main()
|