File: test_change_of_basis.py

package info (click to toggle)
spglib 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 14,180 kB
  • sloc: ansic: 125,066; python: 7,717; cpp: 2,197; f90: 2,143; ruby: 792; makefile: 22; sh: 18
file content (62 lines) | stat: -rw-r--r-- 2,227 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import annotations

import numpy as np
from spglib import get_symmetry_from_database


def test_change_of_basis(crystal_data_dataset):
    crystal_data = crystal_data_dataset["crystal_data"]
    dataset = crystal_data_dataset["dataset"]
    symprec = crystal_data_dataset["symprec"]
    std_pos = dataset.std_positions
    tmat = dataset.transformation_matrix
    orig_shift = dataset.origin_shift
    lat = np.dot(crystal_data.cell[0].T, np.linalg.inv(tmat))
    pos = np.dot(crystal_data.cell[1], tmat.T) + orig_shift
    for p in pos:
        diff = std_pos - p
        diff -= np.rint(diff)
        diff = np.dot(diff, lat.T)
        delta = np.sqrt((diff**2).sum(axis=1))
        indices = np.where(delta < symprec)[0]
        assert len(indices) == 1


def test_std_symmetry(crystal_data_dataset):
    dataset = crystal_data_dataset["dataset"]
    symmetry = get_symmetry_from_database(dataset.hall_number)
    std_pos = dataset.std_positions

    # for r, t in zip(symmetry['rotations'], symmetry['translations']):
    #     for rp in (np.dot(std_pos, r.T) + t):
    #         diff = std_pos - rp
    #         diff -= np.rint(diff)
    #         num_match = len(np.where(abs(diff).sum(axis=1) < 1e-3)[0])
    #         self.assertEqual(num_match, 1, msg="%s" % fname)

    # Equivalent above by numpy hack
    # 15 sec on macOS 2.3 GHz Intel Core i5 (4times faster than above)
    rot = symmetry["rotations"]
    trans = symmetry["translations"]
    # (n_sym, 3, n_atom)
    rot_pos = np.dot(rot, std_pos.T) + trans[:, :, None]
    for p in std_pos:
        diff = rot_pos - p[None, :, None]
        diff -= np.rint(diff)
        num_match = (abs(diff).sum(axis=1) < 1e-3).sum(axis=1)
        assert all(num_match == 1)


def test_std_rotation(crystal_data_dataset):
    crystal_data = crystal_data_dataset["crystal_data"]
    dataset = crystal_data_dataset["dataset"]
    symprec = crystal_data_dataset["symprec"]
    std_lat = dataset.std_lattice
    tmat = dataset.transformation_matrix
    lat = np.dot(crystal_data.cell[0].T, np.linalg.inv(tmat))
    lat_rot = np.dot(dataset.std_rotation_matrix, lat)
    np.testing.assert_allclose(
        std_lat,
        lat_rot.T,
        atol=symprec,
    )