# This module defines a multivariate polynomial class
#
# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
# last revision: 1999-7-21
#

import LinearAlgebra, Numeric, umath
from Scientific.indexing import index_expression

# Class definition

class Polynomial:

    """Multivariate polynomial

    Instances of this class represent polynomials of any order and
    in any number of variables. They can be evaluated like functions.

    Constructor: Polynomial(|coefficients|), where |coefficients| is
    an array whose dimension defines the number of variables and whose
    length along each axis defines the order in the corresponding
    variable.
    """

    def __init__(self, coefficients):
	self.coeff = Numeric.array(coefficients)
	self.dim = len(self.coeff.shape)

    def __call__(self, *args):
	if len(args) != self.dim:
	    raise TypeError, 'Wrong number of arguments'
	p = _powers(args, self.coeff.shape)
	return umath.add.reduce(Numeric.ravel(p*self.coeff))

    def derivative(self, variable=0):
        "Returns the derivative with respect to |variable|."
	n = self.coeff.shape[variable]
	if n == 1:
	    return Polynomial(apply(Numerical.zeros, self.dim*(1,)))
	index = variable*index_expression[::] + \
		index_expression[1::] + index_expression[...]
	factor = Numeric.arange(1.,n)
	factor = factor[index_expression[::] +
			(self.dim-variable-1) * \
			index_expression[Numeric.NewAxis]]
	return Polynomial(factor*self.coeff[index])

    def integral(self, variable=0):
        "Returns the indefinite integral with respect to |variable|."
	n = self.coeff.shape[variable]
	factor = 1./Numeric.arange(1.,n+1)
	factor = factor[index_expression[::] +
			(self.dim-variable-1) * \
			index_expression[Numeric.NewAxis]]
	s = map(None, self.coeff.shape)
	s[variable] = 1
	z = apply(Numeric.zeros, tuple(s))
	intcoeff = _concatenate((z, factor*self.coeff), variable)
	return Polynomial(intcoeff)

# Polynomial fit constructor

def fitPolynomial(order, points, values):
    if len(points) != len(values):
	raise ValueError, 'Inconsistent arguments'
    if type(order) != type(()):
	order = (order,)
    order = tuple(map(lambda n: n+1, order))
    if not _isSequence(points[0]):
	points = map(lambda p: (p,), points)
    if len(order) != len(points[0]):
	raise ValueError, 'Inconsistent arguments'
    if umath.multiply.reduce(order) > len(points):
	raise ValueError, 'Not enough points'
    matrix = []
    for p in points:
	matrix.append(Numeric.ravel(_powers(p, order)))
    matrix = Numeric.array(matrix)
    values = Numeric.array(values)
    #inv = LinearAlgebra.generalized_inverse(matrix)
    #coeff = Numeric.dot(inv, values)
    coeff = LinearAlgebra.linear_least_squares(matrix, values)[0]
    coeff = Numeric.reshape(coeff, order)
    return Polynomial(coeff)

# Helper functions

def _powers(x, n):
    p = 1.
    index = index_expression[::] + \
	    (len(x)-1)*index_expression[Numeric.NewAxis]
    for i in range(len(x)):
	pi = umath.multiply.accumulate(Numeric.array([1.]+(n[i]-1)*[x[i]]))
	p = p*pi[index]
	index = index[-1:] + index[:-1]
    return p

def _isSequence(object):
    n = -1
    try: n = len(object)
    except: pass
    return n >= 0

def _concatenate(arrays, axis):
    axes = range(len(arrays[0].shape))
    axes[axis] = 0
    axes[0] = axis
    arrays = map(lambda a, i = axes: a.transpose(i), arrays)
    c = apply(arrays[0].concat, tuple(arrays[1:]))
    return c.transpose(axes)

# Test code

if __name__ == '__main__':

    p1 = Polynomial([1.,0.3,1.,-0.4])
    x = -1.9
    print p1(x), ((-0.4*x+1.)*x+0.3)*x+1.
    p2 = Polynomial([[1.,0.3],[-0.2,0.5]])
    y = 0.3
    print p2(x,y), 1. + 0.3*y - 0.2*x + 0.5*x*y
    fit = fitPolynomial(2, [1.,2.,3.,4.], [1.,4.,9.,16.])
    print fit.coeff

