1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
|
import inspect
import itertools
import os
import numpy
tol = float(os.environ.get("NRN_RXD_TEST_TOLERANCE", "1e-10"))
dt_eps = 1e-20
def get_data_file_name(frame):
"""returns the filename for the data file need for the test."""
curframe = inspect.currentframe()
calframe = inspect.getouterframes(curframe, 4)
testfunc_name = calframe[frame][3]
assert testfunc_name.startswith("test_")
return testfunc_name[5:] + ".dat"
def get_correct_data_for_test():
"""returns a path to the file with the correct data for a test."""
data_filename = get_data_file_name(frame=3)
basepath = os.path.dirname(os.path.abspath(__file__))
return os.path.join(basepath, "testdata", "test", data_filename)
def save_data_from_test(save_path):
"""save the data generated by a test."""
basepath = os.path.abspath(save_path)
if not os.path.exists(basepath):
os.mkdir(basepath)
filepath = os.path.join(basepath, get_data_file_name(frame=4))
return filepath
def collect_data(h, rxd, data, save_path, num_record=10):
"""grabs the membrane potential data, h.t, and the rxd state values"""
data["record_count"] += 1
if data["record_count"] > num_record:
h.stoprun = True
return
all_potentials = [seg.v for seg in itertools.chain.from_iterable(h.allsec())]
rxd_1d = list(rxd.node._states)
rxd_3d = []
rxd_ecs = []
for sp in rxd.species._all_species:
s = sp()
if s and hasattr(s, "_intracellular_instances"):
for ics in s._intracellular_instances.values():
rxd_3d += list(ics.states)
if s and hasattr(s, "_extracellular_instances"):
for ecs in s._extracellular_instances.values():
rxd_ecs += list(ecs.states.flatten())
all_rxd = rxd_1d + rxd_3d + rxd_ecs
local_data = [h.t] + all_potentials + all_rxd
# remove data before t=0
if h.t == 0:
data["data"] = []
data["record_count"] = 1
# remove previous record if h.t is the same
if data["record_count"] > 1 and h.t == data["data"][-len(local_data)]:
data["record_count"] -= 1
del data["data"][-len(local_data) :]
# add new data record
data["data"].extend(local_data)
if data["record_count"] == 2:
data["rlen"] = len(local_data)
# save the test data (if the --save option was used)
if save_path:
with open(save_data_from_test(save_path), "wb") as f:
numpy.array(data["data"]).tofile(f)
def compare_data(data):
"""compares the test data with the correct data"""
rlen = data["rlen"]
corr_dat = numpy.fromfile(get_correct_data_for_test()).reshape(-1, rlen)
tst_dat = numpy.array(data["data"]).reshape(-1, rlen)
t1 = corr_dat[:, 0]
t2 = tst_dat[:, 0]
# remove any initial t that are greter than the next t
# (removes times before 0) in correct data
c = 0
while c < len(t1) - 1 and t1[c] > t1[c + 1]:
c = c + 1
t1 = numpy.delete(t1, range(c))
corr_dat = numpy.delete(corr_dat, range(c), 0)
# remove any initial t that are greter than the next t
# (removes times before 0) in test data
c = 0
while c < len(t2) - 1 and t2[c] > t2[c + 1]:
c = c + 1
t2 = numpy.delete(t2, range(c))
tst_dat = numpy.delete(tst_dat, range(c), 0)
# get rid of repeating t in correct data (otherwise interpolation fails)
c = 0
while c < len(t1) - 1:
c1 = c + 1
while c1 < len(t1) and abs(t1[c] - t1[c1]) < dt_eps:
c1 = c1 + 1
t1 = numpy.delete(t1, range(c, c1 - 1))
corr_dat = numpy.delete(corr_dat, range(c, c1 - 1), 0)
c = c + 1
# get rid of the test data outside of the correct data time interval
t2_n = len(t2)
t2_0 = 0
while t2[t2_n - 1] > t1[-1]:
t2_n = t2_n - 1
while t2[t2_0] < t1[0]:
t2_0 = t2_0 + 1
# interpolate and compare
corr_vals = numpy.array(
[numpy.interp(t2[t2_0:t2_n], t1, corr_dat[:, i].T) for i in range(1, rlen)]
)
max_err = numpy.amax(abs(corr_vals.T - tst_dat[t2_0:t2_n, 1:]))
return max_err
|