1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
import warnings
import pytest
import numpy as np
from ase import Atoms
from ase.io import write, read, iread
from ase.io.formats import all_formats, ioformats
from ase.calculators.singlepoint import SinglePointCalculator
try:
import matplotlib
except ImportError:
matplotlib = 0
try:
import netCDF4
except ImportError:
netCDF4 = 0
@pytest.fixture
def atoms():
a = 5.0
d = 1.9
c = a / 2
atoms = Atoms('AuH',
positions=[(0, c, c), (d, c, c)],
cell=(2 * d, a, a),
pbc=(1, 0, 0))
extra = np.array([2.3, 4.2])
atoms.set_array('extra', extra)
atoms *= (2, 1, 1)
# attach some results to the Atoms.
# These are serialised by the extxyz writer.
spc = SinglePointCalculator(atoms,
energy=-1.0,
stress=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
forces=-1.0 * atoms.positions)
atoms.calc = spc
return atoms
def check(a, ref_atoms, format):
assert abs(a.positions - ref_atoms.positions).max() < 1e-6, \
(a.positions - ref_atoms.positions)
if format in ['traj', 'cube', 'cfg', 'struct', 'gen', 'extxyz',
'db', 'json', 'trj']:
assert abs(a.cell - ref_atoms.cell).max() < 1e-6
if format in ['cfg', 'extxyz']:
assert abs(a.get_array('extra') -
ref_atoms.get_array('extra')).max() < 1e-6
if format in ['extxyz', 'traj', 'trj', 'db', 'json']:
assert (a.pbc == ref_atoms.pbc).all()
assert a.get_potential_energy() == ref_atoms.get_potential_energy()
assert (a.get_stress() == ref_atoms.get_stress()).all()
assert abs(a.get_forces() - ref_atoms.get_forces()).max() < 1e-12
@pytest.fixture
def catch_warnings():
with warnings.catch_warnings():
yield
def all_tested_formats():
skip = []
# Someone should do something ...
skip += ['dftb', 'eon', 'lammps-data']
# Standalone test used as not compatible with 1D periodicity
skip += ['v-sim', 'mustem', 'prismatic']
# We have a standalone dmol test
skip += ['dmol-arc', 'dmol-car', 'dmol-incoor']
# Complex dependencies; see animate.py test
skip += ['gif', 'mp4']
# Let's not worry about these.
skip += ['postgresql', 'trj', 'vti', 'vtu', 'mysql']
if not matplotlib:
skip += ['eps', 'png']
if not netCDF4:
skip += ['netcdftrajectory']
return sorted(set(all_formats) - set(skip))
@pytest.mark.parametrize('format', all_tested_formats())
def test_ioformat(format, atoms, catch_warnings):
if format in ['proteindatabank', 'netcdftrajectory']:
warnings.simplefilter('ignore', UserWarning)
# netCDF4 uses np.bool which may cause warnings in new numpy.
warnings.simplefilter('ignore', DeprecationWarning)
if format == 'dlp4':
atoms.pbc = (1, 1, 0)
images = [atoms, atoms]
io = ioformats[format]
print('{0:20}{1}{2}{3}{4}'.format(format,
' R'[io.can_read],
' W'[io.can_write],
'+1'[io.single],
'SF'[io.acceptsfd]))
fname1 = 'io-test.1.{}'.format(format)
fname2 = 'io-test.2.{}'.format(format)
if io.can_write:
write(fname1, atoms, format=format)
if not io.single:
write(fname2, images, format=format)
if io.can_read:
for a in [read(fname1, format=format), read(fname1)]:
check(a, atoms, format)
if not io.single:
if format in ['json', 'db']:
aa = read(fname2, index='id=1') + read(fname2, index='id=2')
else:
aa = [read(fname2), read(fname2, 0)]
aa += read(fname2, ':')
for a in iread(fname2, format=format):
aa.append(a)
assert len(aa) == 6, aa
for a in aa:
check(a, atoms, format)
|