""" Test functions for matrix module

"""

from numpy.testing import *
set_package_path()
from numpy import arange, rot90, add, fliplr, flipud, zeros, ones, eye, \
     array, diag, histogram2d
import numpy as np
restore_path()

##################################################


def get_mat(n):
    data = arange(n)
    data = add.outer(data,data)
    return data

class test_eye(NumpyTestCase):
    def check_basic(self):
        assert_equal(eye(4),array([[1,0,0,0],
                                   [0,1,0,0],
                                   [0,0,1,0],
                                   [0,0,0,1]]))
        assert_equal(eye(4,dtype='f'),array([[1,0,0,0],
                                                [0,1,0,0],
                                                [0,0,1,0],
                                                [0,0,0,1]],'f'))
    def check_diag(self):
        assert_equal(eye(4,k=1),array([[0,1,0,0],
                                       [0,0,1,0],
                                       [0,0,0,1],
                                       [0,0,0,0]]))
        assert_equal(eye(4,k=-1),array([[0,0,0,0],
                                        [1,0,0,0],
                                        [0,1,0,0],
                                        [0,0,1,0]]))
    def check_2d(self):
        assert_equal(eye(4,3),array([[1,0,0],
                                     [0,1,0],
                                     [0,0,1],
                                     [0,0,0]]))
        assert_equal(eye(3,4),array([[1,0,0,0],
                                     [0,1,0,0],
                                     [0,0,1,0]]))
    def check_diag2d(self):
        assert_equal(eye(3,4,k=2),array([[0,0,1,0],
                                         [0,0,0,1],
                                         [0,0,0,0]]))
        assert_equal(eye(4,3,k=-2),array([[0,0,0],
                                          [0,0,0],
                                          [1,0,0],
                                          [0,1,0]]))

class test_diag(NumpyTestCase):
    def check_vector(self):
        vals = (100*arange(5)).astype('l')
        b = zeros((5,5))
        for k in range(5):
            b[k,k] = vals[k]
        assert_equal(diag(vals),b)
        b = zeros((7,7))
        c = b.copy()
        for k in range(5):
            b[k,k+2] = vals[k]
            c[k+2,k] = vals[k]
        assert_equal(diag(vals,k=2), b)
        assert_equal(diag(vals,k=-2), c)

    def check_matrix(self):
        vals = (100*get_mat(5)+1).astype('l')
        b = zeros((5,))
        for k in range(5):
            b[k] = vals[k,k]
        assert_equal(diag(vals),b)
        b = b*0
        for k in range(3):
            b[k] = vals[k,k+2]
        assert_equal(diag(vals,2),b[:3])
        for k in range(3):
            b[k] = vals[k+2,k]
        assert_equal(diag(vals,-2),b[:3])

class test_fliplr(NumpyTestCase):
    def check_basic(self):
        self.failUnlessRaises(ValueError, fliplr, ones(4))
        a = get_mat(4)
        b = a[:,::-1]
        assert_equal(fliplr(a),b)
        a = [[0,1,2],
             [3,4,5]]
        b = [[2,1,0],
             [5,4,3]]
        assert_equal(fliplr(a),b)

class test_flipud(NumpyTestCase):
    def check_basic(self):
        a = get_mat(4)
        b = a[::-1,:]
        assert_equal(flipud(a),b)
        a = [[0,1,2],
             [3,4,5]]
        b = [[3,4,5],
             [0,1,2]]
        assert_equal(flipud(a),b)

class test_rot90(NumpyTestCase):
    def check_basic(self):
        self.failUnlessRaises(ValueError, rot90, ones(4))

        a = [[0,1,2],
             [3,4,5]]
        b1 = [[2,5],
              [1,4],
              [0,3]]
        b2 = [[5,4,3],
              [2,1,0]]
        b3 = [[3,0],
              [4,1],
              [5,2]]
        b4 = [[0,1,2],
              [3,4,5]]

        for k in range(-3,13,4):
            assert_equal(rot90(a,k=k),b1)
        for k in range(-2,13,4):
            assert_equal(rot90(a,k=k),b2)
        for k in range(-1,13,4):
            assert_equal(rot90(a,k=k),b3)
        for k in range(0,13,4):
            assert_equal(rot90(a,k=k),b4)


class test_histogram2d(NumpyTestCase):
    def check_simple(self):
        x = array([ 0.41702200,  0.72032449,  0.00011437481, 0.302332573,  0.146755891])
        y = array([ 0.09233859,  0.18626021,  0.34556073,  0.39676747,  0.53881673])
        xedges = np.linspace(0,1,10)
        yedges = np.linspace(0,1,10)
        H = histogram2d(x, y, (xedges, yedges))[0]
        answer = array([[0, 0, 0, 1, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [1, 0, 1, 0, 0, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0],
                           [0, 0, 0, 0, 0, 0, 0, 0, 0]])
        assert_array_equal(H.T, answer)
    def check_asym(self):
        x = array([1, 1, 2, 3, 4, 4, 4, 5])
        y = array([1, 3, 2, 0, 1, 2, 3, 4])
        H, xed, yed = histogram2d(x,y, (6, 5), range = [[0,6],[0,5]], normed=True)
        answer = array([[0.,0,0,0,0],
                        [0,1,0,1,0],
                        [0,0,1,0,0],
                        [1,0,0,0,0],
                        [0,1,1,1,0],
                        [0,0,0,0,1]])
        assert_array_almost_equal(H, answer/8., 3)
    def check_norm(self):
        x = array([1,2,3,1,2,3,1,2,3])
        y = array([1,1,1,2,2,2,3,3,3]) 
        H, xed, yed = histogram2d(x,y,[[1,2,3,5], [1,2,3,5]], normed=True)
        answer=array([[1,1,.5],
                     [1,1,.5],
                     [.5,.5,.25]])/9.
        assert_array_almost_equal(H, answer, 3)
        
if __name__ == "__main__":
    NumpyTest().run()
