"""Python code in support of array's, umath, and numeric 
"""

import multiarray
from umath import * # Substitute fast_umath for "unsafe" math
from Precision import *

import string, types, math

#Use this to add a new axis to an array
NewAxis = None

#The following functions are considered builtin, they all might be
#in C some day

def arrayrange(start, stop=None, step=1, typecode=None):
	"""Just like range() except it returns a array whose
	type can be specfied by the keyword argument typecode
	"""
    
	if (stop == None):
		stop = start
		start = 0
	n = int(math.ceil(float(stop-start)/step))
	if n <= 0:
		m = zeros( (0,) )+(step+start+stop)
	else:
		m = (add.accumulate(ones((n,), Int))-1)*step + (start + stop-stop)
    
	if typecode != None and m.typecode() != typecode:
		return m.astype(typecode)
	else:
		return m

#Include some functions straight from multiarray
array = multiarray.array
zeros = multiarray.zeros
fromstring = multiarray.fromstring
take = multiarray.take
reshape = multiarray.reshape
repeat = multiarray.repeat
choose = multiarray.choose
convolve = multiarray.convolve
ArrayType = multiarray.arraytype

def swapaxes(a, axis1, axis2):
	n = len(shape(a))
	if n <= 1: return a
	new_axes = arange(n)
	new_axes[axis1] = axis2
	new_axes[axis2] = axis1
	return multiarray.transpose(a, new_axes)

arraytype = multiarray.arraytype
#add extra intelligence to the basic C functions
def concatenate(a, axis=0):
	if axis == 0:
		return multiarray.concatenate(a)
	else:
		new_list = []
		for m in a:
			new_list.append(swapaxes(m, axis, 0))
	return swapaxes(multiarray.concatenate(new_list), axis, 0)

def transpose(a, axes=None):
	if axes == None:
		axes = arange(len(array(a).shape))[::-1]
	return multiarray.transpose(a, axes)

def sort(a, axis=-1):
	if axis != -1: a = swapaxes(a, axis, -1)
	s = multiarray.sort(a)
	if axis != -1: s = swapaxes(s, axis, -1)
	return s

def argsort(a, axis=-1):
	if axis != -1: a = swapaxes(a, axis, -1)
	s = multiarray.argsort(a)
	if axis != -1: s = swapaxes(s, axis, -1)
	return s

def argmax(a, axis=-1):
	if axis != -1: a = swapaxes(a, axis, -1)
	s = multiarray.argmax(a)
	#probably need a swap here if > 2d
	#if axis != -1: s = swapaxes(s, axis, -1)
	return s

def argmin(x, axis=-1):
	return argmax(negative(x), axis)


searchsorted = multiarray.binarysearch

def innerproduct(a,b):
	try:
		return multiarray.innerproduct(a,b)
	except(TypeError):
		if array(a).shape == () or array(b).shape == ():
			return a*b
		else:
			raise TypeError, "invalid types for dot"

def dot(a, b):
	return innerproduct(a, swapaxes(b, -1, -2))

#This is obsolete, don't use in new code
matrixmultiply = dot

#Use Konrad's printing function (modified for both str and repr now)
from ArrayPrinter import array2string
def array_repr(a, max_line_width = None, precision = None, suppress_small = None):
	return array2string(a, max_line_width, precision, suppress_small, ', ', 1)

def array_str(a, max_line_width = None, precision = None, suppress_small = None):
	return array2string(a, max_line_width, precision, suppress_small, ' ', 0)
	
multiarray.set_string_function(array_str, 0)
multiarray.set_string_function(array_repr, 1)

#This is a nice value to have around
#Maybe in sys some day
LittleEndian = fromstring("\001"+"\000"*7, 'i')[0] == 1

def resize(a, new_shape):
	"""Returns a new array with the specified shape.  The original
	array's total size can be any size.
	"""

	a = ravel(a)
	total_size = multiply.reduce(new_shape)
	n_copies = total_size / len(a)
	extra = total_size % len(a)

	if extra != 0: 
		n_copies = n_copies+1
		extra = len(a)-extra

	a = concatenate( (a,)*n_copies)
	if extra > 0:
		a = a[:-extra]

	return reshape(a, new_shape)

def indices(dimensions, typecode=None):
	tmp = ones(dimensions)
	lst = []
	for i in range(len(dimensions)):
		lst.append( add.accumulate(tmp, i)-1 )
	if typecode != None:
		return array(lst, typecode)
	else:
		return array(lst)


def fromfunction(function, dimensions):
    return apply(function, tuple(indices(dimensions)))

