"""Matlab(tm) compatibility functions.

This will hopefully become a complete set of the basic functions available in
matlab.  The syntax is kept as close to the matlab syntax as possible.  One 
fundamental change is that the first index in matlab varies the fastest (as in 
FORTRAN).  That means that it will usually perform reductions over columns, 
whereas with this object the most natural reductions are over rows.  It's perfectly
possible to make this work the way it does in matlab if that's desired.
"""
from Numeric import *

# Elementary Matrices

# zeros is from matrixmodule in C
# ones is from Numeric.py

import RandomArray
def rand(*args):
	"""rand(d1,...,dn, typecode='d') returns a matrix of the given dimensions
	which is initialized to random number in the range [0,1).
	"""
	return RandomArray.random_sample(args)

def eye(N, M=None, k=0, typecode=None):
	"""eye(N, M=N, k=0, typecode=None) returns a N-by-M matrix where the 
	k-th diagonal is all ones, and everything else is zeros.
	"""

	if M == None: M = N
	if type(M) == type('d'): 
		typecode = M
		M = N
	if (typecode == None):
		m = outer(subtract, arange(N), arange(M)).equal(-k)
	else:
		i = arange(N, typecode=typecode)
		m = outer(subtract, i, range(M)).equal(-k)
	return m

def tri(N, M=None, k=0, typecode=None):
	if M == None: M = N
	if type(M) == type('d'): 
		typecode = M
		M = N
	if (typecode == None):
		m = outer(subtract, arange(N), arange(M)).greater_equal(-k)
	else:
		i = arange(N, typecode=typecode)
		m = outer(subtract, i, arange(M)).greater_equal(-k)
	return m
	

# Matrix manipulation

def diag(v, k=0):
	s = v.shape
	if len(s)==1:
		n = s[0]+abs(k)
		if k > 0:
			v = v.concat(zeros(k, v.typecode))
		elif k < 0:
			v = zeros(-k, v.typecode).concat(v)
		return multiply[[1,0]](eye(n, k=k), v)
	elif len(s)==2:
		v = add.reduce(eye(s[0], s[1], k=k)*v)
		if k > 0: return v[:-k]
		elif k < 0: return v[-k:]
		else: return v

def fliplr(m): 
    return m[:, ::-1]

def flipud(m):
    return m[::-1]

# reshape(x, m, n) is not used, instead use reshape(x, (m, n))

def rot90(m, k=1):
	k = k % 4
	if k == 0: return m
	elif k == 1: return m.transpose()[::-1,::-1]
	elif k == 2: return fliplr(m)[::-1,::-1]
	elif k == 3: return fliplr(m.transpose())

def tril(m, k=0):
	return tri(m.shape[0], m.shape[1], k=k, typecode=m.typecode())*m

def triu(m, k=0):
	return (1-tri(m.shape[0], m.shape[1], k-1, m.typecode()))*m 

# Data analysis

# Basic operations
def max(m):
	return maximum.reduce(m)

def min(m):
	return minimum.reduce(m)

# Actually from BASIS, but it fits in so naturally here...

def ptp(m):
	return max(m)-min(m)

def mean(m):
	return add.reduce(m)/len(m)

# sort is done in C but is done row-wise rather than column-wise
def msort(m):
	return sort(m.transpose()).transpose()

def median(m):
	return msort(m)[m.shape[0]/2]

def std(m):
	mu = mean(m)
	return sqrt(add.reduce(pow(m-mu,2)))/sqrt(len(m)-1)

def sum(m):
	return add.reduce(m)

def cumsum(m):
	return add.accumulate(m)

def prod(m):
	return multiply.reduce(m)

def cumprod(m):
	return multiply.accumulate(m)

def trapz(y, x=None):
	"""Integrate f using the trapezoidal rule, where y is f(x).
	"""

	if x == None: d = 1
	else: d = diff(x)
	return sum(d * (y[1:]+y[0:-1])/2)

def diff(x, n=1):
	"""Discrete difference approximation to the derivative
	"""
	if n > 1:
	    return diff(x[1:]-x[:-1], n-1)
	else:
	    return x[1:]-x[:-1]
	
def dot(x, y):
	return add.reduce(x*y)

def corrcoef(x, y=None):
	"""The correlation coefficients
	"""
	c = cov(x, y)
	d = diag(c)
	return c/sqrt(outer(multiply, d,d))

def cov(m,y=None):
	if y != None: m = array([m,y], m.typecode())
	mu = mean(m)
	sum_cov = 0.0
	for v in m:
		sum_cov = sum_cov+outer(multiply, v,v)
	return (sum_cov-len(m)*outer(multiply,mu,mu))/(len(m)-1)



