# fmt: off
from dataclasses import is_dataclass

import numpy as np

from ase.build import bulk, cut, fcc111
from ase.cell import Cell
from ase.geometry import get_layers, wrap_positions
from ase.spacegroup import crystal
from ase.utils import atoms_to_spglib_cell


def test_geometry():
    """Test the ase.geometry module and ase.build.cut() function."""

    al = crystal('Al', [(0, 0, 0)], spacegroup=225, cellpar=4.05)

    # Cut out slab of 5 Al(001) layers
    al001 = cut(al, nlayers=5)
    correct_pos = np.array([[0., 0., 0.],
                            [0., 0.5, 0.2],
                            [0.5, 0., 0.2],
                            [0.5, 0.5, 0.],
                            [0., 0., 0.4],
                            [0., 0.5, 0.6],
                            [0.5, 0., 0.6],
                            [0.5, 0.5, 0.4],
                            [0., 0., 0.8],
                            [0.5, 0.5, 0.8]])
    assert np.allclose(correct_pos, al001.get_scaled_positions())

    # Check layers along 001
    tags, levels = get_layers(al001, (0, 0, 1))
    assert np.allclose(tags, [0, 1, 1, 0, 2, 3, 3, 2, 4, 4])
    assert np.allclose(levels, [0., 2.025, 4.05, 6.075, 8.1])

    # Check layers along 101
    tags, levels = get_layers(al001, (1, 0, 1))
    assert np.allclose(tags, [0, 1, 5, 3, 2, 4, 8, 7, 6, 9])
    assert np.allclose(levels, [0.000, 0.752, 1.504, 1.880, 2.256, 2.632, 3.008,
                                3.384, 4.136, 4.888],
                       atol=0.001)

    # Check layers along 111
    tags, levels = get_layers(al001, (1, 1, 1))
    assert np.allclose(tags, [0, 2, 2, 4, 1, 5, 5, 6, 3, 7])
    assert np.allclose(levels, [0.000, 1.102, 1.929, 2.205, 2.756, 3.031, 3.858,
                                4.960],
                       atol=0.001)

    # Cut out slab of three Al(111) layers
    al111 = cut(al, (1, -1, 0), (0, 1, -1), nlayers=3)
    correct_pos = np.array([[0.5, 0., 0.],
                            [0., 0.5, 0.],
                            [0.5, 0.5, 0.],
                            [0., 0., 0.],
                            [1 / 6., 1 / 3., 1 / 3.],
                            [1 / 6., 5 / 6., 1 / 3.],
                            [2 / 3., 5 / 6., 1 / 3.],
                            [2 / 3., 1 / 3., 1 / 3.],
                            [1 / 3., 1 / 6., 2 / 3.],
                            [5 / 6., 1 / 6., 2 / 3.],
                            [5 / 6., 2 / 3., 2 / 3.],
                            [1 / 3., 2 / 3., 2 / 3.]])
    assert np.allclose(correct_pos, al111.get_scaled_positions())

    # Cut out cell including all corner and edge atoms (non-periodic structure)
    al = cut(al, extend=1.1)
    correct_pos = np.array([[0., 0., 0.],
                            [0., 2.025, 2.025],
                            [2.025, 0., 2.025],
                            [2.025, 2.025, 0.],
                            [0., 0., 4.05],
                            [2.025, 2.025, 4.05],
                            [0., 4.05, 0.],
                            [2.025, 4.05, 2.025],
                            [0., 4.05, 4.05],
                            [4.05, 0., 0.],
                            [4.05, 2.025, 2.025],
                            [4.05, 0., 4.05],
                            [4.05, 4.05, 0.],
                            [4.05, 4.05, 4.05]])
    assert np.allclose(correct_pos, al.positions)

    # Create an Ag(111)/Si(111) interface
    ag = crystal(['Ag'], basis=[(0, 0, 0)], spacegroup=225, cellpar=4.09)
    si = crystal(['Si'], basis=[(0, 0, 0)], spacegroup=227, cellpar=5.43)
    try:
        import spglib
        for atoms, no_ref in zip([ag, si], [225, 227]):
            dataset = spglib.get_symmetry_dataset(atoms_to_spglib_cell(atoms))
            no = dataset.number if is_dataclass(dataset) else dataset['number']
            assert no == no_ref
    except ImportError:
        pass

    ag111 = cut(ag, a=(4, -4, 0), b=(4, 4, -8), nlayers=5)  # noqa
    si111 = cut(si, a=(3, -3, 0), b=(3, 3, -6), nlayers=5)  # noqa
    #
    # interface = stack(ag111, si111)
    # assert len(interface) == 1000
    # assert np.allclose(interface.positions[::100],
    #                   [[  4.08125   ,  -2.040625  ,   -2.040625  ],
    #                    [  8.1625    ,   6.121875  ,  -14.284375  ],
    #                    [ 10.211875  ,   0.00875   ,    2.049375  ],
    #                    [ 24.49041667,  -4.07833333,  -16.32208333],
    #                    [ 18.37145833,  14.29020833,  -24.48166667],
    #                    [ 24.49916667,  12.25541667,  -20.39458333],
    #                    [ 18.36854167,  16.32791667,  -30.60645833],
    #                    [ 19.0575    ,   0.01166667,    5.45333333],
    #                    [ 23.13388889,   6.80888889,    1.36722222],
    #                    [ 35.3825    ,   5.45333333,  -16.31333333]])
    #

    # Test the wrap_positions function.
    positions = np.array([
        [4.0725, -4.0725, -1.3575],
        [1.3575, -1.3575, -1.3575],
        [2.715, -2.715, 0.],
        [4.0725, 1.3575, -1.3575],
        [0., 0., 0.],
        [2.715, 2.715, 0.],
        [6.7875, -1.3575, -1.3575],
        [5.43, 0., 0.]])
    cell = np.array([[5.43, 5.43, 0.0], [5.43, -5.43, 0.0], [0.00, 0.00, 40.0]])
    positions += np.array([6.1, -0.1, 10.1])
    result_positions = wrap_positions(positions=positions, cell=cell)
    correct_pos = np.array([
        [4.7425, 1.2575, 8.7425],
        [7.4575, -1.4575, 8.7425],
        [3.385, 2.615, 10.1],
        [4.7425, -4.1725, 8.7425],
        [6.1, -0.1, 10.1],
        [3.385, -2.815, 10.1],
        [2.0275, -1.4575, 8.7425],
        [0.67, -0.1, 10.1]])
    assert np.allclose(correct_pos, result_positions)

    positions = wrap_positions(positions, cell, pbc=[False, True, False])
    correct_pos = np.array([
        [4.7425, 1.2575, 8.7425],
        [7.4575, -1.4575, 8.7425],
        [3.385, 2.615, 10.1],
        [10.1725, 1.2575, 8.7425],
        [6.1, -0.1, 10.1],
        [8.815, 2.615, 10.1],
        [7.4575, 3.9725, 8.7425],
        [6.1, 5.33, 10.1]])
    assert np.allclose(correct_pos, positions)

    # Test center away from values 0, 0.5
    result_positions = wrap_positions(positions, cell,
                                      pbc=[True, True, False],
                                      center=0.2)
    correct_pos = [[4.7425, 1.2575, 8.7425],
                   [2.0275, 3.9725, 8.7425],
                   [3.385, 2.615, 10.1],
                   [-0.6875, 1.2575, 8.7425],
                   [6.1, -0.1, 10.1],
                   [3.385, -2.815, 10.1],
                   [2.0275, -1.4575, 8.7425],
                   [0.67, -0.1, 10.1]]
    assert np.allclose(correct_pos, result_positions)

    # Test pretty_translation keyword
    positions = np.array([
        [0, 0, 0],
        [0, 1, 1],
        [1, 0, 1],
        [1, 1, 0.]])
    cell = np.diag([2, 2, 2])
    result = wrap_positions(positions, cell, pbc=[True, True, True],
                            pretty_translation=True)
    assert np.max(result) < 1 + 1E-10
    assert np.min(result) > -1E-10

    result = wrap_positions(positions - 5, cell, pbc=[True, True, True],
                            pretty_translation=True)
    assert np.max(result) < 1 + 1E-10

    result = wrap_positions(positions - 5, cell, pbc=[False, True, True],
                            pretty_translation=True)
    assert np.max(result[:, 0]) < -3
    assert np.max(result[:, 1:]) < 1 + 1E-10

    # Get the correct crystal structure from a range of different cells

    def checkcell(cell, name):
        cell = Cell.ascell(cell)
        lat = cell.get_bravais_lattice()
        assert lat.name == name, (lat.name, name)

    checkcell(bulk('Al').cell, 'FCC')
    checkcell(bulk('Fe').cell, 'BCC')
    checkcell(bulk('Zn').cell, 'HEX')
    checkcell(fcc111('Au', size=(1, 1, 3), periodic=True).cell, 'HEX')
    checkcell([[1, 0, 0], [0, 1, 0], [0, 0, 1]], 'CUB')
    checkcell([[1, 0, 0], [0, 1, 0], [0, 0, 2]], 'TET')
    checkcell([[1, 0, 0], [0, 2, 0], [0, 0, 3]], 'ORC')
    checkcell([[1, 0, 0], [0, 2, 0], [0.5, 0, 3]], 'ORCC')
    checkcell([[1, 0, 0], [0, 2, 0], [0.501, 0, 3]], 'MCL')
    checkcell([[1, 0, 0], [0.5, 3**0.5 / 2, 0], [0, 0, 3]], 'HEX')
