File: lambertw.pxd

package info (click to toggle)
python-scipy 0.14.0-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 52,228 kB
  • ctags: 63,719
  • sloc: python: 112,726; fortran: 88,685; cpp: 86,979; ansic: 85,860; makefile: 530; sh: 236
file content (119 lines) | stat: -rw-r--r-- 3,808 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# -*-cython-*-
#
# Implementation of the Lambert W function [1]. Based on the MPMath 
# implementation [2], and documentaion [3].
#
# Copyright: Yosef Meller, 2009
# Author email: mellerf@netvision.net.il
# 
# Distributed under the same license as SciPy
#
# References:
# [1] On the Lambert W function, Adv. Comp. Math. 5 (1996) 329-359,
#     available online: http://www.apmaths.uwo.ca/~djeffrey/Offprints/W-adv-cm.pdf
# [2] mpmath source code, Subversion revision 990
#     http://code.google.com/p/mpmath/source/browse/trunk/mpmath/functions.py?spec=svn994&r=992
# [3] mpmath source code, Subversion revision 994
#     http://code.google.com/p/mpmath/source/browse/trunk/mpmath/function_docs.py?spec=svn994&r=994

# NaN checking as per suggestions of the cython-users list,
# http://groups.google.com/group/cython-users/browse_thread/thread/ff03eed8221bc36d

# TODO: use a series expansion when extremely close to the branch point
# at `-1/e` and make sure that the proper branch is chosen there

import cython

cimport sf_error

cdef extern from "math.h":
    double exp(double x) nogil
    double log(double x) nogil

from _complexstuff cimport *

# Heavy lifting is here:

@cython.cdivision(True)
cdef inline double complex lambertw_scalar(double complex z, long k, double tol) nogil:
    """
    This is just the implementation of W for a single input z.
    See the docstring for lambertw() below for the full description.
    """
    # Comments copied verbatim from [2] are marked with '>'
    if zisnan(z):
        return z

    # Return value:
    cdef double complex w
    
    #> We must be extremely careful near the singularities at -1/e and 0
    cdef double u
    u = exp(-1)
    
    cdef double absz
    absz = zabs(z)
    if absz <= u:
        if z == 0:
            #> w(0,0) = 0; for all other branches we hit the pole
            if k == 0:
                return z
            sf_error.error("lambertw", sf_error.SINGULAR, NULL)
            return -inf
        
        if k == 0:
            w = z # Initial guess for iteration
        #> For small real z < 0, the -1 branch beaves roughly like log(-z)
        elif k == -1 and z.imag ==0 and z.real < 0:
            w = log(-z.real)
        #> Use a simple asymptotic approximation.
        else:
            w = zlog(z)
            #> The branches are roughly logarithmic. This approximation
            #> gets better for large |k|; need to check that this always
            #> works for k ~= -1, 0, 1.
            if k: w = w + k*2*pi*1j
    
    elif k == 0 and z.imag and zabs(z) <= 0.7:
        #> Both the W(z) ~= z and W(z) ~= ln(z) approximations break
        #> down around z ~= -0.5 (converging to the wrong branch), so patch
        #> with a constant approximation (adjusted for sign)
        if zabs(z+0.5) < 0.1:
            if z.imag > 0:
                w = 0.7 + 0.7j
            else:
                w = 0.7 - 0.7j
        else:
            w = z
    
    else:
        if z.real == inf:
            if k == 0:
                return z
            else:
                return z + 2*k*pi*1j
        
        if z.real == -inf:
            return (-z) + (2*k+1)*pi*1j
                
        #> Simple asymptotic approximation as above
        w = zlog(z)
        if k: w = w + k*2*pi*1j

    #> Use Halley iteration to solve w*exp(w) = z
    cdef double complex ew, wew, wewz, wn
    cdef int i
    for i in range(100):
        ew = zexp(w)
        wew = w*ew
        wewz = wew-z
        wn = w - wewz / (wew + ew - (w + 2)*wewz/(2*w + 2))
        if zabs(wn-w) < tol*zabs(wn):
            return wn
        else:
            w = wn

    sf_error.error("lambertw", sf_error.SLOW,
                   "iteration failed to converge: %g + %gj",
                   <double>z.real, <double>z.imag)
    return nan