1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
import numpy as np
from mne.connectivity import seed_target_indices
def test_indices():
"""Test connectivity indexing methods."""
n_seeds_test = [1, 3, 4]
n_targets_test = [2, 3, 200]
rng = np.random.RandomState(42)
for n_seeds in n_seeds_test:
for n_targets in n_targets_test:
idx = rng.permutation(np.arange(n_seeds + n_targets))
seeds = idx[:n_seeds]
targets = idx[n_seeds:]
indices = seed_target_indices(seeds, targets)
assert len(indices) == 2
assert len(indices[0]) == len(indices[1])
assert len(indices[0]) == n_seeds * n_targets
for seed in seeds:
assert np.sum(indices[0] == seed) == n_targets
for target in targets:
assert np.sum(indices[1] == target) == n_seeds
|