1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
|
import pytest
import numpy as np
from ase.calculators.calculator import equal
def arrays(a, dtype):
return [list(map(dtype, a)), np.array(a, dtype=dtype)]
@pytest.mark.parametrize('a', [1, 1.])
@pytest.mark.parametrize('b', [1, 1.])
@pytest.mark.parametrize('rtol', [None, 0, 1e-8])
@pytest.mark.parametrize('atol', [None, 0, 1e-8])
def test_single_value(a, b, rtol, atol):
assert equal(a, b, rtol=rtol, atol=atol)
@pytest.mark.parametrize('a', arrays([1, 1], int) + arrays([1, 1], float))
@pytest.mark.parametrize('b', arrays([1, 1], int) + arrays([1, 1], float))
@pytest.mark.parametrize('rtol', [None, 0, 1e-8])
@pytest.mark.parametrize('atol', [None, 0, 1e-8])
def test_array_equal(a, b, rtol, atol):
assert a is not b
assert equal(a, b, rtol=rtol, atol=atol)
assert equal({'size': a, 'gamma': True},
{'size': b, 'gamma': True}, rtol=rtol, atol=atol)
assert not equal({'size': a, 'gamma': True},
{'size': b, 'gamma': False}, rtol=rtol, atol=atol)
assert not equal({'size': a}, b, rtol=rtol, atol=atol)
assert not equal(a, {'size': b}, rtol=rtol, atol=atol)
@pytest.mark.parametrize('a', arrays([2, 2], float))
@pytest.mark.parametrize('b', arrays(np.array([2, 2]) + 1.9e-8, float))
@pytest.mark.parametrize('rtol,atol', [[None, 2e-8], [0, 2e-8],
[1e-8, None], [1e-8, 0],
[0.5e-8, 1e-8]])
def test_array_almost_equal(a, b, rtol, atol):
assert a is not b
assert equal(a, b, rtol=rtol, atol=atol)
assert equal({'size': a, 'gamma': True},
{'size': b, 'gamma': True}, rtol=rtol, atol=atol)
@pytest.mark.parametrize('a', arrays([2, 2], float))
@pytest.mark.parametrize('b', arrays(np.array([2, 2]) + 3.1e-8, float))
@pytest.mark.parametrize('rtol', [None, 0, 1e-8])
@pytest.mark.parametrize('atol', [None, 0, 1e-8])
def test_array_not_equal(a, b, rtol, atol):
assert a is not b
assert not equal(a, b, rtol=rtol, atol=atol)
assert not equal({'size': a, 'gamma': True},
{'size': b, 'gamma': True}, rtol=rtol, atol=atol)
|