File: lambertw.pyx

package info (click to toggle)
python-scipy 0.10.1%2Bdfsg2-1
  • links: PTS, VCS
  • area: main
  • in suites: wheezy
  • size: 42,232 kB
  • sloc: cpp: 224,773; ansic: 103,496; python: 85,210; fortran: 79,130; makefile: 272; sh: 43
file content (341 lines) | stat: -rw-r--r-- 10,526 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Implementation of the Lambert W function [1]. Based on the MPMath 
# implementation [2], and documentaion [3].
#
# Copyright: Yosef Meller, 2009
# Author email: mellerf@netvision.net.il
# 
# Distributed under the same license as SciPy
#
# References:
# [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
#     available online: http://www.apmaths.uwo.ca/~djeffrey/Offprints/W-adv-cm.pdf
# [2] mpmath source code, Subversion revision 990
#     http://code.google.com/p/mpmath/source/browse/trunk/mpmath/functions.py?spec=svn994&r=992
# [3] mpmath source code, Subversion revision 994
#     http://code.google.com/p/mpmath/source/browse/trunk/mpmath/function_docs.py?spec=svn994&r=994

# NaN checking as per suggestions of the cython-users list,
# http://groups.google.com/group/cython-users/browse_thread/thread/ff03eed8221bc36d

# TODO: use a series expansion when extremely close to the branch point
# at `-1/e` and make sure that the proper branch is chosen there

import cython
import warnings

cdef extern from "math.h":
    double exp(double x) nogil
    double log(double x) nogil

# Use Numpy's portable C99-compatible complex functios

cdef extern from "numpy/npy_math.h":
    ctypedef struct npy_cdouble:
        double real
        double imag

    double npy_cabs(npy_cdouble z) nogil
    npy_cdouble npy_clog(npy_cdouble z) nogil
    npy_cdouble npy_cexp(npy_cdouble z) nogil
    int npy_isnan(double x) nogil
    double NPY_INFINITY
    double NPY_PI

cdef inline bint zisnan(double complex x) nogil:
    return npy_isnan(x.real) or npy_isnan(x.imag)

cdef inline double zabs(double complex x) nogil:
    cdef double r
    r = npy_cabs((<npy_cdouble*>&x)[0])
    return r

cdef inline double complex zlog(double complex x) nogil:
    cdef npy_cdouble r
    r = npy_clog((<npy_cdouble*>&x)[0])
    return (<double complex*>&r)[0]

cdef inline double complex zexp(double complex x) nogil:
    cdef npy_cdouble r
    r = npy_cexp((<npy_cdouble*>&x)[0])
    return (<double complex*>&r)[0]

cdef void lambertw_raise_warning(double complex z) with gil:
    warnings.warn("Lambert W iteration failed to converge: %r" % z)

# Heavy lifting is here:

