import io
import numpy as np
import warnings
import pytest

from ase import Atoms
from ase.build import molecule
from ase.io import read, write
from ase.io.cif import CIFLoop, parse_loop, NoStructureData, parse_cif
from ase.calculators.calculator import compare_atoms


def check_fractional_occupancies(atoms):
    """ Checks fractional occupancy entries in atoms.info dict """
    assert atoms.info['occupancy']
    assert list(atoms.arrays['spacegroup_kinds'])

    occupancies = atoms.info['occupancy']
    kinds = atoms.arrays['spacegroup_kinds']
    for a in atoms:
        if a.symbol == 'Na':
            assert len(occupancies[kinds[a.index]]) == 2
            assert occupancies[kinds[a.index]]['K'] == 0.25
            assert occupancies[kinds[a.index]]['Na'] == 0.75
        else:
            assert len(occupancies[kinds[a.index]]) == 1
        if a.symbol == 'Cl':
            assert occupancies[kinds[a.index]]['Cl'] == 0.3


content = """
data_1


_chemical_name_common                  'Mysterious something'
_cell_length_a                         5.50000
_cell_length_b                         5.50000
_cell_length_c                         5.50000
_cell_angle_alpha                      90
_cell_angle_beta                       90
_cell_angle_gamma                      90
_space_group_name_H-M_alt              'F m -3 m'
_space_group_IT_number                 225

loop_
_space_group_symop_operation_xyz
   'x, y, z'
   '-x, -y, -z'
   '-x, -y, z'
   'x, y, -z'
   '-x, y, -z'
   'x, -y, z'
   'x, -y, -z'
   '-x, y, z'
   'z, x, y'
   '-z, -x, -y'
   'z, -x, -y'
   '-z, x, y'
   '-z, -x, y'
   'z, x, -y'
   '-z, x, -y'
   'z, -x, y'
   'y, z, x'
   '-y, -z, -x'
   '-y, z, -x'
   'y, -z, x'
   'y, -z, -x'
   '-y, z, x'
   '-y, -z, x'
   'y, z, -x'
   'y, x, -z'
   '-y, -x, z'
   '-y, -x, -z'
   'y, x, z'
   'y, -x, z'
   '-y, x, -z'
   '-y, x, z'
   'y, -x, -z'
   'x, z, -y'
   '-x, -z, y'
   '-x, z, y'
   'x, -z, -y'
   '-x, -z, -y'
   'x, z, y'
   'x, -z, y'
   '-x, z, -y'
   'z, y, -x'
   '-z, -y, x'
   'z, -y, x'
   '-z, y, -x'
   '-z, y, x'
   'z, -y, -x'
   '-z, -y, -x'
   'z, y, x'
   'x, y+1/2, z+1/2'
   '-x, -y+1/2, -z+1/2'
   '-x, -y+1/2, z+1/2'
   'x, y+1/2, -z+1/2'
   '-x, y+1/2, -z+1/2'
   'x, -y+1/2, z+1/2'
   'x, -y+1/2, -z+1/2'
   '-x, y+1/2, z+1/2'
   'z, x+1/2, y+1/2'
   '-z, -x+1/2, -y+1/2'
   'z, -x+1/2, -y+1/2'
   '-z, x+1/2, y+1/2'
   '-z, -x+1/2, y+1/2'
   'z, x+1/2, -y+1/2'
   '-z, x+1/2, -y+1/2'
   'z, -x+1/2, y+1/2'
   'y, z+1/2, x+1/2'
   '-y, -z+1/2, -x+1/2'
   '-y, z+1/2, -x+1/2'
   'y, -z+1/2, x+1/2'
   'y, -z+1/2, -x+1/2'
   '-y, z+1/2, x+1/2'
   '-y, -z+1/2, x+1/2'
   'y, z+1/2, -x+1/2'
   'y, x+1/2, -z+1/2'
   '-y, -x+1/2, z+1/2'
   '-y, -x+1/2, -z+1/2'
   'y, x+1/2, z+1/2'
   'y, -x+1/2, z+1/2'
   '-y, x+1/2, -z+1/2'
   '-y, x+1/2, z+1/2'
   'y, -x+1/2, -z+1/2'
   'x, z+1/2, -y+1/2'
   '-x, -z+1/2, y+1/2'
   '-x, z+1/2, y+1/2'
   'x, -z+1/2, -y+1/2'
   '-x, -z+1/2, -y+1/2'
   'x, z+1/2, y+1/2'
   'x, -z+1/2, y+1/2'
   '-x, z+1/2, -y+1/2'
   'z, y+1/2, -x+1/2'
   '-z, -y+1/2, x+1/2'
   'z, -y+1/2, x+1/2'
   '-z, y+1/2, -x+1/2'
   '-z, y+1/2, x+1/2'
   'z, -y+1/2, -x+1/2'
   '-z, -y+1/2, -x+1/2'
   'z, y+1/2, x+1/2'
   'x+1/2, y, z+1/2'
   '-x+1/2, -y, -z+1/2'
   '-x+1/2, -y, z+1/2'
   'x+1/2, y, -z+1/2'
   '-x+1/2, y, -z+1/2'
   'x+1/2, -y, z+1/2'
   'x+1/2, -y, -z+1/2'
   '-x+1/2, y, z+1/2'
   'z+1/2, x, y+1/2'
   '-z+1/2, -x, -y+1/2'
   'z+1/2, -x, -y+1/2'
   '-z+1/2, x, y+1/2'
   '-z+1/2, -x, y+1/2'
   'z+1/2, x, -y+1/2'
   '-z+1/2, x, -y+1/2'
   'z+1/2, -x, y+1/2'
   'y+1/2, z, x+1/2'
   '-y+1/2, -z, -x+1/2'
   '-y+1/2, z, -x+1/2'
   'y+1/2, -z, x+1/2'
   'y+1/2, -z, -x+1/2'
   '-y+1/2, z, x+1/2'
   '-y+1/2, -z, x+1/2'
   'y+1/2, z, -x+1/2'
   'y+1/2, x, -z+1/2'
   '-y+1/2, -x, z+1/2'
   '-y+1/2, -x, -z+1/2'
   'y+1/2, x, z+1/2'
   'y+1/2, -x, z+1/2'
   '-y+1/2, x, -z+1/2'
   '-y+1/2, x, z+1/2'
   'y+1/2, -x, -z+1/2'
   'x+1/2, z, -y+1/2'
   '-x+1/2, -z, y+1/2'
   '-x+1/2, z, y+1/2'
   'x+1/2, -z, -y+1/2'
   '-x+1/2, -z, -y+1/2'
   'x+1/2, z, y+1/2'
   'x+1/2, -z, y+1/2'
   '-x+1/2, z, -y+1/2'
   'z+1/2, y, -x+1/2'
   '-z+1/2, -y, x+1/2'
   'z+1/2, -y, x+1/2'
   '-z+1/2, y, -x+1/2'
   '-z+1/2, y, x+1/2'
   'z+1/2, -y, -x+1/2'
   '-z+1/2, -y, -x+1/2'
   'z+1/2, y, x+1/2'
   'x+1/2, y+1/2, z'
   '-x+1/2, -y+1/2, -z'
   '-x+1/2, -y+1/2, z'
   'x+1/2, y+1/2, -z'
   '-x+1/2, y+1/2, -z'
   'x+1/2, -y+1/2, z'
   'x+1/2, -y+1/2, -z'
   '-x+1/2, y+1/2, z'
   'z+1/2, x+1/2, y'
   '-z+1/2, -x+1/2, -y'
   'z+1/2, -x+1/2, -y'
   '-z+1/2, x+1/2, y'
   '-z+1/2, -x+1/2, y'
   'z+1/2, x+1/2, -y'
   '-z+1/2, x+1/2, -y'
   'z+1/2, -x+1/2, y'
   'y+1/2, z+1/2, x'
   '-y+1/2, -z+1/2, -x'
   '-y+1/2, z+1/2, -x'
   'y+1/2, -z+1/2, x'
   'y+1/2, -z+1/2, -x'
   '-y+1/2, z+1/2, x'
   '-y+1/2, -z+1/2, x'
   'y+1/2, z+1/2, -x'
   'y+1/2, x+1/2, -z'
   '-y+1/2, -x+1/2, z'
   '-y+1/2, -x+1/2, -z'
   'y+1/2, x+1/2, z'
   'y+1/2, -x+1/2, z'
   '-y+1/2, x+1/2, -z'
   '-y+1/2, x+1/2, z'
   'y+1/2, -x+1/2, -z'
   'x+1/2, z+1/2, -y'
   '-x+1/2, -z+1/2, y'
   '-x+1/2, z+1/2, y'
   'x+1/2, -z+1/2, -y'
   '-x+1/2, -z+1/2, -y'
   'x+1/2, z+1/2, y'
   'x+1/2, -z+1/2, y'
   '-x+1/2, z+1/2, -y'
   'z+1/2, y+1/2, -x'
   '-z+1/2, -y+1/2, x'
   'z+1/2, -y+1/2, x'
   '-z+1/2, y+1/2, -x'
   '-z+1/2, y+1/2, x'
   'z+1/2, -y+1/2, -x'
   '-z+1/2, -y+1/2, -x'
   'z+1/2, y+1/2, x'

loop_
   _atom_site_label
   _atom_site_occupancy
   _atom_site_fract_x
   _atom_site_fract_y
   _atom_site_fract_z
   _atom_site_adp_type
   _atom_site_B_iso_or_equiv
   _atom_site_type_symbol
   Na         0.7500  0.000000      0.000000      0.000000     Biso  1.000000 Na
   K          0.2500  0.000000      0.000000      0.000000     Biso  1.000000 K
   Cl         0.3000  0.500000      0.500000      0.500000     Biso  1.000000 Cl
   I          0.5000  0.250000      0.250000      0.250000     Biso  1.000000 I
"""


