1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
#!/usr/bin/env python
#-----------------------------------------------------------------------------
# Model Factory Interface:
"""NOTES:
- forward model "forward_poly" calculates a function of x (w/ fixed a,b,c)
- ForwardPolyFactory is a "function generator", allowing a,b,c to be set
"""
def ForwardPolyFactory(params):
a,b,c = params
def forward_poly(x):
""" x should be a 1D (1 by N) numpy array """
return array((a*x*x + b*x + c))
return forward_poly
#-----------------------------------------------------------------------------
# Forward Model Invocation:
"""NOTES:
- fwd is a instance of "forward_poly", built with chosen a,b,c
- "data" converts a function of x into a function of a,b,c (w/ fixed x) [i.e. a functor]
- same methodology is used in COST FUNCTION to produce "goodness of fit"
"""
def data(params):
fwd = ForwardPolyFactory(params)
x = (array([list(range(101))])-50.)[0]
return fwd(x)
#-----------------------------------------------------------------------------
# Build "Measured" Data: (optional... use real measured data)
"""NOTES:
- target is "target solution" for a,b,c
- data is used to generate "measured data" (parameters a,b,c = target)
"""
target = [1., 2., 1.]
datapts = data(target)
#-----------------------------------------------------------------------------
# Cost Function Generation: (optional... write your cost function explicitly)
"""NOTES:
- F is an instance of Cost Function (goodness of fit) generator
- myCost is an instance of a Cost Function
- (default metric) calculates the LeastSquared difference for fwd(x) & datapts
"""
x = (array([list(range(101))])-50.)[0]
F = CostFactory()
F.addModel(ForwardPolyFactory,3,'poly')
myCost = F.getCostFunction(evalpts=x, observations=datapts)
#-----------------------------------------------------------------------------
# Call to Solver:
"""NOTES:
- solution is set of solved parameters a,b,c
- stepmon holds a log of optimization steps
"""
solution, stepmon = de_solve(myCost)
#-----------------------------------------------------------------------------
ND = 3
NP = 80
MAX_GENERATIONS = ND*NP
#-----------------------------------------------------------------------------
# Standard "Solver" Configuration:
"""NOTES:
- ND is number of parameters (a,b,c)
- NP is size of trial population
- MAX_GENERATIONS is maximum optimization iterations
#-----------------------------------------------------------------------------
- VerboseMonitor logs/prints "goodness of fit" and "best solution" at each step
- minrange/maxrange provide box constraints (for parameters a,b,c)
- SetRandomInitialPoints chooses an initial solution within box constraints
- SetStrictRanges only allows trial solutions within box constraints
- 'termination' conditions are to end when "no change" after 300 generations
- enable_signal_handler allows "interrupt" signal to be caught
- sigint_callback registers a user-provided function to the signal_handler
"""
def de_solve(CF):
solver = DifferentialEvolutionSolver(ND, NP)
solver.enable_signal_handler()
stepmon = VerboseMonitor(10,50)
minrange = [-100., -100., -100.]; maxrange = [100., 100., 100.];
solver.SetRandomInitialPoints(min = minrange, max = maxrange)
solver.SetStrictRanges(min = minrange, max = maxrange)
solver.SetEvaluationLimits(maxiter=MAX_GENERATIONS)
solver.SetGenerationMonitor(stepmon)
solver.Solve(CF, ChangeOverGeneration(generations=300),\
CrossProbability=0.5, ScalingFactor=0.5,\
sigint_callback=plot_sol)
solution = solver.Solution()
return solution, stepmon
#-----------------------------------------------------------------------------
# BONUS... the Callback Function:
"""NOTES:
- called on "catch" of signal-interrupt
- _MUST_ be a function of "params"
- _only_one_ configuration parameter is (currently) allowed
"""
def plot_sol(params,linestyle='b-'):
x = (array([list(range(101))])-50.)[0]
d = data(params)
pylab.plot(x,d,'%s'%linestyle,linewidth=2.0)
pylab.axis(plotview)
return
# DONE
|