"""Python code in support of array's, umath, and numeric 
"""
__version__ = "11"
__LLNLDistribution__ = __version__  

import multiarray
from umath import * # Substitute fast_umath for "unsafe" math
from Precision import *

import string, types, math

#Use this to add a new axis to an array
NewAxis = None

#The following functions are considered builtin, they all might be
#in C some day

def arrayrange(start, stop=None, step=1, typecode=None):
    """Just like range() except it returns a array whose
    type can be specfied by the keyword argument typecode
    """
    
    if (stop == None):
        stop = start
        start = 0
    n = int(math.ceil(float(stop-start)/step))
    if n <= 0:
        m = zeros( (0,) )+(step+start+stop)
    else:
        m = (add.accumulate(ones((n,), Int))-1)*step +start+(stop-stop)
        # the last bit is to deal with e.g. Longs -- 3L-3L==0L
    if typecode != None and m.typecode() != typecode:
        return m.astype(typecode)
    else:
        return m

#Include some functions straight from multiarray
array = multiarray.array
zeros = multiarray.zeros
fromstring = multiarray.fromstring
take = multiarray.take
reshape = multiarray.reshape
repeat = multiarray.repeat
choose = multiarray.choose
cross_correlate = multiarray.cross_correlate
def convolve(a,v,mode=0):
    if (len(v) > len(a)):
        temp = a
        a = v
        v = temp
        del temp
    return cross_correlate(a,asarray(v)[::-1],mode)

ArrayType = multiarray.arraytype

def swapaxes(a, axis1, axis2):
    n = len(shape(a))
    if n <= 1: return a
    new_axes = arange(n)
    new_axes[axis1] = axis2
    new_axes[axis2] = axis1
    return multiarray.transpose(a, new_axes)

arraytype = multiarray.arraytype
#add extra intelligence to the basic C functions
def concatenate(a, axis=0):
    if axis == 0:
        return multiarray.concatenate(a)
    else:
        new_list = []
        for m in a:
            new_list.append(swapaxes(m, axis, 0))
    return swapaxes(multiarray.concatenate(new_list), axis, 0)

def transpose(a, axes=None):
    if axes == None:
        axes = arange(len(array(a).shape))[::-1]
    return multiarray.transpose(a, axes)

def sort(a, axis=-1):
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.sort(a)
    if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argsort(a, axis=-1):
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.argsort(a)
    if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argmax(a, axis=-1):
    if axis != -1: a = swapaxes(a, axis, -1)
    s = multiarray.argmax(a)
    #probably need a swap here if > 2d
    #if axis != -1: s = swapaxes(s, axis, -1)
    return s

def argmin(x, axis=-1):
    return argmax(negative(x), axis)


searchsorted = multiarray.binarysearch

def innerproduct(a,b):
    try:
        return multiarray.innerproduct(a,b)
    except(TypeError):
        if array(a).shape == () or array(b).shape == ():
            return a*b
        else:
            raise TypeError, "invalid types for dot"

def dot(a, b):
    return innerproduct(a, swapaxes(b, -1, -2))

#This is obsolete, don't use in new code
matrixmultiply = dot

#Use Konrad's printing function (modified for both str and repr now)
from ArrayPrinter import array2string
def array_repr(a, max_line_width = None, precision = None, suppress_small = None):
    return array2string(a, max_line_width, precision, suppress_small, ', ', 1)

def array_str(a, max_line_width = None, precision = None, suppress_small = None):
    return array2string(a, max_line_width, precision, suppress_small, ' ', 0)
    
multiarray.set_string_function(array_str, 0)
multiarray.set_string_function(array_repr, 1)

#This is a nice value to have around
#Maybe in sys some day
LittleEndian = fromstring("\001"+"\000"*7, 'i')[0] == 1

def resize(a, new_shape):
    """Returns a new array with the specified shape.  The original
    array's total size can be any size.
    """

    a = ravel(a)
    if not len(a): return zeros(new_shape, a.typecode())
    total_size = multiply.reduce(new_shape)
    n_copies = total_size / len(a)
    extra = total_size % len(a)

    if extra != 0: 
        n_copies = n_copies+1
        extra = len(a)-extra

    a = concatenate( (a,)*n_copies)
    if extra > 0:
        a = a[:-extra]

    return reshape(a, new_shape)

def indices(dimensions, typecode=None):
    tmp = ones(dimensions, typecode)
    lst = []
    for i in range(len(dimensions)):
        lst.append( add.accumulate(tmp, i, )-1 )
    return array(lst)

def fromfunction(function, dimensions):
    return apply(function, tuple(indices(dimensions)))
    

