"""Tests of matrix manipulating functions."""

from __future__ import annotations

from pathlib import Path
from typing import cast

import numpy as np
import pytest
from numpy.typing import NDArray

from symfc.spg_reps import SpgRepsBase
from symfc.utils.utils import (
    SymfcAtoms,
    compute_sg_permutations,
    compute_sg_permutations_stable,
    get_indep_atoms_by_lat_trans,
)

cwd = Path(__file__).parent


def test_get_indep_atoms_by_lattice_translation(
    ph_nacl_222: tuple[SymfcAtoms, NDArray, NDArray],
):
    """Test of get_indep_atoms_by_lattice_translation."""
    supercell, _, _ = ph_nacl_222
    sym_op_reps = SpgRepsBase(supercell)
    trans_perms = sym_op_reps.translation_permutations
    assert trans_perms.shape == (32, 64)
    indep_atoms = get_indep_atoms_by_lat_trans(trans_perms)
    np.testing.assert_array_equal(indep_atoms, [0, 32])


def test_compute_sg_permutations(
    ph_gan_222: tuple[SymfcAtoms, NDArray, NDArray], cell_gan_111: SymfcAtoms
):
    """Test compute_sg_permutations."""
    pytest.importorskip("spglib")
    import spglib
    from spglib.spglib import Cell as SpgCell

    supercell, _, _ = ph_gan_222
    primitive = cell_gan_111
    dataset = spglib.get_symmetry_dataset(cast(SpgCell, supercell.totuple()))
    assert dataset is not None
    primitive_dataset = spglib.get_symmetry_dataset(cast(SpgCell, primitive.totuple()))
    assert primitive_dataset is not None
    perms = compute_sg_permutations(
        primitive.scaled_positions,
        primitive_dataset.rotations,
        primitive_dataset.translations,
        primitive.cell,
    )
    ref_perms = [
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
        [1, 0, 3, 2],
        [0, 1, 2, 3],
    ]
    np.testing.assert_array_equal(ref_perms, perms)

    perms_super = compute_sg_permutations(
        supercell.scaled_positions,
        dataset.rotations,
        dataset.translations,
        supercell.cell,
    )
    # np.savetxt("perms_super.dat", perms_super, fmt="%d")
    perms_super_ref = np.loadtxt(cwd / ".." / "perms_super.dat", dtype=int)
    np.testing.assert_array_equal(perms_super_ref, perms_super)


def test_compute_sg_permutations_compare_stable():
    """Test compute_sg_permutations and compare with compute_sg_permutations_stable."""
    pytest.importorskip("spglib")
    import spglib
    from spglib.spglib import Cell as SpgCell

    axis = np.array([[0.0, 1.0, 1.0], [7.0, 6.0, 7.0], [8.0, 8.0, 8.0]])
    positions = np.array(
        [
            [0.000000000000000, 0.000000000000000, 0.000000000000000],
            [0.000000000000000, 0.000000000000000, 0.125000000000000],
            [0.000000000000000, 0.000000000000000, 0.250000000000000],
            [0.000000000000000, 0.000000000000000, 0.375000000000000],
            [0.000000000000000, 0.000000000000000, 0.500000000000000],
            [0.000000000000000, 0.000000000000000, 0.625000000000000],
            [0.000000000000000, 0.000000000000000, 0.750000000000000],
            [0.000000000000000, 0.000000000000000, 0.875000000000000],
        ]
    )
    types = np.zeros(8, dtype=int)

    cell = axis, positions, types
    dataset = spglib.get_symmetry_dataset(cast(SpgCell, cell))
    assert dataset is not None
    perms = compute_sg_permutations(
        positions,
        dataset.rotations,
        dataset.translations,
        axis,
    )
    perms_stable = compute_sg_permutations_stable(
        positions,
        dataset.rotations,
        dataset.translations,
        axis,
    )
    ref_perms = [
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 7, 6, 5, 4, 3, 2, 1],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 2, 3, 4, 5, 6, 7, 0],
        [1, 0, 7, 6, 5, 4, 3, 2],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 3, 4, 5, 6, 7, 0, 1],
        [2, 1, 0, 7, 6, 5, 4, 3],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 4, 5, 6, 7, 0, 1, 2],
        [3, 2, 1, 0, 7, 6, 5, 4],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 5, 6, 7, 0, 1, 2, 3],
        [4, 3, 2, 1, 0, 7, 6, 5],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 6, 7, 0, 1, 2, 3, 4],
        [5, 4, 3, 2, 1, 0, 7, 6],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 7, 0, 1, 2, 3, 4, 5],
        [6, 5, 4, 3, 2, 1, 0, 7],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 6, 5, 4, 3, 2, 1, 0],
        [7, 6, 5, 4, 3, 2, 1, 0],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 6, 5, 4, 3, 2, 1, 0],
        [7, 6, 5, 4, 3, 2, 1, 0],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 6, 5, 4, 3, 2, 1, 0],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 0, 1, 2, 3, 4, 5, 6],
        [7, 6, 5, 4, 3, 2, 1, 0],
    ]
    np.testing.assert_array_equal(perms, perms_stable)
    np.testing.assert_array_equal(perms, ref_perms)


