File: test_vasp_input.py

package info (click to toggle)
python-ase 3.21.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 13,936 kB
  • sloc: python: 122,428; xml: 946; makefile: 111; javascript: 47
file content (181 lines) | stat: -rw-r--r-- 6,007 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import pytest
from unittest import mock

import numpy as np
from ase.calculators.vasp.create_input import GenerateVaspInput
from ase.calculators.vasp.create_input import _args_without_comment
from ase.calculators.vasp.create_input import _to_vasp_bool, _from_vasp_bool

from ase.build import bulk


@pytest.fixture
def rng():
    return np.random.RandomState(seed=42)


@pytest.fixture
def nacl(rng):
    atoms = bulk('NaCl', crystalstructure='rocksalt', a=4.1,
                 cubic=True) * (3, 3, 3)
    rng.shuffle(atoms.symbols)  # Ensure symbols are mixed
    return atoms


@pytest.fixture
def vaspinput_factory(nacl):
    """Factory for GenerateVaspInput class, which mocks the generation of
    pseudopotentials."""
    def _vaspinput_factory(atoms, **kwargs) -> GenerateVaspInput:
        mocker = mock.Mock()
        inputs = GenerateVaspInput()
        inputs.set(**kwargs)
        inputs._build_pp_list = mocker(return_value=None)  # type: ignore
        inputs.initialize(atoms)
        return inputs

    return _vaspinput_factory


def test_sorting(nacl, vaspinput_factory):
    """Test that the sorting/resorting scheme works"""
    vaspinput = vaspinput_factory(nacl)
    srt = vaspinput.sort
    resrt = vaspinput.resort
    atoms = nacl.copy()
    assert atoms[srt] != nacl
    assert atoms[resrt] != nacl
    assert atoms[srt][resrt] == nacl

    # Check the first and second half of the sorted atoms have the same symbols
    assert len(atoms) % 2 == 0  # We should have an even number of atoms
    atoms_sorted = atoms[srt]
    N = len(atoms) // 2
    seq1 = set(atoms_sorted.symbols[:N])
    seq2 = set(atoms_sorted.symbols[N:])
    assert len(seq1) == 1
    assert len(seq2) == 1
    # Check that we have two different symbols
    assert len(seq1.intersection(seq2)) == 0


@pytest.fixture(params=['random', 'ones', 'binaries'])
def magmoms_factory(rng, request):
    """Factory for generating various kinds of magnetic moments"""
    kind = request.param
    if kind == 'random':
        # Array of random
        func = rng.rand
    elif kind == 'ones':
        # Array of just 1's
        func = np.ones
    elif kind == 'binaries':
        # Array of 0's and 1's
        def rand_binary(x):
            return rng.randint(2, size=x)

        func = rand_binary
    else:
        raise ValueError(f'Unknown kind: {kind}')

    def _magmoms_factory(atoms):
        magmoms = func(len(atoms))
        assert len(magmoms) == len(atoms)
        return magmoms

    return _magmoms_factory


def read_magmom_from_file(filename) -> np.ndarray:
    """Helper function to parse the magnetic moments from an INCAR file"""
    found = False
    with open(filename) as file:
        for line in file:
            # format "MAGMOM = n1*val1 n2*val2 ..."
            if 'MAGMOM = ' in line:
                found = True
                parts = line.strip().split()[2:]
                new_magmom = []
                for part in parts:
                    n, val = part.split('*')
                    # Add "val" to magmom "n" times
                    new_magmom += int(n) * [float(val)]
                break
    assert found
    return np.array(new_magmom)


@pytest.fixture
def assert_magmom_equal_to_incar_value():
    """Fixture to compare a pre-made magmom array to the value
    a GenerateVaspInput.write_incar object writes to a file"""
    def _assert_magmom_equal_to_incar_value(atoms, expected_magmom, vaspinput):
        assert len(atoms) == len(expected_magmom)
        vaspinput.write_incar(atoms)
        new_magmom = read_magmom_from_file('INCAR')
        assert len(new_magmom) == len(expected_magmom)
        srt = vaspinput.sort
        resort = vaspinput.resort
        # We round to 4 digits
        assert np.allclose(expected_magmom, new_magmom[resort], atol=1e-3)
        assert np.allclose(np.array(expected_magmom)[srt],
                           new_magmom,
                           atol=1e-3)

    return _assert_magmom_equal_to_incar_value


@pytest.mark.parametrize('list_func', [list, tuple, np.array])
def test_write_magmom(magmoms_factory, list_func, nacl, vaspinput_factory,
                      assert_magmom_equal_to_incar_value):
    """Test writing magnetic moments to INCAR, and ensure we can do it
    passing different types of sequences"""
    magmom = magmoms_factory(nacl)

    vaspinput = vaspinput_factory(nacl, magmom=magmom, ispin=2)
    assert vaspinput.spinpol
    assert_magmom_equal_to_incar_value(nacl, magmom, vaspinput)


def test_atoms_with_initial_magmoms(magmoms_factory, nacl, vaspinput_factory,
                                    assert_magmom_equal_to_incar_value):
    """Test passing atoms with initial magnetic moments"""
    magmom = magmoms_factory(nacl)
    assert len(magmom) == len(nacl)
    nacl.set_initial_magnetic_moments(magmom)
    vaspinput = vaspinput_factory(nacl)
    assert vaspinput.spinpol
    assert_magmom_equal_to_incar_value(nacl, magmom, vaspinput)


def test_vasp_from_bool():
    for s in ('T', '.true.'):
        assert _from_vasp_bool(s) is True
    for s in ('f', '.False.'):
        assert _from_vasp_bool(s) is False
    with pytest.raises(ValueError):
        _from_vasp_bool('yes')
    with pytest.raises(AssertionError):
        _from_vasp_bool(True)


def test_vasp_to_bool():
    for x in ('T', '.true.', True):
        assert _to_vasp_bool(x) == '.TRUE.'
    for x in ('f', '.FALSE.', False):
        assert _to_vasp_bool(x) == '.FALSE.'

    with pytest.raises(ValueError):
        _to_vasp_bool('yes')
    with pytest.raises(AssertionError):
        _to_vasp_bool(1)


@pytest.mark.parametrize('args, expected_len',
                         [(['a', 'b', '#', 'c'], 2),
                          (['a', 'b', '!', 'c', '#', 'd'], 2),
                          (['#', 'a', 'b', '!', 'c', '#', 'd'], 0)])
def test_vasp_args_without_comment(args, expected_len):
    """Test comment splitting logic"""
    clean_args = _args_without_comment(args)
    assert len(clean_args) == expected_len