File: test_pickle_bundle_trajectory.py

package info (click to toggle)
python-ase 3.21.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 13,936 kB
  • sloc: python: 122,428; xml: 946; makefile: 111; javascript: 47
file content (139 lines) | stat: -rw-r--r-- 4,101 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import pytest
import numpy as np
import sys
from subprocess import check_call, check_output

from pathlib import Path

from ase.build import bulk
from ase.io import read, write
from ase.io.pickletrajectory import PickleTrajectory
from ase.calculators.calculator import compare_atoms
from ase.calculators.emt import EMT
from ase.constraints import FixAtoms
from ase.io.bundletrajectory import (BundleTrajectory,
                                     print_bundletrajectory_info)


trajname = 'pickletraj.traj'


def test_raises():
    with pytest.raises(DeprecationWarning):
        PickleTrajectory(trajname, 'w')


@pytest.fixture
def images():
    atoms = bulk('Ti') * (1, 2, 1)
    atoms.symbols = 'Au'
    atoms.calc = EMT()
    atoms1 = atoms.copy()
    atoms1.rattle()
    images = [atoms, atoms1]

    # Set all sorts of weird data:
    for i, atoms in enumerate(images):
        ints = np.arange(len(atoms)) + i
        floats = 1.0 + np.arange(len(atoms))
        atoms.set_tags(ints)
        atoms.set_initial_magnetic_moments(floats)
        atoms.set_initial_charges(floats)
        atoms.set_masses(floats)
        floats3d = 1.2 * np.arange(3 * len(atoms)).reshape(-1, 3)
        atoms.set_momenta(floats3d)
        atoms.info = {'hello': 'world'}
        atoms.calc = EMT()
        atoms.get_potential_energy()

    atoms.set_constraint(FixAtoms(indices=[0]))
    return [atoms, atoms1]


def read_images(filename):
    with PickleTrajectory(filename, _warn=False) as traj:
        return list(traj)


@pytest.fixture
def trajfile(images):
    ptraj = PickleTrajectory(trajname, 'w', _warn=False)
    for image in images:
        ptraj.write(image)
    ptraj.close()
    return trajname


def assert_images_equal(images1, images2):
    assert len(images1) == len(images2), 'length mismatch'
    for atoms1, atoms2 in zip(images1, images2):
        differences = compare_atoms(atoms1, atoms2)
        assert not differences


def test_write_read_pickle(images, trajfile):
    images1 = read_images(trajfile)
    assert_images_equal(images, images1)


@pytest.mark.xfail(reason='bug: writes initial magmoms but reads magmoms '
                   'as part of calculator')
def test_write_read_bundle(images, bundletraj):
    images1 = read(bundletraj, ':')
    assert_images_equal(images, images1)


def test_append_pickle(images, trajfile):
    with PickleTrajectory(trajfile, 'a', _warn=False) as traj:
        for image in images:
            traj.write(image)

    images1 = read_images(trajfile)
    assert_images_equal(images * 2, images1)


@pytest.mark.xfail(reason='same as test_read_write_bundle')
def test_append_bundle(images, bundletraj):
    traj = BundleTrajectory(bundletraj, mode='a')
    assert len(read(bundletraj, ':')) == 2
    #write(bundletraj, images, append=True)
    for atoms in images:
        traj.write(atoms)
    traj.close()
    images1 = read(bundletraj, ':')
    assert len(images1) == 4
    # XXX Fix the magmoms/charges bug
    assert_images_equal(images * 2, images1)


def test_old_trajectory_conversion_utility(images, trajfile):
    trajpath = Path(trajfile)
    assert trajpath.exists()
    check_call([sys.executable, '-m', 'ase.io.trajectory', trajfile])
    oldtrajpath = trajpath.with_suffix('.traj.old')
    assert oldtrajpath.exists()
    assert trajpath.exists()  # New file should be where the old one was
    new_images = read(trajpath, ':', format='traj')
    assert_images_equal(images, new_images)


@pytest.fixture
def bundletraj(images):
    fname = 'traj.bundle'
    write(fname, images, format='bundletrajectory')
    return fname


def test_bundletrajectory_info(images, bundletraj, capsys):
    print_bundletrajectory_info(bundletraj)
    output, _ = capsys.readouterr()

    natoms = len(images[0])
    expected_substring = f'Number of atoms: {natoms}'
    assert expected_substring in output

    # Same thing but via main():
    output2 = check_output([sys.executable,
                            '-m', 'ase.io.bundletrajectory', bundletraj],
                           encoding='ascii')
    assert expected_substring in output2