File: l1reg.go

package info (click to toggle)
golang-github-kshedden-statmodel 0.0~git20210519.ee97d3e-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 892 kB
  • sloc: makefile: 3
file content (207 lines) | stat: -rw-r--r-- 3,987 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
package statmodel

import (
	"fmt"
	"math"
)

// Focuser restricts a model to one parameter.
type Focuser interface {
	NumParams() int
	NumObs() int
	Focus(int, []float64, []float64) RegFitter
	LogLike(Parameter, bool) float64
	Score(Parameter, []float64)
	Hessian(Parameter, HessType, []float64)
}

// FitL1Reg fits the provided L1RegFitter and returns the array of
// parameter values.
func FitL1Reg(model Focuser, param Parameter, l1wgt, offset []float64, checkstep bool) Parameter {

	maxiter := 400

	// A parameter for the 1-d focused model.
	param1d := param.Clone()
	param1d.SetCoeff([]float64{0})

	nvar := model.NumParams()
	nobs := model.NumObs()

	// Since we are using non-normalized log-likelihood, the
	// tolerance can scale with the sample size.
	tol := 1e-7 * float64(nobs)
	if tol > 0.1 {
		tol = 0.1
	}

	coeff := param.GetCoeff()

	// Outer coordinate descent loop.
	for iter := 0; iter < maxiter; iter++ {

		// L-inf of the increment in the parameter vector
		px := 0.0

		// Loop over covariates
		for j := 0; j < nvar; j++ {

			// Get the new point
			fmodel := model.Focus(j, coeff, offset)
			np := opt1d(fmodel, coeff[j], param1d, float64(nobs)*l1wgt[j], checkstep)

			// Update the change measure
			d := math.Abs(np - coeff[j])
			if d > px {
				px = d
			}

			coeff[j] = np
		}

		if px < tol {
			break
		}
	}

	return param
}

// Use a local quadratic approximation, then fall back to a line
// search if needed.
func opt1d(m1 RegFitter, coeff float64, par Parameter, l1wgt float64, checkstep bool) float64 {

	// Quadratic approximation coefficients
	bv := make([]float64, 1)
	par.SetCoeff([]float64{coeff})
	m1.Score(par, bv)
	b := -bv[0]
	cv := make([]float64, 1)
	m1.Hessian(par, ObsHess, cv)
	c := -cv[0]

	// The optimum point of the quadratic approximation
	d := b - c*coeff

	if l1wgt > math.Abs(d) {
		// The optimum is achieved by hard thresholding to zero
		return 0
	}

	// pj + h is the minimizer of Q(x) + L1_wt*abs(x)
	var h float64
	if d >= 0 {
		h = (l1wgt - b) / c
	} else if d < 0 {
		h = -(l1wgt + b) / c
	} else {
		panic(fmt.Sprintf("d=%f\n", d))
	}

	if !checkstep {
		return coeff + h
	}

	// Check whether the new point improves the target function.
	// This check is a bit expensive and not necessary for OLS
	par.SetCoeff([]float64{coeff})
	f0 := -m1.LogLike(par, false) + l1wgt*math.Abs(coeff)
	par.SetCoeff([]float64{coeff + h})
	f1 := -m1.LogLike(par, false) + l1wgt*math.Abs(coeff+h)
	if f1 <= f0+1e-10 {
		return coeff + h
	}

	// Wrap the log-likelihood so it takes a scalar argument.
	fw := func(z float64) float64 {
		par.SetCoeff([]float64{z})
		f := -m1.LogLike(par, false) + l1wgt*math.Abs(z)
		return f
	}

	// Fallback for models where the loss is not quadratic
	w := 1.0
	btol := 1e-7
	np := bisection(fw, coeff-w, coeff+w, btol)
	return np
}

// Standard bisection to minimize f.
func bisection(f func(float64) float64, xl, xu, tol float64) float64 {

	var x0, x1, x2, f0, f1, f2 float64

	// Try to find a bracket.
	success := false
	x0, x2 = xl, xu
	x1 = (x0 + x2) / 2
	f1 = f(x1)
	for k := 0; k < 100; k++ {

		// TODO recomputing some values here
		f0 = f(x0)
		f1 = f(x1)
		f2 = f(x2)

		if f1 < f0 && f1 < f2 {
			success = true
			break
		}

		if f0 > f1 && f1 > f2 {
			// Slide right
			x0 = x1
			x1 = x2
			x2 += 1.5 * (x1 - x0)
			continue
		}

		if f0 < f1 && f1 < f2 {
			// Slide left
			x1 = x0
			x2 = x1
			x0 -= 1.5 * (x2 - x1)
			continue
		}

		x0 = x1 - 2*(x1-x0)
		x2 = x1 + 2*(x2-x1)
	}

	if !success {
		fmt.Printf("Did not find bracket...\n")
		if f0 < f1 && f0 < f2 {
			return x0
		} else if f1 < f0 && f1 < f2 {
			return x1
		} else {
			return x2
		}
	}

	iter := 0
	for x2-x0 > tol {
		iter++
		if x1-x0 > x2-x1 {
			xx := (x0 + x1) / 2
			ff := f(xx)
			if ff < f1 {
				x2 = x1
				x1, f1 = xx, ff
			} else {
				x0 = xx
			}
		} else {
			xx := (x1 + x2) / 2
			ff := f(xx)
			if ff < f1 {
				x0 = x1
				x1, f1 = xx, ff
			} else {
				x2 = xx
			}
		}
	}

	return x1
}