File: backtracking.go

package info (click to toggle)
golang-gonum-v1-gonum 0.15.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 18,792 kB
  • sloc: asm: 6,252; fortran: 5,271; sh: 377; ruby: 211; makefile: 98
file content (84 lines) | stat: -rw-r--r-- 2,686 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package optimize

const (
	defaultBacktrackingContraction = 0.5
	defaultBacktrackingDecrease    = 1e-4
	minimumBacktrackingStepSize    = 1e-20
)

var _ Linesearcher = (*Backtracking)(nil)

// Backtracking is a Linesearcher that uses backtracking to find a point that
// satisfies the Armijo condition with the given decrease factor. If the Armijo
// condition has not been met, the step size is decreased by ContractionFactor.
//
// The Armijo condition only requires the gradient at the beginning of each
// major iteration (not at successive step locations), and so Backtracking may
// be a good linesearch for functions with expensive gradients. Backtracking is
// not appropriate for optimizers that require the Wolfe conditions to be met,
// such as BFGS.
//
// Both DecreaseFactor and ContractionFactor must be between zero and one, and
// Backtracking will panic otherwise. If either DecreaseFactor or
// ContractionFactor are zero, it will be set to a reasonable default.
type Backtracking struct {
	DecreaseFactor    float64 // Constant factor in the sufficient decrease (Armijo) condition.
	ContractionFactor float64 // Step size multiplier at each iteration (step *= ContractionFactor).

	stepSize float64
	initF    float64
	initG    float64

	lastOp Operation
}

func (b *Backtracking) Init(f, g float64, step float64) Operation {
	if step <= 0 {
		panic("backtracking: bad step size")
	}
	if g >= 0 {
		panic("backtracking: initial derivative is non-negative")
	}

	if b.ContractionFactor == 0 {
		b.ContractionFactor = defaultBacktrackingContraction
	}
	if b.DecreaseFactor == 0 {
		b.DecreaseFactor = defaultBacktrackingDecrease
	}
	if b.ContractionFactor <= 0 || b.ContractionFactor >= 1 {
		panic("backtracking: ContractionFactor must be between 0 and 1")
	}
	if b.DecreaseFactor <= 0 || b.DecreaseFactor >= 1 {
		panic("backtracking: DecreaseFactor must be between 0 and 1")
	}

	b.stepSize = step
	b.initF = f
	b.initG = g

	b.lastOp = FuncEvaluation
	return b.lastOp
}

func (b *Backtracking) Iterate(f, _ float64) (Operation, float64, error) {
	if b.lastOp != FuncEvaluation {
		panic("backtracking: Init has not been called")
	}

	if ArmijoConditionMet(f, b.initF, b.initG, b.stepSize, b.DecreaseFactor) {
		b.lastOp = MajorIteration
		return b.lastOp, b.stepSize, nil
	}
	b.stepSize *= b.ContractionFactor
	if b.stepSize < minimumBacktrackingStepSize {
		b.lastOp = NoOperation
		return b.lastOp, b.stepSize, ErrLinesearcherFailure
	}
	b.lastOp = FuncEvaluation
	return b.lastOp, b.stepSize, nil
}