def test_cif():
    cif_file = io.StringIO(content)

    # legacy behavior is to not read the K atoms
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        atoms_leg = read(cif_file, format='cif', fractional_occupancies=False)
    elements = np.unique(atoms_leg.get_atomic_numbers())
    for n in (11, 17, 53):
        assert n in elements
    try:
        atoms_leg.info['occupancy']
        raise AssertionError
    except KeyError:
        pass

    cif_file = io.StringIO(content)
    # new behavior is to still not read the K atoms, but build info
    atoms = read(cif_file, format='cif', fractional_occupancies=True)

    # yield the same old atoms for fractional_occupancies case
    assert len(atoms_leg) == len(atoms)
    assert np.all(atoms_leg.get_atomic_numbers() == atoms.get_atomic_numbers())
    assert atoms_leg == atoms

    elements = np.unique(atoms_leg.get_atomic_numbers())
    for n in (11, 17, 53):
        assert n in elements

    check_fractional_occupancies(atoms)

    # read/write
    fname = 'testfile.cif'
    with open(fname, 'wb') as fd:
        write(fd, atoms, format='cif')

    with open(fname) as fd:
        atoms = read(fd, format='cif', fractional_occupancies=True)

    check_fractional_occupancies(atoms)

    # check repeating atoms
    atoms = atoms.repeat([2, 1, 1])
    assert len(atoms.arrays['spacegroup_kinds']) == len(atoms.arrays['numbers'])