@cython.cdivision(True)
cdef double complex lambertw_scalar(double complex z, long k, double tol) nogil:
    """
    This is just the implementation of W for a single input z.
    See the docstring for lambertw() below for the full description.
    """
    # Comments copied verbatim from [2] are marked with '>'
    if zisnan(z):
        return z
    
    # Return value:
    cdef double complex w
    
    #> We must be extremely careful near the singularities at -1/e and 0
    cdef double u
    u = exp(-1)
    
    cdef double absz
    absz = zabs(z)
    if absz <= u:
        if z == 0:
            #> w(0,0) = 0; for all other branches we hit the pole
            if k == 0:
                return z
            return -NPY_INFINITY
        
        if k == 0:
            w = z # Initial guess for iteration
        #> For small real z < 0, the -1 branch beaves roughly like log(-z)
        elif k == -1 and z.imag ==0 and z.real < 0:
            w = log(-z.real)
        #> Use a simple asymptotic approximation.
        else:
            w = zlog(z)
            #> The branches are roughly logarithmic. This approximation
            #> gets better for large |k|; need to check that this always
            #> works for k ~= -1, 0, 1.
            if k: w = w + k*2*NPY_PI*1j
    
    elif k == 0 and z.imag and zabs(z) <= 0.7:
        #> Both the W(z) ~= z and W(z) ~= ln(z) approximations break
        #> down around z ~= -0.5 (converging to the wrong branch), so patch
        #> with a constant approximation (adjusted for sign)
        if zabs(z+0.5) < 0.1:
            if z.imag > 0:
                w = 0.7 + 0.7j
            else:
                w = 0.7 - 0.7j
        else:
            w = z
    
    else:
        if z.real == NPY_INFINITY:
            if k == 0:
                return z
            else:
                return z + 2*k*NPY_PI*1j
        
        if z.real == -NPY_INFINITY:
            return (-z) + (2*k+1)*NPY_PI*1j
                
        #> Simple asymptotic approximation as above
        w = zlog(z)
        if k: w = w + k*2*NPY_PI*1j

    #> Use Halley iteration to solve w*exp(w) = z
    cdef double complex ew, wew, wewz, wn
    cdef int i
    for i in range(100):
        ew = zexp(w)
        wew = w*ew
        wewz = wew-z
        wn = w - wewz / (wew + ew - (w + 2)*wewz/(2*w + 2))
        if zabs(wn-w) < tol*zabs(wn):
            return wn
        else:
            w = wn

    lambertw_raise_warning(z)
    return wn


# Turn the above function into a Ufunc:
#--------------------------------------
cdef extern from "numpy/arrayobject.h":
    void import_array()
    ctypedef int npy_intp
    cdef enum NPY_TYPES:
        NPY_LONG
        NPY_CDOUBLE
        NPY_DOUBLE

cdef extern from "numpy/ufuncobject.h":
    void import_ufunc()
    ctypedef void (*PyUFuncGenericFunction)(char**, npy_intp*, npy_intp*, void*)
    object PyUFunc_FromFuncAndData(PyUFuncGenericFunction* func, void** data,
        char* types, int ntypes, int nin, int nout,
        int identity, char* name, char* doc, int c)

cdef void _apply_func_to_1d_vec(char **args, npy_intp *dimensions, npy_intp *steps,
                     void *func) nogil:
    cdef npy_intp i
    cdef char *ip1=args[0], *ip2=args[1], *ip3=args[2], *op=args[3]
    for i in range(0, dimensions[0]):
        (<double complex*>op)[0] = (<double complex(*)(double complex, long, double) nogil>func)(
            (<double complex*>ip1)[0], (<long*>ip2)[0], (<double*>ip3)[0])
        ip1 += steps[0]; ip2 += steps[1]; ip3 += steps[2]; op += steps[3]

cdef PyUFuncGenericFunction _loop_funcs[1]
_loop_funcs[0] = _apply_func_to_1d_vec

cdef char _inp_outp_types[4]
_inp_outp_types[0] = NPY_CDOUBLE
_inp_outp_types[1] = NPY_LONG
_inp_outp_types[2] = NPY_DOUBLE
_inp_outp_types[3] = NPY_CDOUBLE

import_array()
import_ufunc()

# The actual ufunc declaration:
cdef void *the_func_to_apply[1]
the_func_to_apply[0] = <void*>lambertw_scalar
_lambertw = PyUFunc_FromFuncAndData(_loop_funcs, the_func_to_apply,
    _inp_outp_types, 1, 3, 1, 0, "", "", 0)

