1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
|
from pymbar import timeseries
from pymbar import testsystems
import numpy as np
from scipy import stats
import pytest
from pymbar.utils_for_testing import assert_almost_equal
try:
import statsmodels.api as sm # pylint: disable=unused-import
HAVE_STATSMODELS = True
except ImportError as err:
HAVE_STATSMODELS = False
has_statmodels = pytest.mark.skipif(
not HAVE_STATSMODELS, reason="Skipping FFT based tests because statsmodels not installed."
)
@pytest.fixture(scope="module")
def data(N=10000, K=10):
var = np.ones(N)
for replica in range(2, K + 1):
var = np.concatenate((var, np.ones(N)))
X = np.random.normal(np.zeros(K * N), var).reshape((K, N)) / 10.0
Y = np.random.normal(np.zeros(K * N), var).reshape((K, N))
energy = 10 * (X**2) / 2.0 + (Y**2) / 2.0
return X, Y, energy
def test_statistical_inefficiency_single(data):
X, Y, energy = data
timeseries.statistical_inefficiency(X[0])
timeseries.statistical_inefficiency(X[0], X[0])
timeseries.statistical_inefficiency(X[0] ** 2)
timeseries.statistical_inefficiency(X[0] ** 2, X[0] ** 2)
timeseries.statistical_inefficiency(energy[0])
timeseries.statistical_inefficiency(energy[0], energy[0])
timeseries.statistical_inefficiency(X[0], X[0] ** 2)
# TODO: Add some checks to test statistical inefficinecies are within normal range
def test_statistical_inefficiency_multiple(data):
X, Y, energy = data
timeseries.statistical_inefficiency_multiple(X)
timeseries.statistical_inefficiency_multiple(X**2)
timeseries.statistical_inefficiency_multiple(X[0, :] ** 2)
timeseries.statistical_inefficiency_multiple(X[0:2, :] ** 2)
timeseries.statistical_inefficiency_multiple(energy)
# TODO: Add some checks to test statistical inefficinecies are within normal range
@has_statmodels
def test_statistical_inefficiency_fft(data):
X, Y, energy = data
timeseries.statistical_inefficiency_fft(X[0])
timeseries.statistical_inefficiency_fft(X[0] ** 2)
timeseries.statistical_inefficiency_fft(energy[0])
g0 = timeseries.statistical_inefficiency_fft(X[0])
g1 = timeseries.statistical_inefficiency(X[0])
g2 = timeseries.statistical_inefficiency(X[0], X[0])
g3 = timeseries.statistical_inefficiency(X[0], fft=True)
assert_almost_equal(g0, g1, decimal=6)
assert_almost_equal(g0, g2, decimal=6)
assert_almost_equal(g0, g3, decimal=6)
@has_statmodels
def test_statistical_inefficiency_fft_gaussian():
# Run multiple times to get things with and without negative "spikes" at C(1)
for i in range(5):
x = np.random.normal(size=100000)
g0 = timeseries.statistical_inefficiency(x, fast=False)
g1 = timeseries.statistical_inefficiency(x, x, fast=False)
g2 = timeseries.statistical_inefficiency_fft(x)
g3 = timeseries.statistical_inefficiency(x, fft=True)
assert_almost_equal(g0, g1, decimal=5)
assert_almost_equal(g0, g2, decimal=5)
assert_almost_equal(g0, g3, decimal=5)
assert_almost_equal(np.log(g0), np.log(1.0), decimal=1)
for i in range(5):
x = np.random.normal(size=100000)
x = np.repeat(
x, 3
) # e.g. Construct correlated gaussian e.g. [a, b, c] -> [a, a, a, b, b, b, c, c, c]
g0 = timeseries.statistical_inefficiency(x, fast=False)
g1 = timeseries.statistical_inefficiency(x, x, fast=False)
g2 = timeseries.statistical_inefficiency_fft(x)
g3 = timeseries.statistical_inefficiency(x, fft=True)
assert_almost_equal(g0, g1, decimal=5)
assert_almost_equal(g0, g2, decimal=5)
assert_almost_equal(g0, g3, decimal=5)
assert_almost_equal(np.log(g0), np.log(3.0), decimal=1)
def test_detectEquil():
x = np.random.normal(size=10000)
t, g, Neff_max = timeseries.detect_equilibration(x)
@has_statmodels
def test_detectEquil_binary():
x = np.random.normal(size=10000)
t, g, Neff_max = timeseries.detect_equilibration_binary_search(x)
@has_statmodels
def test_compare_detectEquil(show_hist=False):
"""
compare detect_equilibration implementations (with and without binary search + fft)
"""
t_res = []
N = 100
for _ in range(100):
A_t = testsystems.correlated_timeseries_example(N=N, tau=5.0) + 2.0
B_t = testsystems.correlated_timeseries_example(N=N, tau=5.0) + 1.0
C_t = testsystems.correlated_timeseries_example(N=N * 2, tau=5.0)
D_t = np.concatenate([A_t, B_t, C_t])
bs_de = timeseries.detect_equilibration_binary_search(D_t, bs_nodes=10)
std_de = timeseries.detect_equilibration(D_t, fast=False, nskip=1)
t_res.append(bs_de[0] - std_de[0])
try:
# scipy<1.9
t_res_mode = float(stats.mode(t_res)[0][0])
except IndexError:
# scipy>=1.9
t_res_mode = float(stats.mode(t_res, keepdims=True)[0][0])
assert_almost_equal(t_res_mode, 0.0, decimal=1)
if show_hist:
import matplotlib.pyplot as plt
plt.hist(t_res)
plt.show()
def test_detectEquil_constant_trailing():
"""
This explicitly tests issue #122, see https://github.com/choderalab/pymbar/issues/122
We only check that the code doesn't give an exception. The exact value of Neff can either be
~50 if we try to include part of the equilibration samples, or it can be Neff=1 if we find that the
whole first half is discarded.
"""
x = np.random.normal(size=100) * 0.01
x[50:] = 3.0
# The input data is some MCMC chain where the trailing end of the chain is a constant sequence.
t, g, Neff_max = timeseries.detect_equilibration(x)
def test_correlationFunctionMultiple():
"""
tests the truncate and norm feature
"""
A_t = [testsystems.correlated_timeseries_example(N=10000, tau=10.0) for _ in range(10)]
corr_norm = timeseries.normalized_fluctuation_correlation_function_multiple(A_kn=A_t)
corr = timeseries.normalized_fluctuation_correlation_function_multiple(A_kn=A_t, norm=False)
corr_norm_trun = timeseries.normalized_fluctuation_correlation_function_multiple(
A_kn=A_t, truncate=True
)
corr_trun = timeseries.normalized_fluctuation_correlation_function_multiple(
A_kn=A_t, norm=False, truncate=True
)
assert corr_norm_trun[-1] >= 0
assert corr_trun[-1] >= 0
assert corr_norm[0] == 1.0
assert corr_norm_trun[0] == 1.0
assert len(corr_trun) == len(corr_norm_trun)
|