def diagonal(a, offset= 0, axis1=0, axis2=1):
    a = array (a)
    if axis2 < axis1: axis1, axis2 = axis2, axis1
    if axis2 > 1:
        new_axes = range (len (a.shape))
        del new_axes [axis2]; del new_axes [axis1]
        new_axes [0:0] = [axis1, axis2]
        a = transpose (a, new_axes)
    s = a.shape
    if len (s) == 2:
        n1 = s [0]
        n2 = s [1]
        n = n1 * n2
        s = (n,)
        a = reshape (a, s)
        if offset < 0:
            return take (a, range ( - n2 * offset, min(n2, n1+offset) * (n2+1) - n2 * offset, n2+1), 0)
        else:
            return take (a, range (offset,         min(n1, n2-offset) * (n2+1) + offset,      n2+1), 0)
    else :
        my_diagonal = []
        for i in range (s [0]) :
            my_diagonal.append (diagonal (a [i], offset))
        return array (my_diagonal)

def trace(a, offset=0, axis1=0, axis2=1):
    return add.reduce(diagonal(a, offset, axis1, axis2))


# These two functions are used in my modified pickle.py so that
# matrices can be pickled.  Notice that matrices are written in 
# binary format for efficiency, but that they pay attention to
# byte-order issues for  portability.

def DumpArray(m, fp):    
    if m.typecode() == 'O': 
        raise TypeError, "Numeric Pickler can't pickle arrays of Objects"
    s = m.shape
    if LittleEndian: endian = "L"
    else: endian = "B"
    fp.write("A%s%s%d " % (m.typecode(), endian, m.itemsize()))
    for d in s:
        fp.write("%d "% d)
    fp.write('\n')
    fp.write(m.tostring())

def LoadArray(fp):
    ln = string.split(fp.readline())
    if ln[0][0] == 'A': ln[0] = ln[0][1:] # Nasty hack showing my ignorance of pickle
    typecode = ln[0][0]
    endian = ln[0][1]
    
    shape = map(lambda x: string.atoi(x), ln[1:])
    itemsize = string.atoi(ln[0][2:])

    sz = reduce(multiply, shape)*itemsize
    data = fp.read(sz)
        
    m = fromstring(data, typecode)
    m = reshape(m, shape)

    if (LittleEndian and endian == 'B') or (not LittleEndian and endian == 'L'):
        return m.byteswapped()
    else:
        return m

import pickle, copy
class Unpickler(pickle.Unpickler):
    def load_array(self):
        self.stack.append(LoadArray(self))
    
    dispatch = copy.copy(pickle.Unpickler.dispatch)    
    dispatch['A'] = load_array

class Pickler(pickle.Pickler):
    def save_array(self, object):
        DumpArray(object, self)

    dispatch = copy.copy(pickle.Pickler.dispatch)        
    dispatch[ArrayType] = save_array

#Convenience functions
from StringIO import StringIO

def dump(object, file):
    Pickler(file).dump(object)

def dumps(object):
    file = StringIO()
    Pickler(file).dump(object)
    return file.getvalue()

def load(file):
    return Unpickler(file).load()

def loads(str):
    file = StringIO(str)
    return Unpickler(file).load()


# These are all essentially abbreviations
# These might wind up in a special abbreviations module

def ravel(m):
    """Returns a 1d array corresponding to all the elements of it's
    argument.
    """
    return reshape(m, (-1,))

def nonzero(a):
    """Return the indices of the elements of a which are not zero, a must be 1d
    """
    return repeat(arange(len(a)), not_equal(a, 0))

def asarray(a, typecode=None):
    return array(a, typecode, copy=0)

#Move this into C to do it right!
def shape(a):
    return asarray(a).shape

def where(condition, x, y):
    """where(condition,x,y) is shaped like condition and has elements of x and
    y where condition is respectively true or false
    """
    return choose(not_equal(condition, 0), (y, x))

def compress(condition, m, dimension=-1):
    """compress(condition, x, dimension=-1) = those elements of x corresponding 
    to those elements of condition that are "true".  condition must be the
    same size as the given dimension of x."""
    return take(m, nonzero(condition), dimension)

def clip(m, m_min, m_max):
    """clip(m, m_min, m_max) = every entry in m that is less than m_min is
    replaced by m_min, and every entry greater than m_max is replaced by
    m_max.
    """

    selector = less(m, m_min)+2*greater(m, m_max)
    return choose(selector, (m, m_min, m_max))

def ones(shape, typecode='l'):
    """ones(shape, typecode=Int) returns a array of the given dimensions
    which is initialized to all ones.
    """

    return zeros(shape, typecode)+array(1, typecode)

def identity(n):
    return resize([1]+n*[0], (n,n))

sum = add.reduce
cumsum = add.accumulate
product = multiply.reduce
cumproduct = multiply.accumulate
alltrue = logical_and.reduce
sometrue = logical_or.reduce

arange = arrayrange

### Temporary solution for pickling arrays
### Quite inefficient
### david ascher, march 1998
import copy_reg

def array_constructor(shape, typecode, thestr):
    x = fromstring(thestr, typecode)
    x.shape = shape
    return x

def pickle_array(a):
    return array_constructor, (a.shape, a.typecode(), a.tostring(),)  

copy_reg.pickle(ArrayType, pickle_array, array_constructor)

