1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fd
import (
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/mat"
)
// ConstFunc is a constant function returning the value held by the type.
type ConstFunc float64
func (c ConstFunc) Func(x []float64) float64 {
return float64(c)
}
func (c ConstFunc) Grad(grad, x []float64) {
for i := range grad {
grad[i] = 0
}
}
func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, 0)
}
}
}
// LinearFunc is a linear function returning w*x+c.
type LinearFunc struct {
w []float64
c float64
}
func (l LinearFunc) Func(x []float64) float64 {
return floats.Dot(l.w, x) + l.c
}
func (l LinearFunc) Grad(grad, x []float64) {
copy(grad, l.w)
}
func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, 0)
}
}
}
// QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c.
type QuadFunc struct {
a *mat.SymDense
b *mat.VecDense
c float64
}
func (q QuadFunc) Func(x []float64) float64 {
v := mat.NewVecDense(len(x), x)
var tmp mat.VecDense
tmp.MulVec(q.a, v)
return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c
}
func (q QuadFunc) Grad(grad, x []float64) {
var tmp mat.VecDense
v := mat.NewVecDense(len(x), x)
tmp.MulVec(q.a, v)
for i := range grad {
grad[i] = tmp.At(i, 0) + q.b.At(i, 0)
}
}
func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) {
n := len(x)
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
dst.SetSym(i, j, q.a.At(i, j))
}
}
}
|