# ICSD-like file from issue #293
content2 = """
data_global
_cell_length_a 9.378(5)
_cell_length_b 7.488(5)
_cell_length_c 6.513(5)
_cell_angle_alpha 90.
_cell_angle_beta 91.15(5)
_cell_angle_gamma 90.
_cell_volume 457.27
_cell_formula_units_Z 2
_symmetry_space_group_name_H-M 'P 1 n 1'
_symmetry_Int_Tables_number 7
_refine_ls_R_factor_all 0.071
loop_
_symmetry_equiv_pos_site_id
_symmetry_equiv_pos_as_xyz
1 'x+1/2, -y, z+1/2'
2 'x, y, z'
loop_
_atom_type_symbol
_atom_type_oxidation_number
Sn2+ 2
As4+ 4
Se2- -2
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_symmetry_multiplicity
_atom_site_Wyckoff_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
_atom_site_B_iso_or_equiv
_atom_site_occupancy
_atom_site_attached_hydrogens
Sn1 Sn2+ 2 a 0.5270(2) 0.3856(2) 0.7224(3) 0.0266(4) 1. 0
Sn2 Sn2+ 2 a 0.0279(2) 0.1245(2) 0.7870(2) 0.0209(4) 1. 0
As1 As4+ 2 a 0.6836(4) 0.1608(5) 0.8108(6) 0.0067(7) 1. 0
As2 As4+ 2 a 0.8174(4) 0.6447(6) 0.1908(6) 0.0057(6) 1. 0
Se1 Se2- 2 a 0.4898(4) 0.7511(6) 0.8491(6) 0.0110(6) 1. 0
Se2 Se2- 2 a 0.7788(4) 0.6462(6) 0.2750(6) 0.0097(6) 1. 0
Se3 Se2- 2 a 0.6942(4) 0.0517(5) 0.5921(6) 0.2095(6) 1. 0
Se4 Se2- 2 a 0.0149(4) 0.3437(6) 0.5497(7) 0.1123(7) 1. 0
Se5 Se2- 2 a 0.1147(4) 0.5633(4) 0.3288(6) 0.1078(6) 1. 0
Se6 Se2- 2 a 0.0050(4) 0.4480(6) 0.9025(6) 0.9102(6) 1. 0
"""