def lambertw(z, k=0, tol=1e-8):
    r"""
    lambertw(z, k=0, tol=1e-8)

    Lambert W function.

    The Lambert W function `W(z)` is defined as the inverse function
    of :math:`w \exp(w)`. In other words, the value of :math:`W(z)` is
    such that :math:`z = W(z) \exp(W(z))` for any complex number
    :math:`z`.

    The Lambert W function is a multivalued function with infinitely
    many branches. Each branch gives a separate solution of the
    equation :math:`w \exp(w)`. Here, the branches are indexed by the
    integer `k`.
    
    Parameters
    ----------
    z : array_like
        Input argument
    k : integer, optional
        Branch index
    tol : float
        Evaluation tolerance

    Notes
    -----
    All branches are supported by `lambertw`:

    * ``lambertw(z)`` gives the principal solution (branch 0)
    * ``lambertw(z, k)`` gives the solution on branch `k`

    The Lambert W function has two partially real branches: the
    principal branch (`k = 0`) is real for real `z > -1/e`, and the
    `k = -1` branch is real for `-1/e < z < 0`. All branches except
    `k = 0` have a logarithmic singularity at `z = 0`.

    .. rubric:: Possible issues
    
    The evaluation can become inaccurate very close to the branch point
    at `-1/e`. In some corner cases, :func:`lambertw` might currently
    fail to converge, or can end up on the wrong branch.

    .. rubric:: Algorithm

    Halley's iteration is used to invert `w \exp(w)`, using a first-order
    asymptotic approximation (`O(\log(w))` or `O(w)`) as the initial
    estimate.

    The definition, implementation and choice of branches is based
    on Corless et al, "On the Lambert W function", Adv. Comp. Math. 5
    (1996) 329-359, available online here:
    http://www.apmaths.uwo.ca/~djeffrey/Offprints/W-adv-cm.pdf
    
    TODO: use a series expansion when extremely close to the branch point
    at `-1/e` and make sure that the proper branch is chosen there

    Examples
    --------
    The Lambert W function is the inverse of `w \exp(w)`::

    >>> from scipy.special import lambertw
    >>> w = lambertw(1)
    >>> w
    0.56714329040978387299996866221035555
    >>> w*exp(w)
    1.0

    Any branch gives a valid inverse::

    >>> w = lambertw(1, k=3)
    >>> w
    (-2.8535817554090378072068187234910812 +
    17.113535539412145912607826671159289j)
    >>> w*exp(w)
    (1.0 + 3.5075477124212226194278700785075126e-36j)

    .. rubric:: Applications to equation-solving

    The Lambert W function may be used to solve various kinds of
    equations, such as finding the value of the infinite power
    tower `z^{z^{z^{\ldots}}}`::

    >>> def tower(z, n):
    ... if n == 0:
    ... return z
    ... return z ** tower(z, n-1)
    ...
    >>> tower(0.5, 100)
    0.641185744504986
    >>> -lambertw(-log(0.5))/log(0.5)
    0.6411857445049859844862004821148236665628209571911

    .. rubric:: Properties

    The Lambert W function grows roughly like the natural logarithm
    for large arguments::

    >>> lambertw(1000)
    5.2496028524016
    >>> log(1000)
    6.90775527898214
    >>> lambertw(10**100)
    224.843106445119
    >>> log(10**100)
    230.258509299405
    
    The principal branch of the Lambert W function has a rational
    Taylor series expansion around `z = 0`::
    
    >>> nprint(taylor(lambertw, 0, 6), 10)
    [0.0, 1.0, -1.0, 1.5, -2.666666667, 5.208333333, -10.8]
    
    Some special values and limits are::
    
    >>> lambertw(0)
    0.0
    >>> lambertw(1)
    0.567143290409784
    >>> lambertw(e)
    1.0
    >>> lambertw(inf)
    +inf
    >>> lambertw(0, k=-1)
    -inf
    >>> lambertw(0, k=3)
    -inf
    >>> lambertw(inf, k=3)
    (+inf + 18.8495559215388j)

    The `k = 0` and `k = -1` branches join at `z = -1/e` where
    `W(z) = -1` for both branches. Since `-1/e` can only be represented
    approximately with mpmath numbers, evaluating the Lambert W function
    at this point only gives `-1` approximately::

    >>> lambertw(-1/e, 0)
    -0.999999999999837133022867
    >>> lambertw(-1/e, -1)
    -1.00000000000016286697718
    
    If `-1/e` happens to round in the negative direction, there might be
    a small imaginary part::
    
    >>> lambertw(-1/e)
    (-1.0 + 8.22007971511612e-9j)

    """
    return _lambertw(z, k, tol)