1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
from typing import List
import numpy as np
from ase import Atoms
from .spacegroup import Spacegroup, _SPACEGROUP
__all__ = ('get_basis', )
def _has_spglib() -> bool:
"""Check if spglib is available"""
try:
import spglib
assert spglib # silence flakes
except ImportError:
return False
return True
def _get_basis_ase(atoms: Atoms,
spacegroup: _SPACEGROUP,
tol: float = 1e-5) -> np.ndarray:
"""Recursively get a reduced basis, by removing equivalent sites.
Uses the first index as a basis, then removes all equivalent sites,
uses the next index which hasn't been placed into a basis, etc.
:param atoms: Atoms object to get basis from.
:param spacegroup: ``int``, ``str``, or
:class:`ase.spacegroup.Spacegroup` object.
:param tol: ``float``, numeric tolerance for positional comparisons
Default: ``1e-5``
"""
scaled_positions = atoms.get_scaled_positions()
spacegroup = Spacegroup(spacegroup)
def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
"""Check if a scaled position is in a site"""
for site in sites:
if np.allclose(site, scaled_pos, atol=tol):
return True
return False
def _get_basis(scaled_positions: np.ndarray,
spacegroup: Spacegroup,
all_basis=None) -> np.ndarray:
"""Main recursive function to be executed"""
if all_basis is None:
# Initialization, first iteration
all_basis = []
if len(scaled_positions) == 0:
# End termination
return np.array(all_basis)
basis = scaled_positions[0]
all_basis.append(basis.tolist()) # Add the site as a basis
# Get equivalent sites
sites, _ = spacegroup.equivalent_sites(basis)
# Remove equivalent
new_scaled = np.array(
[sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
# We should always have at least popped off the site itself
assert len(new_scaled) < len(scaled_positions)
return _get_basis(new_scaled, spacegroup, all_basis=all_basis)
return _get_basis(scaled_positions, spacegroup)
def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
"""Get a reduced basis using spglib. This requires having the
spglib package installed.
:param atoms: Atoms, atoms object to get basis from
:param tol: ``float``, numeric tolerance for positional comparisons
Default: ``1e-5``
"""
if not _has_spglib():
# Give a reasonable alternative solution to this function.
raise ImportError(
('This function requires spglib. Use "get_basis" and specify '
'the spacegroup instead, or install spglib.'))
scaled_positions = atoms.get_scaled_positions()
reduced_indices = _get_reduced_indices(atoms, tol=tol)
return scaled_positions[reduced_indices]
def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool:
"""Helper dispatch function, for deciding if the spglib implementation
can be used"""
if not _has_spglib():
# Spglib not installed
return False
if spacegroup is not None:
# Currently, passing an explicit space group is not supported
# in spglib implementation
return False
return True
# Dispatcher function for chosing get_basis implementation.
def get_basis(atoms: Atoms,
spacegroup: _SPACEGROUP = None,
method: str = 'auto',
tol: float = 1e-5) -> np.ndarray:
"""Function for determining a reduced basis of an atoms object.
Can use either an ASE native algorithm or an spglib based one.
The native ASE version requires specifying a space group,
while the (current) spglib version cannot.
The default behavior is to automatically determine which implementation
to use, based on the the ``spacegroup`` parameter,
and whether spglib is installed.
:param atoms: ase Atoms object to get basis from
:param spacegroup: Optional, ``int``, ``str``
or :class:`ase.spacegroup.Spacegroup` object.
If unspecified, the spacegroup can be inferred using spglib,
if spglib is installed, and ``method`` is set to either
``'spglib'`` or ``'auto'``.
Inferring the spacegroup requires spglib.
:param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
Selection of which implementation to use.
It is recommended to use ``'auto'``, which is also the default.
:param tol: ``float``, numeric tolerance for positional comparisons
Default: ``1e-5``
"""
ALLOWED_METHODS = ('auto', 'ase', 'spglib')
if method not in ALLOWED_METHODS:
raise ValueError('Expected one of {} methods, got {}'.format(
ALLOWED_METHODS, method))
if method == 'auto':
# Figure out which implementation we want to use automatically
# Essentially figure out if we can use the spglib version or not
use_spglib = _can_use_spglib(spacegroup=spacegroup)
else:
# User told us which implementation they wanted
use_spglib = method == 'spglib'
if use_spglib:
# Use the spglib implementation
# Note, we do not pass the spacegroup, as the function cannot handle
# an explicit space group right now. This may change in the future.
return _get_basis_spglib(atoms, tol=tol)
else:
# Use the ASE native non-spglib version, since a specific
# space group is requested
if spacegroup is None:
# We have reached this point either because spglib is not installed,
# or ASE was explicitly required
raise ValueError(
('A space group must be specified for the native ASE '
'implementation. Try using the spglib version instead, '
'or explicitly specifying a space group.'))
return _get_basis_ase(atoms, spacegroup, tol=tol)
def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]:
"""Get a list of the reduced atomic indices using spglib.
Note: Does no checks to see if spglib is installed.
:param atoms: ase Atoms object to reduce
:param tol: ``float``, numeric tolerance for positional comparisons
"""
import spglib
# Create input for spglib
spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
atoms.numbers)
symmetry_data = spglib.get_symmetry_dataset(spglib_cell, symprec=tol)
return list(set(symmetry_data['equivalent_atoms']))
|