1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
#!/usr/bin/env python
"""
Script to run a batch fit in a series of files and plot the fitted parameters.
Usage syntax::
python batch_fit.py model.py sample1.dat sample2.dat ... other_sample.dat
(files named sample1.dat, sample2.dat, ..., other_sample.dat)
or if the file names are numbers (and the extension is .dat)::
python batch_fit.py model.py 93190 93210
(files named 093190.dat, 093191.dat, ..., 093210.dat)
or for Grasp-like naming::
python batch_fit.py model.py 93190 93210 200
(files named 093190_200.dat, 093191_201.dat, ..., 093210_220.dat)
The script reads a series of files and fits the model defined by model.py.
For example model_ellipsoid_hayter_msa.py fits a model consisting in an
ellipsoid form factor multiplied by a Hayter MSA structure factor.
The file *model.py* must load the data using::
data = load_data(sys.argv[1])
Include options to bumps (minimizer, steps, etc.) as desired. For example::
python batch_fit.py model.py 93190 93210 200 --fit=lm --steps=200 --ftol=1.5e-8 --xtol=1.5e-8
Fit options can come before or after the model and files.
For each file a directory named Fit_filename is created. There the file
model.par contains the fitted parameters. These are gathered together into
batch_fit.csv in the current directory.
Finally the fitted parameters are plotted for the full series.
Example::
python batch_fit.py model_ellipsoid_hayter_msa.py 93191 93195 201
Note:
If sasmodels, sasview or bumps are not in the path, use the PYTHONPATH
environment variable to set them.
"""
from __future__ import print_function
import sys
import os
import numpy as np
import matplotlib.pyplot as plt
from bumps.dream.views import tile_axes # make a grid of plots
# GET INPUT AND ENSURE MODEL AND DATA FILES ARE DEFINED
fit_opts = [v for v in sys.argv[1:] if v.startswith('--')]
args = [v for v in sys.argv[1:] if not v.startswith('--')]
nargs = len(args)
if nargs < 2:
print ("Error in the list of arguments! \n")
sys.exit()
model_file = args[0]
if not model_file.endswith('.py'):
print("Expected model.py as the first argument")
sys.exit(1)
if '.' in args[1]:
data_files = args[1:]
else:
first = int(args[1])
last = int(args[2])
count = last-first+1
data_files = []
if nargs == 3:
data_files = ['%06d.dat'%(first+i) for i in range(count)]
elif nargs == 4:
ext = int(args[3])
data_files = ['%06d_%d.dat'%(first+i, ext+i) for i in range(count)]
else:
print("Unexpected arguments: " + " ".join(args[4:]))
sys.exit(1)
# CHECK THAT DATA FILES EXIST
missing = [filename for filename in data_files if not os.path.isfile(filename)]
if missing:
print("Missing data files: %s" % ", ".join(missing))
sys.exit(1)
# STORE DIRECTORY FOR BUMPS FITS
def fit_dir(filename):
"Return the store directory name for the given file"
return "Fit_" + os.path.splitext(filename)[0]
# LOOP OVER FILES AND CALL TO BUMPS FOR EACH OF THEM
bumps_cmd = "python -m bumps.cli --batch"
fit_opts = " ".join(fit_opts)
for data_file in data_files:
store_opts = "--store=" + fit_dir(data_file)
cmd = " ".join((bumps_cmd, fit_opts, store_opts, model_file, data_file))
os.system(cmd)
# GATHER RESULTS
results = {}
par_file = os.path.splitext(model_file)[0] + '.par'
for data_file in data_files:
with open(os.path.join(fit_dir(data_file), par_file), 'r') as fid:
for line in fid:
parameter, value = line.split()
results.setdefault(parameter, []).append(float(value))
# SAVE RESULTS INTO FILE
with open('batch_fit.csv', 'w') as fid:
parameters = list(sorted(results.keys()))
values_by_file = zip(*(v for k, v in sorted(results.items())))
fid.write(','.join(['filename'] + parameters) + '\n')
for filename, values in zip(data_files, values_by_file):
fid.write(','.join([filename] + [str(v) for v in values]) + '\n')
# SHOW FITTED PARAMETERS
nh, nw = tile_axes(len(results))
ticks = np.arange(1, len(data_files)+1)
labels = [os.path.splitext(filename)[0] for filename in data_files]
for k, (parameter, values) in enumerate(sorted(results.items())):
plt.subplot(nh, nw, k+1)
plt.plot(ticks, values)
plt.xlim(ticks[0]-0.5, ticks[-1]+0.5)
if k%nh == nh-1:
#plt.xlabel('Dataset #')
plt.xticks(ticks, labels, rotation=30)
else:
plt.xticks(ticks, [' ']*len(labels))
plt.ylabel(parameter)
plt.suptitle("Fit " + args[0])
plt.show()
|