def test_compute_sg_permutations_compare_stable_nacl(
    ph_nacl_222: tuple[SymfcAtoms, NDArray, NDArray],
):
    """Test of compute_sg_permutations for NaCl."""
    pytest.importorskip("spglib")
    import spglib
    from spglib.spglib import Cell as SpgCell

    supercell, _, _ = ph_nacl_222
    dataset = spglib.get_symmetry_dataset(cast(SpgCell, supercell.totuple()))
    assert dataset is not None
    perms = compute_sg_permutations(
        supercell.scaled_positions,
        dataset.rotations,
        dataset.translations,
        supercell.cell,
    )
    perms_stable = compute_sg_permutations_stable(
        supercell.scaled_positions,
        dataset.rotations,
        dataset.translations,
        supercell.cell,
    )
    np.testing.assert_array_equal(perms, perms_stable)


def test_compute_sg_permutations_compare_stable_sio2(
    ph_sio2_221: tuple[SymfcAtoms, NDArray, NDArray],
):
    """Test of compute_sg_permutations for SiO2."""
    pytest.importorskip("spglib")
    import spglib
    from spglib.spglib import Cell as SpgCell

    supercell, _, _ = ph_sio2_221
    dataset = spglib.get_symmetry_dataset(cast(SpgCell, supercell.totuple()))
    assert dataset is not None
    perms = compute_sg_permutations(
        supercell.scaled_positions,
        dataset.rotations,
        dataset.translations,
        supercell.cell,
    )
    perms_stable = compute_sg_permutations_stable(
        supercell.scaled_positions,
        dataset.rotations,
        dataset.translations,
        supercell.cell,
    )
    np.testing.assert_array_equal(perms, perms_stable)


