File: _trlib.pyx

package info (click to toggle)
python-scipy 1.1.0-7
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 93,828 kB
  • sloc: python: 156,854; ansic: 82,925; fortran: 80,777; cpp: 7,505; makefile: 427; sh: 294
file content (150 lines) | stat: -rw-r--r-- 6,177 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from scipy.optimize._trustregion import (_minimize_trust_region, BaseQuadraticSubproblem)
import numpy as np
cimport ctrlib
cimport libc.stdio
cimport numpy as np

from scipy._lib.messagestream cimport MessageStream


class TRLIBQuadraticSubproblem(BaseQuadraticSubproblem):

    def __init__(self, x, fun, jac, hess, hessp, tol_rel_i=-2.0, tol_rel_b=-3.0,
                 disp=False):
        super(TRLIBQuadraticSubproblem, self).__init__(x, fun, jac, hess, hessp)
        self.tol_rel_i = tol_rel_i
        self.tol_rel_b = tol_rel_b
        self.disp = disp
        self.itmax = int(min(1e9/self.jac.shape[0], 2*self.jac.shape[0]))
        cdef long itmax, iwork_size, fwork_size, h_pointer
        itmax = self.itmax
        ctrlib.trlib_krylov_memory_size(itmax, &iwork_size, &fwork_size,
                                        &h_pointer)
        self.h_pointer = h_pointer
        self.fwork = np.empty([fwork_size])
        cdef double [:] fwork_view = self.fwork
        cdef double *fwork_ptr = NULL
        if fwork_view.shape[0] > 0:
            fwork_ptr = &fwork_view[0]
        ctrlib.trlib_krylov_prepare_memory(itmax, fwork_ptr)
        self.iwork = np.zeros([iwork_size], dtype=np.int)
        self.s  = np.empty(self.jac.shape)
        self.g  = np.empty(self.jac.shape)
        self.v  = np.empty(self.jac.shape)
        self.gm = np.empty(self.jac.shape)
        self.p  = np.empty(self.jac.shape)
        self.Hp = np.empty(self.jac.shape)
        self.Q  = np.empty([self.itmax+1, self.jac.shape[0]])
        self.timing = np.zeros([ctrlib.trlib_krylov_timing_size()],
                               dtype=np.int)
        self.init = ctrlib._TRLIB_CLS_INIT

    def solve(self, double trust_radius):

        cdef long equality = 0
        cdef long itmax_lanczos = 100
        cdef double tol_r_i = self.tol_rel_i
        cdef double tol_a_i =  0.0 
        cdef double tol_r_b = self.tol_rel_b
        cdef double tol_a_b =  0.0 
        cdef double zero = 2e-16
        cdef double obj_lb = -1e20
        cdef long ctl_invariant = 0
        cdef long convexify = 1
        cdef long earlyterm = 1
        cdef double g_dot_g = 0.0
        cdef double v_dot_g = 0.0
        cdef double p_dot_Hp = 0.0
        cdef long refine = 1
        cdef long verbose = 0
        cdef long unicode = 1
        cdef long ret = 0
        cdef long action = 0
        cdef long it = 0
        cdef long ityp = 0
        cdef long itmax = self.itmax
        cdef long init  = self.init
        cdef double flt1 = 0.0
        cdef double flt2 = 0.0
        cdef double flt3 = 0.0
        prefix = b""
        cdef long   [:] iwork_view  = self.iwork
        cdef double [:] fwork_view  = self.fwork
        cdef long   [:] timing_view = self.timing
        cdef long   *iwork_ptr = NULL
        cdef double *fwork_ptr = NULL
        cdef long   *timing_ptr = NULL

        if self.disp:
            verbose = 2

        if iwork_view.shape[0] > 0:
            iwork_ptr = &iwork_view[0]
        if fwork_view.shape[0] > 0:
            fwork_ptr = &fwork_view[0]
        if timing_view.shape[0] > 0:
            timing_ptr = &timing_view[0]

        cdef MessageStream messages = MessageStream()
        try:
            while True:
                messages.clear()
                ret = ctrlib.trlib_krylov_min(init, trust_radius, equality,
                          itmax, itmax_lanczos, tol_r_i, tol_a_i,
                          tol_r_b, tol_a_b, zero, obj_lb, ctl_invariant,
                          convexify, earlyterm, g_dot_g, v_dot_g, p_dot_Hp,
                          iwork_ptr, fwork_ptr, refine, verbose, unicode,
                          prefix, messages.handle,
                          timing_ptr, &action, &it, &ityp, &flt1, &flt2, &flt3)
                if self.disp:
                    msg = messages.get()
                    if msg:
                        print(msg)

                init = 0
                if action == ctrlib._TRLIB_CLA_INIT:
                    self.s[:]  = 0.0
                    self.gm[:] = 0.0
                    self.g[:]  = self.jac
                    self.v[:]  = self.g
                    g_dot_g = np.dot(self.g, self.g)
                    v_dot_g = np.dot(self.v, self.g)
                    self.p[:]  = - self.v
                    self.Hp[:] = self.hessp(self.p)
                    p_dot_Hp = np.dot(self.p, self.Hp)
                    self.Q[0,:] = self.v/np.sqrt(v_dot_g)
                if action == ctrlib._TRLIB_CLA_RETRANSF:
                    self.s[:] = np.dot(self.fwork[self.h_pointer:self.h_pointer+it+1],
                                       self.Q[:it+1,:])
                if action == ctrlib._TRLIB_CLA_UPDATE_STATIO:
                    if ityp == ctrlib._TRLIB_CLT_CG:
                        self.s += flt1 * self.p
                if action == ctrlib._TRLIB_CLA_UPDATE_GRAD:
                    if ityp == ctrlib._TRLIB_CLT_CG:
                        self.Q[it,:] = flt2*self.v
                        self.gm[:] = self.g
                        self.g    += flt1*self.Hp
                    if ityp == ctrlib._TRLIB_CLT_L:
                        self.s[:]  = self.Hp + flt1*self.g + flt2*self.gm
                        self.gm[:] = flt3*self.g
                        self.g[:]  = self.s
                    self.v[:] = self.g
                    g_dot_g = np.dot(self.g, self.g)
                    v_dot_g = np.dot(self.v, self.g)
                if action == ctrlib._TRLIB_CLA_UPDATE_DIR:
                    self.p[:]  = flt1 * self.v + flt2 * self.p
                    self.Hp[:] = self.hessp(self.p)
                    p_dot_Hp = np.dot(self.p, self.Hp)
                    if ityp == ctrlib._TRLIB_CLT_L:
                        self.Q[it,:] = self.p
                if action == ctrlib._TRLIB_CLA_OBJVAL:
                    g_dot_g = .5*np.dot(self.s, self.hessp(self.s))
                    g_dot_g += np.dot(self.s, self.jac)
                if ret < 10:
                    break
                self.init = ctrlib._TRLIB_CLS_HOTSTART
            self.lam = self.fwork[7]
        finally:
            messages.close()

        return self.s, self.lam > 0.0