1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
|
"""
Routines for evaluating and manipulating piecewise polynomials in
local power basis.
"""
from .polyint import _Interpolator1D
import numpy as np
cimport cython
cimport libc.stdlib
cimport libc.math
ctypedef double complex double_complex
ctypedef fused double_or_complex:
double
double complex
cdef extern from "blas_defs.h":
void c_dgeev(char *jobvl, char *jobvr, int *n, double *a,
int *lda, double *wr, double *wi, double *vl, int *ldvl,
double *vr, int *ldvr, double *work, int *lwork,
int *info)
cdef extern from "numpy/npy_math.h":
double nan "NPY_NAN"
DEF MAX_DIMS = 64
#------------------------------------------------------------------------------
# Piecewise power basis polynomials
#------------------------------------------------------------------------------
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def evaluate(double_or_complex[:,:,::1] c,
double[::1] x,
double[::1] xp,
int dx,
bint extrapolate,
double_or_complex[:,::1] out):
"""
Evaluate a piecewise polynomial.
Parameters
----------
c : ndarray, shape (k, m, n)
Coefficients local polynomials of order `k-1` in `m` intervals.
There are `n` polynomials in each interval.
Coefficient of highest order-term comes first.
x : ndarray, shape (m+1,)
Breakpoints of polynomials
xp : ndarray, shape (r,)
Points to evaluate the piecewise polynomial at.
dx : int
Order of derivative to evaluate. The derivative is evaluated
piecewise and may have discontinuities.
extrapolate : bint, optional
Whether to extrapolate to out-of-bounds points based on first
and last intervals, or to return NaNs.
out : ndarray, shape (r, n)
Value of each polynomial at each of the input points.
This argument is modified in-place.
"""
cdef int ip, jp
cdef int interval
cdef double xval
# check derivative order
if dx < 0:
raise ValueError("Order of derivative cannot be negative")
# shape checks
if out.shape[0] != xp.shape[0]:
raise ValueError("out and xp have incompatible shapes")
if out.shape[1] != c.shape[2]:
raise ValueError("out and c have incompatible shapes")
if c.shape[1] != x.shape[0] - 1:
raise ValueError("x and c have incompatible shapes")
# evaluate
interval = 0
for ip in range(len(xp)):
xval = xp[ip]
# Find correct interval
i = find_interval(&x[0], x.shape[0], xval, interval, extrapolate)
if i < 0:
# xval was nan etc
for jp in range(c.shape[2]):
out[ip, jp] = nan
continue
else:
interval = i
# Evaluate the local polynomial(s)
for jp in range(c.shape[2]):
out[ip, jp] = evaluate_poly1(xval - x[interval], c, interval, jp, dx)
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def evaluate_nd(double_or_complex[:,:,::1] c,
tuple xs,
int[:] ks,
double[:,:] xp,
int[:] dx,
int extrapolate,
double_or_complex[:,::1] out):
"""
Evaluate a piecewise tensor-product polynomial.
Parameters
----------
c : ndarray, shape (k_1*...*k_d, m_1*...*m_d, n)
Coefficients local polynomials of order `k-1` in
`m_1`, ..., `m_d` intervals. There are `n` polynomials
in each interval.
ks : ndarray of int, shape (d,)
Orders of polynomials in each dimension
xs : d-tuple of ndarray of shape (m_d+1,) each
Breakpoints of polynomials
xp : ndarray, shape (r, d)
Points to evaluate the piecewise polynomial at.
dx : ndarray of int, shape (d,)
Orders of derivative to evaluate. The derivative is evaluated
piecewise and may have discontinuities.
extrapolate : int, optional
Whether to extrapolate to out-of-bounds points based on first
and last intervals, or to return NaNs.
out : ndarray, shape (r, n)
Value of each polynomial at each of the input points.
For points outside the span ``x[0] ... x[-1]``,
``nan`` is returned.
This argument is modified in-place.
"""
cdef size_t ntot
cdef ssize_t strides[MAX_DIMS]
cdef ssize_t kstrides[MAX_DIMS]
cdef double* xx[MAX_DIMS]
cdef size_t nxx[MAX_DIMS]
cdef double[::1] y
cdef double_or_complex[:,:,::1] c2
cdef int ip, jp, k, ndim
cdef int interval[MAX_DIMS]
cdef int pos, kpos, koutpos
cdef int out_of_range
cdef double xval
ndim = len(xs)
if ndim > MAX_DIMS:
raise ValueError("Too many dimensions (maximum: %d)" % (MAX_DIMS,))
# shape checks
if dx.shape[0] != ndim:
raise ValueError("dx has incompatible shape")
if xp.shape[1] != ndim:
raise ValueError("xp has incompatible shape")
if out.shape[0] != xp.shape[0]:
raise ValueError("out and xp have incompatible shapes")
if out.shape[1] != c.shape[2]:
raise ValueError("out and c have incompatible shapes")
# compute interval strides
ntot = 1
for ip in xrange(ndim-1, -1, -1):
if dx[ip] < 0:
raise ValueError("Order of derivative cannot be negative")
y = xs[ip]
if y.shape[0] < 2:
raise ValueError("each dimension must have >= 2 points")
strides[ip] = ntot
ntot *= y.shape[0] - 1
# grab array pointers
nxx[ip] = y.shape[0]
xx[ip] = <double*>&y[0]
y = None
if c.shape[1] != ntot:
raise ValueError("xs and c have incompatible shapes")
# compute order strides
ntot = 1
for ip in xrange(ndim):
kstrides[ip] = ntot
ntot *= ks[ip]
if c.shape[0] != ntot:
raise ValueError("ks and c have incompatible shapes")
# temporary storage
if double_or_complex is double:
c2 = np.zeros((c.shape[0], 1, 1), dtype=float)
else:
c2 = np.zeros((c.shape[0], 1, 1), dtype=complex)
# evaluate
for ip in xrange(ndim):
interval[ip] = 0
for ip in range(xp.shape[0]):
out_of_range = 0
# Find correct intervals
for k in range(ndim):
xval = xp[ip, k]
i = find_interval(xx[k],
nxx[k],
xval,
interval[k],
extrapolate)
if i < 0:
out_of_range = 1
break
else:
interval[k] = i
if out_of_range:
# xval was nan etc
for jp in range(c.shape[2]):
out[ip, jp] = nan
continue
pos = 0
for k in range(ndim):
pos += interval[k] * strides[k]
# Evaluate the local polynomials, via nested 1D polynomial evaluation
#
# sum_{ijk} c[kx-i,ky-j,kz-k] x**i y**j z**k = sum_i a[i] x**i
# a[i] = sum_j b[i,j] y**j
# b[i,j] = sum_k c[kx-i,ky-j,kz-k] z**k
#
# The array c2 is used to hold the intermediate sums a,b,...
for jp in range(c.shape[2]):
c2[:,0,0] = c[:,pos,jp]
for k in range(ndim-1, -1, -1):
xval = xp[ip, k] - xx[k][interval[k]]
kpos = 0
for koutpos in range(kstrides[k]):
c2[koutpos,0,0] = evaluate_poly1(xval, c2[kpos:kpos+ks[k],:,:], 0, 0, dx[k])
kpos += ks[k]
out[ip,jp] = c2[0,0,0]
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def fix_continuity(double_or_complex[:,:,::1] c,
double[::1] x,
int order):
"""
Make a piecewise polynomial continuously differentiable to given order.
Parameters
----------
c : ndarray, shape (k, m, n)
Coefficients local polynomials of order `k-1` in `m` intervals.
There are `n` polynomials in each interval.
Coefficient of highest order-term comes first.
Coefficients c[-order-1:] are modified in-place.
x : ndarray, shape (m+1,)
Breakpoints of polynomials
order : int
Order up to which enforce piecewise differentiability.
"""
cdef int ip, jp, kp, dx
cdef int interval
cdef double_or_complex res
cdef double xval
# check derivative order
if order < 0:
raise ValueError("Order of derivative cannot be negative")
# shape checks
if c.shape[1] != x.shape[0] - 1:
raise ValueError("x and c have incompatible shapes")
if order >= c.shape[0] - 1:
raise ValueError("order too large")
if order < 0:
raise ValueError("order negative")
# evaluate
for ip in range(1, len(x)-1):
xval = x[ip]
interval = ip - 1
for jp in range(c.shape[2]):
# ensure continuity for derivatives, starting at the
# highest one (the lower derivatives depend on the higher
# ones, but not vice versa)
for dx in range(order, -1, -1):
# evaluate dx-th derivative of the polynomial in previous interval
res = evaluate_poly1(xval - x[interval], c, interval, jp, dx)
# set dx-th coefficient of polynomial in current
# interval so that the dx-th derivative is continuous
for kp in range(dx):
res /= kp + 1
c[c.shape[0] - dx - 1, ip, jp] = res
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def integrate(double_or_complex[:,:,::1] c,
double[::1] x,
double a,
double b,
bint extrapolate,
double_or_complex[::1] out):
"""
Compute integral over a piecewise polynomial.
Parameters
----------
c : ndarray, shape (k, m, n)
Coefficients local polynomials of order `k-1` in `m` intervals.
x : ndarray, shape (m+1,)
Breakpoints of polynomials
a : double
Start point of integration.
b : double
End point of integration.
extrapolate : bint, optional
Whether to extrapolate to out-of-bounds points based on first
and last intervals, or to return NaNs.
out : ndarray, shape (n,)
Integral of the piecewise polynomial, assuming the polynomial
is zero outside the range (x[0], x[-1]).
This argument is modified in-place.
"""
cdef int jp
cdef int start_interval, end_interval, interval
cdef double_or_complex va, vb, vtot
# shape checks
if c.shape[1] != x.shape[0] - 1:
raise ValueError("x and c have incompatible shapes")
if out.shape[0] != c.shape[2]:
raise ValueError("x and c have incompatible shapes")
# fix integration order
if not (b >= a):
raise ValueError("Integral bounds not in order")
# find intervals
start_interval = find_interval(&x[0], x.shape[0], a, 0, extrapolate)
if start_interval < 0:
out[:] = nan
return
end_interval = find_interval(&x[0], x.shape[0], b, 0, extrapolate)
if end_interval < 0:
out[:] = nan
return
# evaluate
for jp in range(c.shape[2]):
vtot = 0
for interval in range(start_interval, end_interval+1):
# local antiderivative, end point
if interval == end_interval:
vb = evaluate_poly1(b - x[interval], c, interval, jp, -1)
else:
vb = evaluate_poly1(x[interval+1] - x[interval], c, interval, jp, -1)
# local antiderivative, start point
if interval == start_interval:
va = evaluate_poly1(a - x[interval], c, interval, jp, -1)
else:
va = evaluate_poly1(0, c, interval, jp, -1)
# integral
vtot = vtot + (vb - va)
out[jp] = vtot
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def real_roots(double[:,:,::1] c, double[::1] x, double y, bint report_discont,
bint extrapolate):
"""
Compute real roots of a real-valued piecewise polynomial function.
If a section of the piecewise polynomial is identically zero, the
values (x[begin], nan) are appended to the root list.
If the piecewise polynomial is not continuous, and the sign
changes across a breakpoint, the breakpoint is added to the root
set if `report_discont` is True.
Parameters
----------
c, x
Polynomial coefficients, as above
y : float
Find roots of ``pp(x) == y``.
report_discont : bint, optional
Whether to report discontinuities across zero at breakpoints
as roots
extrapolate : bint, optional
Whether to consider roots obtained by extrapolating based
on first and last intervals.
"""
cdef list roots
cdef list cur_roots
cdef int interval, jp, k, i, p
cdef double *wr
cdef double *wi
cdef double last_root, va, vb
cdef double f, df, dx
cdef void *workspace
if c.shape[1] != x.shape[0] - 1:
raise ValueError("x and c have incompatible shapes")
if c.shape[0] == 0:
return np.array([], dtype=float)
wr = <double*>libc.stdlib.malloc(c.shape[0] * sizeof(double))
wi = <double*>libc.stdlib.malloc(c.shape[0] * sizeof(double))
workspace = NULL
last_root = nan
roots = []
try:
for jp in range(c.shape[2]):
cur_roots = []
for interval in range(c.shape[1]):
# Check for sign change across intervals
if interval > 0 and report_discont:
va = evaluate_poly1(x[interval] - x[interval-1],
c, interval-1, jp, 0) - y
vb = evaluate_poly1(0, c, interval, jp, 0) - y
if (va < 0 and vb > 0) or (va > 0 and vb < 0):
# sign change between intervals
if x[interval] != last_root:
last_root = x[interval]
cur_roots.append(float(last_root))
# Compute first the complex roots
k = croots_poly1(c, y, interval, jp, wr, wi, &workspace)
# Check for errors and identically zero values
if k == -1:
# Zero everywhere
if x[interval] == x[interval+1]:
# Only a point
if x[interval] != last_root:
last_root = x[interval]
cur_roots.append(x[interval])
else:
# A real interval
cur_roots.append(x[interval])
cur_roots.append(np.nan)
last_root = nan
continue
elif k < -1:
# An error occurred
raise RuntimeError("Internal error in root finding; "
"please report this bug")
elif k == 0:
# No roots
continue
# Filter real roots
for i in range(k):
# Check real root
#
# The reality of a root is a decision that can be left to LAPACK,
# which has to determine this in any case.
if wi[i] != 0:
continue
# Refine root by one Newton iteration
f = evaluate_poly1(wr[i], c, interval, jp, 0) - y
df = evaluate_poly1(wr[i], c, interval, jp, 1)
if df != 0:
dx = f/df
if abs(dx) < abs(wr[i]):
wr[i] = wr[i] - dx
# Check interval
wr[i] += x[interval]
if interval == 0 and extrapolate:
# Half-open to the left
if not wr[i] <= x[interval+1]:
continue
elif interval == c.shape[1] - 1 and extrapolate:
# Half-open to the right
if not wr[i] >= x[interval]:
continue
else:
if not (x[interval] <= wr[i] <= x[interval+1]):
continue
# Add to list
if wr[i] != last_root:
last_root = wr[i]
cur_roots.append(float(last_root))
# Construct roots
roots.append(np.array(cur_roots, dtype=float))
finally:
if workspace != NULL:
libc.stdlib.free(workspace)
libc.stdlib.free(wr)
libc.stdlib.free(wi)
return roots
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int find_interval(double *x,
size_t nx,
double xval,
int prev_interval=0,
bint extrapolate=1) nogil:
"""
Find an interval such that x[interval] <= xval < x[interval+1]
or interval == 0 and xval < x[0]
or interval == n-2 and xval > x[n-1]
Parameters
----------
x : array of double, shape (m,)
Piecewise polynomial breakpoints
xval : double
Point to find
prev_interval : int, optional
Interval where a previous point was found
extrapolate : bint, optional
Whether to return the last of the first interval if the
point is out-of-bounds.
Returns
-------
interval : int
Suitable interval or -1 if nan.
"""
cdef int interval, high, low, mid
cdef double a, b
a = x[0]
b = x[nx-1]
interval = prev_interval
if interval < 0 or interval >= nx:
interval = 0
if not (a <= xval <= b):
# Out-of-bounds (or nan)
if xval < a and extrapolate:
# below
interval = 0
elif xval > b and extrapolate:
# above
interval = nx - 2
else:
# nan or no extrapolation
interval = -1
elif xval == b:
# Make the interval closed from the right
interval = nx - 2
else:
# Find the interval the coordinate is in
# (binary search with locality)
if xval >= x[interval]:
low = interval
high = nx - 2
else:
low = 0
high = interval
if xval < x[low+1]:
high = low
while low < high:
mid = (high + low)//2
if xval < x[mid]:
# mid < high
high = mid
elif xval >= x[mid + 1]:
low = mid + 1
else:
# x[mid] <= xval < x[mid+1]
low = mid
break
interval = low
return interval
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
cdef double_or_complex evaluate_poly1(double s, double_or_complex[:,:,::1] c, int ci, int cj, int dx) nogil:
"""
Evaluate polynomial, derivative, or antiderivative in a single interval.
Antiderivatives are evaluated assuming zero integration constants.
Parameters
----------
s : double
Polynomial x-value
c : double[:,:,:]
Polynomial coefficients. c[:,ci,cj] will be used
ci, cj : int
Which of the coefs to use
dx : int
Order of derivative (> 0) or antiderivative (< 0) to evaluate.
"""
cdef int kp, k
cdef double_or_complex res, z
cdef double prefactor
res = 0.0
z = 1.0
if dx < 0:
for k in range(-dx):
z *= s
for kp in range(c.shape[0]):
# prefactor of term after differentiation
if dx == 0:
prefactor = 1.0
elif dx > 0:
# derivative
if kp < dx:
continue
else:
prefactor = 1.0
for k in range(kp, kp - dx, -1):
prefactor *= k
else:
# antiderivative
prefactor = 1.0
for k in range(kp, kp - dx):
prefactor /= k + 1
res = res + c[c.shape[0] - kp - 1, ci, cj] * z * prefactor
# compute x**max(k-dx,0)
if kp < c.shape[0] - 1 and kp >= dx:
z *= s
return res
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
cdef int croots_poly1(double[:,:,::1] c, double y, int ci, int cj,
double* wr, double* wi, void **workspace):
"""
Find all complex roots of a local polynomial.
Parameters
----------
c : ndarray, shape (k, m, n)
Coefficients of polynomials of order k
y : float
right-hand side of ``pp(x) == y``.
ci, cj : int
Index of the local polynomial whose coefficients c[:,ci,cj] to use
wr, wi : double*
Allocated double arrays of size `k`. The complex roots are stored
here after call. The roots are sorted in increasing order according
to the real part.
workspace : double**
Work space pointer. workspace[0] should be NULL on initial
call. Multiple subsequent calls with same `k` can share the
same `workspace`. If workspace[0] is non-NULL after the
calls, it must be freed with libc.stdlib.free.
Returns
-------
nroots : int
How many roots found for the polynomial.
If `-1`, the polynomial is identically zero.
If `< -1`, an error occurred.
Notes
-----
Uses LAPACK + the companion matrix method.
"""
cdef double *a
cdef double *work
cdef double a0, a1, a2, d, br, bi, cc
cdef int lwork, n, i, j, order
cdef int nworkspace, info
n = c.shape[0]
# Check actual polynomial order
for j in range(n):
if c[j,ci,cj] != 0:
order = n - 1 - j
break
else:
order = -1
if order < 0:
# Zero everywhere
if y == 0:
return -1
else:
return 0
elif order == 0:
# Nonzero constant polynomial: no roots
# (unless r.h.s. is exactly equal to the coefficient, that is.)
if c[n-1, ci, cj] == y:
return -1
else:
return 0
elif order == 1:
# Low-order polynomial: a0*x + a1
a0 = c[n-1-order,ci,cj]
a1 = c[n-1-order+1,ci,cj] - y
wr[0] = -a1 / a0
wi[0] = 0
return 1
elif order == 2:
# Low-order polynomial: a0*x**2 + a1*x + a2
a0 = c[n-1-order,ci,cj]
a1 = c[n-1-order+1,ci,cj]
a2 = c[n-1-order+2,ci,cj] - y
d = a1*a1 - 4*a0*a2
if d < 0:
# no real roots
d = libc.math.sqrt(-d)
wr[0] = -a1/(2*a0)
wi[0] = -d/(2*a0)
wr[1] = -a1/(2*a0)
wi[1] = d/(2*a0)
return 2
d = libc.math.sqrt(d)
# avoid cancellation in subtractions
if d == 0:
wr[0] = -a1/(2*a0)
wi[0] = 0
wr[1] = -a1/(2*a0)
wi[1] = 0
elif a1 < 0:
wr[0] = (2*a2) / (-a1 + d) # == (-a1 - d)/(2*a0)
wi[0] = 0
wr[1] = (-a1 + d) / (2*a0)
wi[1] = 0
else:
wr[0] = (-a1 - d)/(2*a0)
wi[0] = 0
wr[1] = (2*a2) / (-a1 - d) # == (-a1 + d)/(2*a0)
wi[1] = 0
return 2
# Compute required workspace and allocate it
lwork = 1 + 8*n
if workspace[0] == NULL:
nworkspace = n*n + lwork
workspace[0] = libc.stdlib.malloc(nworkspace * sizeof(double))
a = <double*>workspace[0]
work = a + n*n
# Initialize the companion matrix, Fortran order
for j in range(order*order):
a[j] = 0
for j in range(order):
cc = c[n-1-j,ci,cj]
if j == 0:
cc -= y
a[j + (order-1)*order] = -cc / c[n-1-order,ci,cj]
if j + 1 < order:
a[j+1 + order*j] = 1
# Compute companion matrix eigenvalues
info = 0
c_dgeev("N", "N", &order, a, &order, <double*>wr, <double*>wi,
NULL, &order, NULL, &order, work, &lwork, &info)
if info != 0:
# Failure
return -2
# Sort roots (insertion sort)
for i in range(order):
br = wr[i]
bi = wi[i]
for j in range(i - 1, -1, -1):
if wr[j] > br:
wr[j+1] = wr[j]
wi[j+1] = wi[j]
else:
wr[j+1] = br
wi[j+1] = bi
break
else:
wr[0] = br
wi[0] = bi
# Return with roots
return order
def _croots_poly1(double[:,:,::1] c, double_complex[:,:,::1] w, double y=0):
"""
Find roots of polynomials.
This function is for testing croots_poly1
Parameters
----------
c : ndarray, (k, m, n)
Coefficients of several order-k polynomials
w : ndarray, (k, m, n)
Output argument --- roots of the polynomials.
"""
cdef double *wr
cdef double *wi
cdef void *workspace
cdef int i, j, k, nroots
if (c.shape[0] != w.shape[0] or c.shape[1] != w.shape[1]
or c.shape[2] != w.shape[2]):
raise ValueError("c and w have incompatible shapes")
if c.shape[0] <= 0:
return
wr = <double*>libc.stdlib.malloc(c.shape[0] * sizeof(double))
wi = <double*>libc.stdlib.malloc(c.shape[0] * sizeof(double))
workspace = NULL
try:
for i in range(c.shape[1]):
for j in range(c.shape[2]):
for k in range(c.shape[0]):
w[k,i,j] = nan
nroots = croots_poly1(c, y, i, j, wr, wi, &workspace)
if nroots == -1:
continue
elif nroots < -1 or nroots >= c.shape[0]:
raise RuntimeError("root-finding failed")
for k in range(nroots):
w[k,i,j].real = wr[k]
w[k,i,j].imag = wi[k]
finally:
if workspace != NULL:
libc.stdlib.free(workspace)
libc.stdlib.free(wr)
libc.stdlib.free(wi)
#------------------------------------------------------------------------------
# Piecewise Bernstein basis polynomials
#------------------------------------------------------------------------------
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
cdef double_or_complex evaluate_bpoly1(double_or_complex s,
double_or_complex[:,:,::1] c,
int ci, int cj) nogil:
"""
Evaluate polynomial in the Bernstein basis in a single interval.
A Bernstein polynomial is defined as
.. math:: b_{j, k} = comb(k, j) x^{j} (1-x)^{k-j}
with ``0 <= x <= 1``.
Parameters
----------
s : double
Polynomial x-value
c : double[:,:,:]
Polynomial coefficients. c[:,ci,cj] will be used
ci, cj : int
Which of the coefs to use
"""
cdef int k, j
cdef double_or_complex res, s1, comb
k = c.shape[0] - 1 # polynomial order
s1 = 1. - s
# special-case lowest orders
if k == 0:
res = c[0, ci, cj]
elif k == 1:
res = c[0, ci, cj] * s1 + c[1, ci, cj] * s
elif k == 2:
res = c[0, ci, cj] * s1*s1 + c[1, ci, cj] * 2.*s1*s + c[2, ci, cj] * s*s
elif k == 3:
res = (c[0, ci, cj] * s1*s1*s1 + c[1, ci, cj] * 3.*s1*s1*s +
c[2, ci, cj] * 3.*s1*s*s + c[3, ci, cj] * s*s*s)
else:
# XX: replace with de Casteljau's algorithm if needs be
res, comb = 0., 1.
for j in range(k+1):
res += comb * s**j * s1**(k-j) * c[j, ci, cj]
comb *= 1. * (k-j) / (j+1.)
return res
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
cdef double_or_complex evaluate_bpoly1_deriv(double_or_complex s,
double_or_complex[:,:,::1] c,
int ci, int cj,
int nu,
double_or_complex[:,:,::1] wrk) nogil:
"""
Evaluate the derivative of a polynomial in the Bernstein basis
in a single interval.
A Bernstein polynomial is defined as
.. math:: b_{j, k} = comb(k, j) x^{j} (1-x)^{k-j}
with ``0 <= x <= 1``.
The algorithm is detailed in BPoly._construct_from_derivatives.
Parameters
----------
s : double
Polynomial x-value
c : double[:,:,:]
Polynomial coefficients. c[:,ci,cj] will be used
ci, cj : int
Which of the coefs to use
nu : int
Order of the derivative to evaluate. Assumed strictly positive
(no checks are made).
wrk : double[:,:,::1]
A work array, shape (c.shape[0]-nu, 1, 1).
"""
cdef int k, j, a
cdef double_or_complex res, term
cdef double comb, poch
k = c.shape[0] - 1 # polynomial order
if nu == 0:
res = evaluate_bpoly1(s, c, ci, cj)
else:
poch = 1.
for a in range(nu):
poch *= k - a
term = 0.
for a in range(k - nu + 1):
term, comb = 0., 1.
for j in range(nu+1):
term += c[j+a, ci, cj] * (-1)**(j+nu) * comb
comb *= 1. * (nu-j) / (j+1)
wrk[a, 0, 0] = term * poch
res = evaluate_bpoly1(s, wrk, 0, 0)
return res
#
# Evaluation; only differs from _ppoly by evaluate_poly1 -> evaluate_bpoly1
#
@cython.wraparound(False)
@cython.boundscheck(False)
@cython.cdivision(True)
def evaluate_bernstein(double_or_complex[:,:,::1] c,
double[::1] x,
double[::1] xp,
int nu,
bint extrapolate,
double_or_complex[:,::1] out):
"""
Evaluate a piecewise polynomial in the Bernstein basis.
Parameters
----------
c : ndarray, shape (k, m, n)
Coefficients local polynomials of order `k-1` in `m` intervals.
There are `n` polynomials in each interval.
Coefficient of highest order-term comes first.
x : ndarray, shape (m+1,)
Breakpoints of polynomials
xp : ndarray, shape (r,)
Points to evaluate the piecewise polynomial at.
nu : int
Order of derivative to evaluate. The derivative is evaluated
piecewise and may have discontinuities.
extrapolate : bint, optional
Whether to extrapolate to out-of-bounds points based on first
and last intervals, or to return NaNs.
out : ndarray, shape (r, n)
Value of each polynomial at each of the input points.
This argument is modified in-place.
"""
cdef int ip, jp
cdef int interval
cdef double xval
cdef double_or_complex s, ds, ds_nu
cdef double_or_complex[:,:,::1] wrk
# check derivative order
if nu < 0:
raise NotImplementedError("Cannot do antiderivatives in the B-basis yet.")
# shape checks
if out.shape[0] != xp.shape[0]:
raise ValueError("out and xp have incompatible shapes")
if out.shape[1] != c.shape[2]:
raise ValueError("out and c have incompatible shapes")
if c.shape[1] != x.shape[0] - 1:
raise ValueError("x and c have incompatible shapes")
if nu > 0:
if double_or_complex is double_complex:
wrk = np.empty((c.shape[0]-nu, 1, 1), dtype=np.complex_)
else:
wrk = np.empty((c.shape[0]-nu, 1, 1), dtype=np.float_)
# evaluate
interval = 0
for ip in range(len(xp)):
xval = xp[ip]
# Find correct interval
i = find_interval(&x[0], x.shape[0], xval, interval, extrapolate)
if i < 0:
# xval was nan etc
for jp in range(c.shape[2]):
out[ip, jp] = nan
continue
else:
interval = i
# Evaluate the local polynomial(s)
ds = x[interval+1] - x[interval]
ds_nu = ds**nu
for jp in range(c.shape[2]):
s = (xval - x[interval]) / ds
if nu == 0:
out[ip, jp] = evaluate_bpoly1(s, c, interval, jp)
else:
out[ip, jp] = evaluate_bpoly1_deriv(s, c, interval, jp,
nu, wrk) / ds_nu
|