File: utils.py

package info (click to toggle)
python-ase 3.21.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 13,936 kB
  • sloc: python: 122,428; xml: 946; makefile: 111; javascript: 47
file content (173 lines) | stat: -rw-r--r-- 6,686 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import List
import numpy as np
from ase import Atoms
from .spacegroup import Spacegroup, _SPACEGROUP

__all__ = ('get_basis', )


def _has_spglib() -> bool:
    """Check if spglib is available"""
    try:
        import spglib
        assert spglib  # silence flakes
    except ImportError:
        return False
    return True


def _get_basis_ase(atoms: Atoms,
                   spacegroup: _SPACEGROUP,
                   tol: float = 1e-5) -> np.ndarray:
    """Recursively get a reduced basis, by removing equivalent sites.
    Uses the first index as a basis, then removes all equivalent sites,
    uses the next index which hasn't been placed into a basis, etc.

    :param atoms: Atoms object to get basis from.
    :param spacegroup: ``int``, ``str``, or
        :class:`ase.spacegroup.Spacegroup` object.
    :param tol: ``float``, numeric tolerance for positional comparisons
        Default: ``1e-5``
    """
    scaled_positions = atoms.get_scaled_positions()
    spacegroup = Spacegroup(spacegroup)

    def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
        """Check if a scaled position is in a site"""
        for site in sites:
            if np.allclose(site, scaled_pos, atol=tol):
                return True
        return False

    def _get_basis(scaled_positions: np.ndarray,
                   spacegroup: Spacegroup,
                   all_basis=None) -> np.ndarray:
        """Main recursive function to be executed"""
        if all_basis is None:
            # Initialization, first iteration
            all_basis = []
        if len(scaled_positions) == 0:
            # End termination
            return np.array(all_basis)

        basis = scaled_positions[0]
        all_basis.append(basis.tolist())  # Add the site as a basis

        # Get equivalent sites
        sites, _ = spacegroup.equivalent_sites(basis)

        # Remove equivalent
        new_scaled = np.array(
            [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
        # We should always have at least popped off the site itself
        assert len(new_scaled) < len(scaled_positions)

        return _get_basis(new_scaled, spacegroup, all_basis=all_basis)

    return _get_basis(scaled_positions, spacegroup)


def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
    """Get a reduced basis using spglib. This requires having the
    spglib package installed.

    :param atoms: Atoms, atoms object to get basis from
    :param tol: ``float``, numeric tolerance for positional comparisons
        Default: ``1e-5``
    """
    if not _has_spglib():
        # Give a reasonable alternative solution to this function.
        raise ImportError(
            ('This function requires spglib. Use "get_basis" and specify '
             'the spacegroup instead, or install spglib.'))

    scaled_positions = atoms.get_scaled_positions()
    reduced_indices = _get_reduced_indices(atoms, tol=tol)
    return scaled_positions[reduced_indices]


def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool:
    """Helper dispatch function, for deciding if the spglib implementation
    can be used"""
    if not _has_spglib():
        # Spglib not installed
        return False
    if spacegroup is not None:
        # Currently, passing an explicit space group is not supported
        # in spglib implementation
        return False
    return True


# Dispatcher function for chosing get_basis implementation.
def get_basis(atoms: Atoms,
              spacegroup: _SPACEGROUP = None,
              method: str = 'auto',
              tol: float = 1e-5) -> np.ndarray:
    """Function for determining a reduced basis of an atoms object.
    Can use either an ASE native algorithm or an spglib based one.
    The native ASE version requires specifying a space group,
    while the (current) spglib version cannot.
    The default behavior is to automatically determine which implementation
    to use, based on the the ``spacegroup`` parameter,
    and whether spglib is installed.

    :param atoms: ase Atoms object to get basis from
    :param spacegroup: Optional, ``int``, ``str``
        or :class:`ase.spacegroup.Spacegroup` object.
        If unspecified, the spacegroup can be inferred using spglib,
        if spglib is installed, and ``method`` is set to either
        ``'spglib'`` or ``'auto'``.
        Inferring the spacegroup requires spglib.
    :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
        Selection of which implementation to use.
        It is recommended to use ``'auto'``, which is also the default.
    :param tol: ``float``, numeric tolerance for positional comparisons
        Default: ``1e-5``
    """
    ALLOWED_METHODS = ('auto', 'ase', 'spglib')

    if method not in ALLOWED_METHODS:
        raise ValueError('Expected one of {} methods, got {}'.format(
            ALLOWED_METHODS, method))

    if method == 'auto':
        # Figure out which implementation we want to use automatically
        # Essentially figure out if we can use the spglib version or not
        use_spglib = _can_use_spglib(spacegroup=spacegroup)
    else:
        # User told us which implementation they wanted
        use_spglib = method == 'spglib'

    if use_spglib:
        # Use the spglib implementation
        # Note, we do not pass the spacegroup, as the function cannot handle
        # an explicit space group right now. This may change in the future.
        return _get_basis_spglib(atoms, tol=tol)
    else:
        # Use the ASE native non-spglib version, since a specific
        # space group is requested
        if spacegroup is None:
            # We have reached this point either because spglib is not installed,
            # or ASE was explicitly required
            raise ValueError(
                ('A space group must be specified for the native ASE '
                 'implementation. Try using the spglib version instead, '
                 'or explicitly specifying a space group.'))
        return _get_basis_ase(atoms, spacegroup, tol=tol)


def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]:
    """Get a list of the reduced atomic indices using spglib.
    Note: Does no checks to see if spglib is installed.
    
    :param atoms: ase Atoms object to reduce
    :param tol: ``float``, numeric tolerance for positional comparisons
    """
    import spglib

    # Create input for spglib
    spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
                   atoms.numbers)
    symmetry_data = spglib.get_symmetry_dataset(spglib_cell, symprec=tol)
    return list(set(symmetry_data['equivalent_atoms']))