def __diagonal(a, offset=0):
	s = list(a.shape)
	s[-2] = s[-2]*s[-1]
	r = reshape(a, s[:-1])
	if offset < 0: offset = s[-1]-(offset+1)
	return take(r, arange(offset,s[-2], s[-1]+1), -1)
	


def diagonal(a, offset=0, axis1=-2, axis2=-1):
	if axis1 != -2: a = swapaxes(a, axis1, -2)
	if axis2 != -1: a = swapaxes(a, axis2, -1)	
	s = __diagonal(a, offset)
	if axis1 != -2: s = swapaxes(s, axis1, -2)
	if axis2 != -1: s = swapaxes(a, axis2, -1)

	return s


def trace(a, offset=0, axis1=-2, axis2=-1):
	return add.reduce(diagonal(a, offset, axis1, axis2), -1)


# These two functions are used in my modified pickle.py so that
# matrices can be pickled.  Notice that matrices are written in 
# binary format for efficiency, but that they pay attention to
# byte-order issues for  portability.

def DumpArray(m, fp):    
	s = m.shape
	if LittleEndian: endian = "L"
	else: endian = "B"
	fp.write("A%s%s%d " % (m.typecode(), endian, m.itemsize()))
	for d in s:
		fp.write("%d "% d)
	fp.write('\n')
	fp.write(m.tostring())

def LoadArray(fp):
	ln = string.split(fp.readline())
	if ln[0][0] == 'A': ln[0] = ln[0][1:] # Nasty hack showing my ignorance of pickle
	typecode = ln[0][0]
	endian = ln[0][1]
	
	shape = map(lambda x: string.atoi(x), ln[1:])
	itemsize = string.atoi(ln[0][2:])

	sz = reduce(multiply, shape)*itemsize
	data = fp.read(sz)
		
	m = fromstring(data, typecode)
	m = reshape(m, shape)

	if (LittleEndian and endian == 'B') or (not LittleEndian and endian == 'L'):
		return m.byteswapped()
	else:
		return m

import pickle, copy
class Unpickler(pickle.Unpickler):
	def load_array(self):
		self.stack.append(LoadArray(self))
	
	dispatch = copy.copy(pickle.Unpickler.dispatch)	
	dispatch['A'] = load_array

class Pickler(pickle.Pickler):
	def save_array(self, object):
		DumpArray(object, self)

	dispatch = copy.copy(pickle.Pickler.dispatch)		
	dispatch[ArrayType] = save_array

#Convenience functions
from StringIO import StringIO

def dump(object, file):
	Pickler(file).dump(object)

def dumps(object):
	file = StringIO()
	Pickler(file).dump(object)
	return file.getvalue()

def load(file):
	return Unpickler(file).load()

def loads(str):
	file = StringIO(str)
	return Unpickler(file).load()


# These are all essentially abbreviations
# These might wind up in a special abbreviations module

def ravel(m):
	"""Returns a 1d array corresponding to all the elements of it's
	argument.
	"""
	return reshape(m, (-1,))

def nonzero(a):
	"""Return the indices of the elements of a which are not zero, a must be 1d
	"""
	return repeat(arange(len(a)), not_equal(a, 0))

def asarray(a, typecode=None):
	if typecode == None:
		return array(a, copy=0)
	else:
		return array(a, typecode, copy=0)

#Move this into C to do it right!
def shape(a):
	return asarray(a).shape

def where(condition, x, y):
	"""where(condition,x,y) is shaped like condition and has elements of x and
	y where condition is respectively true or false
	"""
	return choose(not_equal(condition, 0), (y, x))

def compress(condition, m, dimension=-1):
	"""compress(condition, x, dimension=-1) = those elements of x corresponding 
	to those elements of condition that are "true".  condition must be the
	same size as the given dimension of x."""
	return take(m, nonzero(condition), dimension)

def clip(m, m_min, m_max):
	"""clip(m, m_min, m_max) = every entry in m that is less than m_min is
	replaced by m_min, and every entry greater than m_max is replaced by
	m_max.
	"""

	selector = less(m, m_min)+2*greater(m, m_max)
	return choose(selector, (m, m_min, m_max))

def ones(shape, typecode='l'):
	"""ones(shape, typecode=Int) returns a array of the given dimensions
	which is initialized to all ones.
	"""

	return zeros(shape, typecode)+array(1, typecode)

def identity(n):
	return resize([1]+n*[0], (n,n))

sum = add.reduce
cumsum = add.accumulate
product = multiply.reduce
cumproduct = multiply.accumulate
alltrue = logical_and.reduce
sometrue = logical_or.reduce

arange = arrayrange