def test_cif_icsd():
    cif_file = io.StringIO(content2)
    atoms = read(cif_file, format='cif')
    # test something random so atoms is not unused
    assert 'occupancy' in atoms.info


@pytest.fixture
def cif_atoms():
    cif_file = io.StringIO(content)
    return read(cif_file, format='cif')


def test_cif_loop_keys(cif_atoms):
    data = {}
    # test case has 20 entries
    data['someKey'] = [[str(i) + "test" for i in range(20)]]
    # test case has 20 entries
    data['someIntKey'] = [[str(i) + "123" for i in range(20)]]
    cif_atoms.write('testfile.cif', loop_keys=data)

    atoms1 = read('testfile.cif', store_tags=True)
    # keys are read lowercase only
    r_data = {'someKey': atoms1.info['_somekey'],
              'someIntKey': atoms1.info['_someintkey']}
    assert r_data['someKey'] == data['someKey'][0]
    # data reading auto converts strins
    assert r_data['someIntKey'] == [int(x) for x in data['someIntKey'][0]]


# test if automatic numbers written after elements are correct
def test_cif_writer_label_numbers(cif_atoms):
    cif_atoms.write('testfile.cif')
    atoms1 = read('testfile.cif', store_tags=True)
    labels = atoms1.info['_atom_site_label']
    # cannot use atoms.symbols as K is missing there
    elements = atoms1.info['_atom_site_type_symbol']
    build_labels = [
        "{:}{:}".format(
            x, i) for x in set(elements) for i in range(
            1, elements.count(x) + 1)]
    assert build_labels.sort() == labels.sort()


def test_cif_labels(cif_atoms):
    data = [["label" + str(i) for i in range(20)]]  # test case has 20 entries
    cif_atoms.write('testfile.cif', labels=data)

    atoms1 = read('testfile.cif', store_tags=True)
    print(atoms1.info)
    assert data[0] == atoms1.info['_atom_site_label']


def test_cifloop():
    dct = {'_eggs': range(4),
           '_potatoes': [1.3, 7.1, -1, 0]}

    loop = CIFLoop()
    loop.add('_eggs', dct['_eggs'], '{:<2d}')
    loop.add('_potatoes', dct['_potatoes'], '{:.4f}')

    string = loop.tostring() + '\n\n'
    lines = string.splitlines()[::-1]
    assert lines.pop() == 'loop_'

    newdct = parse_loop(lines)
    print(newdct)
    assert set(dct) == set(newdct)
    for name in dct:
        assert dct[name] == pytest.approx(newdct[name])