def test_compute_sg_permutations_compare_stable_baal2o4():
    """Test compute_sg_permutations for BaAl2O4."""
    pytest.importorskip("spglib")
    import spglib
    from spglib.spglib import Cell as SpgCell

    types = np.zeros(8, dtype=int)
    axis = np.array(
        [
            [10.4020522437529994, 0.0000000000000000, 0.0000000000000000],
            [-5.2010261218764997, 9.0084414945829998, 0.0000000000000000],
            [0.00000000000000000, 0.0000000000000000, 17.8057491600000013],
        ]
    )
    types = np.hstack(
        [np.zeros(16, dtype=int), np.ones(32, dtype=int), np.ones(64, dtype=int) * 2]
    )
    positions = np.array(
        [
            [0.0000000000000000, 0.0000000000000000, 0.1250000000000000],
            [0.5000000000000000, 0.0000000000000000, 0.1250000000000000],
            [0.0000000000000050, 0.5000000000000000, 0.1250000000000000],
            [0.4999999999999950, 0.5000000000000000, 0.1250000000000000],
            [0.0000000000000000, 0.0000000000000000, 0.6250000000000000],
            [0.5000000000000000, 0.0000000000000000, 0.6250000000000000],
            [0.0000000000000050, 0.5000000000000000, 0.6250000000000000],
            [0.4999999999999950, 0.5000000000000000, 0.6250000000000000],
            [0.0000000000000000, 0.0000000000000000, 0.3750000000000000],
            [0.5000000000000000, 0.0000000000000000, 0.3750000000000000],
            [0.0000000000000050, 0.5000000000000000, 0.3750000000000000],
            [0.4999999999999950, 0.5000000000000000, 0.3750000000000000],
            [0.0000000000000000, 0.0000000000000000, 0.8750000000000000],
            [0.5000000000000000, 0.0000000000000000, 0.8750000000000000],
            [0.0000000000000050, 0.5000000000000000, 0.8750000000000000],
            [0.4999999999999950, 0.5000000000000000, 0.8750000000000000],
            [0.1666666666666690, 0.3333333333333370, 0.2768253700000010],
            [0.6666666666666690, 0.3333333333333370, 0.2768253700000010],
            [0.1666666666666730, 0.8333333333333370, 0.2768253700000010],
            [0.6666666666666640, 0.8333333333333370, 0.2768253700000010],
            [0.1666666666666690, 0.3333333333333370, 0.7768253699999890],
            [0.6666666666666690, 0.3333333333333370, 0.7768253699999890],
            [0.1666666666666730, 0.8333333333333370, 0.7768253699999890],
            [0.6666666666666640, 0.8333333333333370, 0.7768253699999890],
            [0.3333333333333270, 0.1666666666666630, 0.0268253700000010],
            [0.8333333333333270, 0.1666666666666630, 0.0268253700000010],
            [0.3333333333333320, 0.6666666666666630, 0.0268253700000010],
            [0.8333333333333310, 0.6666666666666630, 0.0268253700000010],
            [0.3333333333333270, 0.1666666666666630, 0.5268253700000010],
            [0.8333333333333270, 0.1666666666666630, 0.5268253700000010],
            [0.3333333333333320, 0.6666666666666630, 0.5268253700000010],
            [0.8333333333333310, 0.6666666666666630, 0.5268253700000010],
            [0.3333333333333270, 0.1666666666666630, 0.2231746299999990],
            [0.8333333333333270, 0.1666666666666630, 0.2231746299999990],
            [0.3333333333333320, 0.6666666666666630, 0.2231746299999990],
            [0.8333333333333310, 0.6666666666666630, 0.2231746299999990],
            [0.3333333333333270, 0.1666666666666630, 0.7231746300000110],
            [0.8333333333333270, 0.1666666666666630, 0.7231746300000110],
            [0.3333333333333320, 0.6666666666666630, 0.7231746300000110],
            [0.8333333333333310, 0.6666666666666630, 0.7231746300000110],
            [0.1666666666666690, 0.3333333333333370, 0.4731746299999990],
            [0.6666666666666690, 0.3333333333333370, 0.4731746299999990],
            [0.1666666666666730, 0.8333333333333370, 0.4731746299999990],
            [0.6666666666666640, 0.8333333333333370, 0.4731746299999990],
            [0.1666666666666690, 0.3333333333333370, 0.9731746300000110],
            [0.6666666666666690, 0.3333333333333370, 0.9731746300000110],
            [0.1666666666666730, 0.8333333333333370, 0.9731746300000110],
            [0.6666666666666640, 0.8333333333333370, 0.9731746300000110],
            [0.1749127449999990, 0.1749127449999980, 0.2500000000000000],
            [0.6749127449999970, 0.1749127449999980, 0.2500000000000000],
            [0.1749127450000020, 0.6749127449999980, 0.2500000000000000],
            [0.6749127450000020, 0.6749127449999980, 0.2500000000000000],
            [0.1749127449999990, 0.1749127449999980, 0.7500000000000000],
            [0.6749127449999970, 0.1749127449999980, 0.7500000000000000],
            [0.1749127450000020, 0.6749127449999980, 0.7500000000000000],
            [0.6749127450000020, 0.6749127449999980, 0.7500000000000000],
            [0.0000000000000030, 0.3250872550000020, 0.2500000000000000],
            [0.5000000000000030, 0.3250872550000020, 0.2500000000000000],
            [0.9999999999999990, 0.8250872550000020, 0.2500000000000000],
            [0.5000000000000010, 0.8250872550000020, 0.2500000000000000],
            [0.0000000000000030, 0.3250872550000020, 0.7500000000000000],
            [0.5000000000000030, 0.3250872550000020, 0.7500000000000000],
            [0.9999999999999990, 0.8250872550000020, 0.7500000000000000],
            [0.5000000000000010, 0.8250872550000020, 0.7500000000000000],
            [0.9999999999999990, 0.1749127449999980, 0.0000000000000000],
            [0.5000000000000010, 0.1749127449999980, 0.0000000000000000],
            [0.9999999999999970, 0.6749127449999980, 0.0000000000000000],
            [0.4999999999999970, 0.6749127449999980, 0.0000000000000000],
            [0.9999999999999990, 0.1749127449999980, 0.5000000000000000],
            [0.5000000000000010, 0.1749127449999980, 0.5000000000000000],
            [0.9999999999999970, 0.6749127449999980, 0.5000000000000000],
            [0.4999999999999970, 0.6749127449999980, 0.5000000000000000],
            [0.1666666666666690, 0.3333333333333370, 0.3750000000000000],
            [0.6666666666666690, 0.3333333333333370, 0.3750000000000000],
            [0.1666666666666730, 0.8333333333333370, 0.3750000000000000],
            [0.6666666666666640, 0.8333333333333370, 0.3750000000000000],
            [0.1666666666666690, 0.3333333333333370, 0.8750000000000000],
            [0.6666666666666690, 0.3333333333333370, 0.8750000000000000],
            [0.1666666666666730, 0.8333333333333370, 0.8750000000000000],
            [0.6666666666666640, 0.8333333333333370, 0.8750000000000000],
            [0.3250872549999980, 0.3250872550000020, 0.0000000000000000],
            [0.8250872549999980, 0.3250872550000020, 0.0000000000000000],
            [0.3250872550000010, 0.8250872550000020, 0.0000000000000000],
            [0.8250872550000030, 0.8250872550000020, 0.0000000000000000],
            [0.3250872549999980, 0.3250872550000020, 0.5000000000000000],
            [0.8250872549999980, 0.3250872550000020, 0.5000000000000000],
            [0.3250872550000010, 0.8250872550000020, 0.5000000000000000],
            [0.8250872550000030, 0.8250872550000020, 0.5000000000000000],
            [0.3250872550000040, 0.0000000000000000, 0.2500000000000000],
            [0.8250872550000040, 0.0000000000000000, 0.2500000000000000],
            [0.3250872550000020, 0.5000000000000000, 0.2500000000000000],
            [0.8250872550000000, 0.5000000000000000, 0.2500000000000000],
            [0.3250872550000040, 0.0000000000000000, 0.7500000000000000],
            [0.8250872550000040, 0.0000000000000000, 0.7500000000000000],
            [0.3250872550000020, 0.5000000000000000, 0.7500000000000000],
            [0.8250872550000000, 0.5000000000000000, 0.7500000000000000],
            [0.3333333333333270, 0.1666666666666630, 0.1250000000000000],
            [0.8333333333333270, 0.1666666666666630, 0.1250000000000000],
            [0.3333333333333320, 0.6666666666666630, 0.1250000000000000],
            [0.8333333333333310, 0.6666666666666630, 0.1250000000000000],
            [0.3333333333333270, 0.1666666666666630, 0.6250000000000000],
            [0.8333333333333270, 0.1666666666666630, 0.6250000000000000],
            [0.3333333333333320, 0.6666666666666630, 0.6250000000000000],
            [0.8333333333333310, 0.6666666666666630, 0.6250000000000000],
            [0.1749127449999960, 0.0000000000000000, 0.0000000000000000],
            [0.6749127449999960, 0.0000000000000000, 0.0000000000000000],
            [0.1749127449999980, 0.5000000000000000, 0.0000000000000000],
            [0.6749127450000000, 0.5000000000000000, 0.0000000000000000],
            [0.1749127449999960, 0.0000000000000000, 0.5000000000000000],
            [0.6749127449999960, 0.0000000000000000, 0.5000000000000000],
            [0.1749127449999980, 0.5000000000000000, 0.5000000000000000],
            [0.6749127450000000, 0.5000000000000000, 0.5000000000000000],
        ]
    )

    cell = axis, positions, types
    dataset = spglib.get_symmetry_dataset(cast(SpgCell, cell))
    assert dataset is not None
    perms = compute_sg_permutations(
        positions,
        dataset.rotations,
        dataset.translations,
        axis,
    )
    perms_stable = compute_sg_permutations_stable(
        positions,
        dataset.rotations,
        dataset.translations,
        axis,
    )
    np.testing.assert_array_equal(perms, perms_stable)
