from __future__ import absolute_import, division
import six

import numpy as np
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
                           assert_almost_equal)

import pytest

from gridData import Grid

def f_arithmetic(g):
    return g + g - 2.5 * g / (g + 5.3)

@pytest.fixture(scope="class")
def data():
    d = dict(
        griddata=np.arange(1, 28).reshape(3, 3, 3),
        origin=np.zeros(3),
        delta=np.ones(3))
    d['grid'] = Grid(d['griddata'], origin=d['origin'],
                     delta=d['delta'])
    return d

class TestGrid(object):
    @pytest.fixture
    def pklfile(self, data, tmpdir):
        g = data['grid']
        fn = tmpdir.mkdir('grid').join('grid.dat')
        g.save(fn)  # always saves as pkl
        return fn

    def test_init(self, data):
        g = Grid(data['griddata'], origin=data['origin'],
                 delta=1)
        assert_array_equal(g.delta, data['delta'])

    def test_init_wrong_origin(self, data):
        with pytest.raises(TypeError):
            Grid(data['griddata'], origin=np.ones(4), delta=data['delta'])

    def test_init_wrong_delta(self, data):
        with pytest.raises(TypeError):
            Grid(data['griddata'], origin=data['origin'], delta=np.ones(4))

    def test_empty_Grid(self):
        g = Grid()
        assert isinstance(g, Grid)

    def test_init_missing_delta_ValueError(self, data):
        with pytest.raises(ValueError):
            Grid(data['griddata'], origin=data['origin'])

    def test_init_missing_origin_ValueError(self, data):
        with pytest.raises(ValueError):
            Grid(data['griddata'], delta=data['delta'])

    def test_init_wrong_data_exception(self):
        with pytest.raises(IOError):
            Grid("__does_not_exist__")

    def test_load_wrong_fileformat_ValueError(self):
        with pytest.raises(ValueError):
            Grid(grid=True, file_format="xxx")

    def test_equality(self, data):
        assert data['grid'] == data['grid']
        assert data['grid'] != 'foo'
        g = Grid(data['griddata'], origin=data['origin'] + 1, delta=data['delta'])
        assert data['grid'] != g

    def test_addition(self, data):
        g = data['grid'] + data['grid']
        assert_array_equal(g.grid.flat, (2 * data['griddata']).flat)
        g = 2 + data['grid']
        assert_array_equal(g.grid.flat, (2 + data['griddata']).flat)
        g = g + data['grid']
        assert_array_equal(g.grid.flat, (2 + (2 * data['griddata'])).flat)

    def test_subtraction(self, data):
        g = data['grid'] - data['grid']
        assert_array_equal(g.grid.flat, np.zeros(27))
        g = 2 - data['grid']
        assert_array_equal(g.grid.flat, (2 - data['griddata']).flat)

    def test_multiplication(self, data):
        g = data['grid'] * data['grid']
        assert_array_equal(g.grid.flat, (data['griddata'] ** 2).flat)
        g = 2 * data['grid']
        assert_array_equal(g.grid.flat, (2 * data['griddata']).flat)

    def test_division(self, data):
        # __truediv__ is used in py3 by default and py2 if division
        # is imported from __future__; to make testing easier lets call
        # them explicitely
        #
        g = data['grid'].__truediv__(data['grid'])
        assert_array_equal(g.grid.flat, np.ones(27))
        g = data['grid'].__rtruediv__(2)
        assert_array_equal(g.grid.flat, (2 / data['griddata']).flat)

    @pytest.mark.skipif(not six.PY2, reason="classic division only in Python 2")
    def test_classic_division(self, data):
        # this is normally ONLY invoked in python 2 and will ONLY
        # work in Python 2; we test the operator methods directly
        # because '/' always performs truedivision in GridDataFormats
        # (we use __future__.division everywhere, also in this test)
        g = data['grid'].__div__(data['grid'])
        assert_array_equal(g.grid.flat, np.ones(27, dtype=np.int64))

        # performs floordivision
        g = data['grid'].__rdiv__(2)
        assert_array_equal(g.grid.flat, (2 // data['griddata']).flat)

        # performs truedivision
        # (note: '/' performs truedivision because of __future__.division!)
        g = data['grid'].__rdiv__(2.0)
        assert_array_equal(g.grid.flat, (2.0 / data['griddata']).flat)

    @pytest.mark.skipif(six.PY2, reason="classic division present in Python 2")
    def test_classic_division_NotImplementedError(self, data):
        with pytest.raises(NotImplementedError):
            data['grid'].__div__(2)
        with pytest.raises(NotImplementedError):
            data['grid'].__rdiv__(2)

    def test_floordivision(self, data):
        g = data['grid'].__floordiv__(data['grid'])
        assert_array_equal(g.grid.flat, np.ones(27, dtype=np.int64))
        g = 2 // data['grid']
        assert_array_equal(g.grid.flat, (2 // data['griddata']).flat)

    def test_power(self, data):
        g = data['grid'] ** 2
        assert_array_equal(g.grid.flat, (data['griddata'] ** 2).flat)
        g = 2 ** data['grid']
        assert_array_equal(g.grid.flat, (2 ** data['griddata']).flat)

    def test_compatibility_type(self, data):
        assert data['grid'].check_compatible(data['grid'])
        assert data['grid'].check_compatible(3)
        g = Grid(data['griddata'], origin=data['origin'] - 1, delta=data['delta'])
        assert data['grid'].check_compatible(g)

    def test_wrong_compatibile_type(self, data):
        with pytest.raises(TypeError):
            data['grid'].check_compatible("foo")

    def test_non_orthonormal_boxes(self, data):
        delta = np.eye(3)
        with pytest.raises(NotImplementedError):
            Grid(data['griddata'], origin=data['origin'], delta=delta)

    def test_centers(self, data):
        # this only checks the edges. If you know an alternative
        # algorithm that isn't an exact duplicate of the one in
        # g.centers to test this please implement it.
        g = Grid(data['griddata'], origin=np.ones(3), delta=data['delta'])
        centers = np.array(list(g.centers()))
        assert_array_equal(centers[0], g.origin)
        assert_array_equal(centers[-1] - g.origin,
                           (np.array(g.grid.shape) - 1) * data['delta'])

    def test_resample_factor_failure(self, data):
        pytest.importorskip('scipy')

        with pytest.raises(ValueError):
            g = data['grid'].resample_factor(0)

    def test_resample_factor(self, data):
        pytest.importorskip('scipy')

        g = data['grid'].resample_factor(2)
        assert_array_equal(g.delta, np.ones(3) * .5)
        # zooming in by a factor of 2. Each subinterval is
        # split in half, so 3 gridpoints (2 subintervals)
        # becomes 5 gridpoints (4 subintervals)
        assert_array_equal(g.grid.shape, np.ones(3) * 5)
        # check that the values are identical with the
        # correct stride.
        assert_array_almost_equal(g.grid[::2, ::2, ::2],
                                  data['grid'].grid)

    def test_load_pickle(self, data, tmpdir):
        g = data['grid']
        fn = str(tmpdir.mkdir('grid').join('grid.pkl'))
        g.save(fn)

        h = Grid()
        h.load(fn)

        assert h == g

    def test_init_pickle_pathobjects(self, data, tmpdir):
        g = data['grid']
        fn = tmpdir.mkdir('grid').join('grid.pickle')
        g.save(fn)

        h = Grid(fn)

        assert h == g

    @pytest.mark.parametrize("fileformat", ("pkl", "PKL", "pickle", "python"))
    def test_load_fileformat(self, data, pklfile, fileformat):
        h = Grid(pklfile, file_format="pkl")
        assert h == data['grid']

    # At the moment, reading the file with the wrong parser does not give
    # good error messages.
    @pytest.mark.xfail
    @pytest.mark.parametrize("fileformat", ("ccp4", "plt", "dx"))
    def test_load_wrong_fileformat(self, data, pklfile, fileformat):
        with pytest.raises('ValueError'):
            Grid(pklfile, file_format=fileformat)

    # just check that we can export without stupid failures; detailed
    # format checks in separate tests
    @pytest.mark.parametrize("fileformat", ("dx", "pkl"))
    def test_export(self, data, fileformat, tmpdir):
        g = data['grid']
        fn = tmpdir.mkdir('grid_export').join("grid.{}".format(fileformat))
        g.export(fn)   # check that path objects work
        h = Grid(fn)   # use format autodetection
        assert g == h

    @pytest.mark.parametrize("fileformat", ("ccp4", "plt"))
    def test_export_not_supported(self, data, fileformat, tmpdir):
        g = data['grid']
        fn = tmpdir.mkdir('grid_export').join("grid.{}".format(fileformat))
        with pytest.raises(ValueError):
            g.export(fn)


def test_inheritance(data):
    class DerivedGrid(Grid):
        pass

    dg = DerivedGrid(data['griddata'], origin=data['origin'],
                     delta=data['delta'])
    result = f_arithmetic(dg)

    assert isinstance(result, DerivedGrid)

    ref = f_arithmetic(data['grid'])
    assert_almost_equal(result.grid, ref.grid)

def test_anyarray(data):
    ma = np.ma.MaskedArray(data['griddata'])
    mg = Grid(ma, origin=data['origin'], delta=data['delta'])

    assert isinstance(mg.grid, ma.__class__)

    result = f_arithmetic(mg)
    ref = f_arithmetic(data['grid'])

    assert_almost_equal(result.grid, ref.grid)
