1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fd
import "gonum.org/v1/gonum/floats"
// Gradient estimates the gradient of the multivariate function f at the
// location x. If dst is not nil, the result will be stored in-place into dst
// and returned, otherwise a new slice will be allocated first. Finite
// difference formula and other options are specified by settings. If settings is
// nil, the gradient will be estimated using the Forward formula and a default
// step size.
//
// Gradient panics if the length of dst and x is not equal, or if the derivative
// order of the formula is not 1.
func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 {
if dst == nil {
dst = make([]float64, len(x))
}
if len(dst) != len(x) {
panic("fd: slice length mismatch")
}
// Default settings.
formula := Forward
step := formula.Step
var originValue float64
var originKnown, concurrent bool
// Use user settings if provided.
if settings != nil {
if !settings.Formula.isZero() {
formula = settings.Formula
step = formula.Step
checkFormula(formula)
if formula.Derivative != 1 {
panic(badDerivOrder)
}
}
if settings.Step != 0 {
step = settings.Step
}
originKnown = settings.OriginKnown
originValue = settings.OriginValue
concurrent = settings.Concurrent
}
evals := len(formula.Stencil) * len(x)
nWorkers := computeWorkers(concurrent, evals)
hasOrigin := usesOrigin(formula.Stencil)
// Copy x in case it is modified during the call.
xcopy := make([]float64, len(x))
if hasOrigin && !originKnown {
copy(xcopy, x)
originValue = f(xcopy)
}
if nWorkers == 1 {
for i := range xcopy {
var deriv float64
for _, pt := range formula.Stencil {
if pt.Loc == 0 {
deriv += pt.Coeff * originValue
continue
}
// Copying the data anew has two benefits. First, it
// avoids floating point issues where adding and then
// subtracting the step don't return to the exact same
// location. Secondly, it protects against the function
// modifying the input data.
copy(xcopy, x)
xcopy[i] += pt.Loc * step
deriv += pt.Coeff * f(xcopy)
}
dst[i] = deriv / step
}
return dst
}
sendChan := make(chan fdrun, evals)
ansChan := make(chan fdrun, evals)
quit := make(chan struct{})
defer close(quit)
// Launch workers. Workers receive an index and a step, and compute the answer.
for i := 0; i < nWorkers; i++ {
go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) {
xcopy := make([]float64, len(x))
for {
select {
case <-quit:
return
case run := <-sendChan:
// See above comment on the copy.
copy(xcopy, x)
xcopy[run.idx] += run.pt.Loc * step
run.result = f(xcopy)
ansChan <- run
}
}
}(sendChan, ansChan, quit)
}
// Launch the distributor. Distributor sends the cases to be computed.
go func(sendChan chan<- fdrun, ansChan chan<- fdrun) {
for i := range x {
for _, pt := range formula.Stencil {
if pt.Loc == 0 {
// Answer already known. Send the answer on the answer channel.
ansChan <- fdrun{
idx: i,
pt: pt,
result: originValue,
}
continue
}
// Answer not known, send the answer to be computed.
sendChan <- fdrun{
idx: i,
pt: pt,
}
}
}
}(sendChan, ansChan)
for i := range dst {
dst[i] = 0
}
// Read in all of the results.
for i := 0; i < evals; i++ {
run := <-ansChan
dst[run.idx] += run.pt.Coeff * run.result
}
floats.Scale(1/step, dst)
return dst
}
type fdrun struct {
idx int
pt Point
result float64
}
|