File: test_spacegroup_utils.py

package info (click to toggle)
python-ase 3.21.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 13,936 kB
  • sloc: python: 122,428; xml: 946; makefile: 111; javascript: 47
file content (121 lines) | stat: -rw-r--r-- 3,624 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pytest
import numpy as np
from ase.build import bulk
from ase.spacegroup import crystal, Spacegroup
from ase.spacegroup.spacegroup import SpacegroupValueError
from ase.spacegroup import utils


@pytest.fixture(params=[
    # Use lambda's to not crash during collection if there's an error
    lambda: {
        'atoms': bulk('NaCl', crystalstructure='rocksalt', a=4.1),
        'spacegroup': 225,
        'expected': [[0, 0, 0], [0.5, 0.5, 0.5]]
    },
    # diamond
    lambda: {
        'atoms':
        crystal('C', [(0, 0, 0)],
                spacegroup=227,
                cellpar=[4, 4, 4, 90, 90, 90],
                primitive_cell=True),
        'spacegroup':
        227,
        'expected': [[0, 0, 0]]
    },
    lambda: {
        'atoms':
        crystal('Mg', [(1 / 3, 2 / 3, 3 / 4)],
                spacegroup=194,
                cellpar=[3.21, 3.21, 5.21, 90, 90, 120]),
        'spacegroup':
        194,
        'expected': [(1 / 3, 2 / 3, 3 / 4)]
    },
    lambda: {
        'atoms':
        crystal(['Ti', 'O'],
                basis=[(0, 0, 0), (0.3, 0.3, 0.0)],
                spacegroup=136,
                cellpar=[4, 4, 6, 90, 90, 90]),
        'spacegroup':
        Spacegroup(136),
        'expected': [(0, 0, 0), (0.3, 0.3, 0.0)]
    },
])
def basis_tests(request):
    """Fixture which returns a dictionary with some test inputs and expected values
    for testing the `get_basis` function."""
    return request.param()


def test_get_basis(basis_tests):
    """Test explicitly passing spacegroup and getting basis"""
    atoms = basis_tests['atoms']
    expected = basis_tests['expected']
    spacegroup = basis_tests['spacegroup']

    basis = utils.get_basis(atoms, spacegroup=spacegroup)
    assert np.allclose(basis, expected)


def test_get_basis_infer_sg(basis_tests):
    """Test inferring spacegroup, which uses 'get_basis_spglib' under the hood"""
    pytest.importorskip('spglib')

    atoms = basis_tests['atoms']
    expected = basis_tests['expected']

    basis = utils.get_basis(atoms)
    assert np.allclose(basis, expected)


def test_get_basis_spglib(basis_tests):
    """Test getting the basis using spglib"""
    pytest.importorskip('spglib')

    atoms = basis_tests['atoms']
    expected = basis_tests['expected']

    basis = utils._get_basis_spglib(atoms)
    assert np.allclose(basis, expected)


def test_get_basis_ase(basis_tests):
    atoms = basis_tests['atoms']
    spacegroup = basis_tests['spacegroup']
    expected = basis_tests['expected']

    basis = utils._get_basis_ase(atoms, spacegroup)
    assert np.allclose(basis, expected)


@pytest.mark.parametrize('spacegroup', [251.5, [1, 2, 3], np.array([255])])
def test_get_basis_wrong_type(basis_tests, spacegroup):
    atoms = basis_tests['atoms']

    with pytest.raises(SpacegroupValueError):
        utils._get_basis_ase(atoms, spacegroup)
    with pytest.raises(SpacegroupValueError):
        utils.get_basis(atoms, spacegroup=spacegroup)


@pytest.mark.parametrize('method', [None, 12, 'nonsense', True, False])
def test_get_basis_wrong_method(basis_tests, method):
    """Test passing in un-supported methods"""
    atoms = basis_tests['atoms']
    with pytest.raises(ValueError):
        utils.get_basis(atoms, method=method)


def test_get_basis_group_1(basis_tests):
    """Always use spacegroup 1, nothing should be symmetrically equivalent"""
    atoms = basis_tests['atoms']
    scaled = atoms.get_scaled_positions()

    spacegroup = 1

    basis = utils.get_basis(atoms, spacegroup)
    # Basis should now be the same as the scaled positions
    assert np.allclose(basis, scaled)