@pytest.mark.parametrize('data', [b'', b'data_dummy'])
def test_empty_or_atomless(data):
    ciffile = io.BytesIO(data)

    images = read(ciffile, index=':', format='cif')
    assert len(images) == 0


def test_empty_or_atomless_cifblock():
    ciffile = io.BytesIO(b'data_dummy')
    blocks = list(parse_cif(ciffile))

    assert len(blocks) == 1
    assert not blocks[0].has_structure()
    with pytest.raises(NoStructureData):
        blocks[0].get_atoms()


def test_symbols_questionmark():
    ciffile = io.BytesIO(
        b'data_dummy\n'
        b'loop_\n'
        b'_atom_site_label\n'
        b'?\n')
    blocks = list(parse_cif(ciffile))
    assert not blocks[0].has_structure()
    with pytest.raises(NoStructureData, match='undetermined'):
        blocks[0].get_atoms()


def test_bad_occupancies(cif_atoms):
    assert 'Au' not in cif_atoms.symbols
    cif_atoms.symbols[0] = 'Au'
    with pytest.warns(UserWarning, match='no occupancy info'):
        write('tmp.cif', cif_atoms)


@pytest.mark.parametrize(
    'setting_name, ref_setting',
    [
        ('hexagonal', 1),
        ('trigonal', 2),
        ('rhombohedral', 2)
    ]
)
def test_spacegroup_named_setting(setting_name, ref_setting):
    """The rhombohedral crystal system signifies setting=2"""
    ciffile = io.BytesIO("""\
data_test
_space_group_crystal_system {}
_symmetry_space_group_name_H-M         'R-3m'
""".format(setting_name).encode('ascii'))

    blocks = list(parse_cif(ciffile))
    assert len(blocks) == 1
    spg = blocks[0].get_spacegroup(False)
    assert int(spg) == 166
    assert spg.setting == ref_setting


@pytest.fixture
def atoms():
    return Atoms('CO', cell=[2., 3., 4., 50., 60., 70.], pbc=True,
                 scaled_positions=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])


def roundtrip(atoms):
    from ase.io.bytes import to_bytes, parse_atoms
    buf = to_bytes(atoms, format='cif')
    return parse_atoms(buf, format='cif')


def test_cif_roundtrip_periodic(atoms):
    # Reading and writing the cell loses the rotation information,
    # but preserves cellpar and scaled positions.
    atoms1 = roundtrip(atoms)

    assert str(atoms1.symbols) == 'CO'
    assert all(atoms1.pbc)
    assert atoms.cell.cellpar() == pytest.approx(
        atoms1.cell.cellpar(), abs=1e-5)
    assert atoms.get_scaled_positions() == pytest.approx(
        atoms1.get_scaled_positions(), abs=1e-5)


def test_cif_roundtrip_nonperiodic():
    atoms = molecule('H2O')
    atoms1 = roundtrip(atoms)
    assert not compare_atoms(atoms, atoms1, tol=1e-5)


def test_cif_missingvector(atoms):
    # We don't know any way to represent only 2 cell vectors in CIF.
    # So we discard them and warn the user.
    atoms.cell[0] = 0.0
    atoms.pbc[0] = False

    assert atoms.cell.rank == 2

    with pytest.raises(ValueError, match='CIF format can only'):
        roundtrip(atoms)


def test_cif_roundtrip_mixed():
    atoms = Atoms('Au', cell=[1., 2., 3.], pbc=[1, 1, 0])
    atoms1 = roundtrip(atoms)

    # We cannot preserve PBC info for this case:
    assert all(atoms1.pbc)
    assert compare_atoms(atoms, atoms1, tol=1e-5) == ['pbc']
    assert atoms.get_scaled_positions() == pytest.approx(
        atoms1.get_scaled_positions(), abs=1e-5)
    #assert pytest.approx(atoms.positions) == atoms1.positions
    #assert atoms1.cell.rank == 0
