1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
|
# Copyright 2009-2010 by Eric Talevich. All rights reserved.
# Revisions copyright 2010 by Peter Cock. All rights reserved.
#
# Converted by Eric Talevich from an older unit test copyright 2002
# by Thomas Hamelryck.
#
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.
"""Unit tests for those parts of the Bio.PDB module using Bio.PDB.kdtrees."""
import unittest
try:
from numpy import array, dot, sqrt, argsort
from numpy.random import random
except ImportError:
from Bio import MissingExternalDependencyError
raise MissingExternalDependencyError(
"Install NumPy if you want to use Bio.PDB."
) from None
try:
from Bio.PDB import kdtrees
except ImportError:
from Bio import MissingExternalDependencyError
raise MissingExternalDependencyError(
"C module Bio.PDB.kdtrees not compiled"
) from None
from Bio.PDB.NeighborSearch import NeighborSearch
class NeighborTest(unittest.TestCase):
def test_neighbor_search(self):
"""NeighborSearch: Find nearby randomly generated coordinates.
Based on the self test in Bio.PDB.NeighborSearch.
"""
class RandomAtom:
def __init__(self):
self.coord = 100 * random(3)
def get_coord(self):
return self.coord
for i in range(0, 20):
atoms = [RandomAtom() for j in range(100)]
ns = NeighborSearch(atoms)
hits = ns.search_all(5.0)
self.assertIsInstance(hits, list)
self.assertGreaterEqual(len(hits), 0)
x = array([250, 250, 250]) # Far away from our random atoms
self.assertEqual([], ns.search(x, 5.0, "A"))
self.assertEqual([], ns.search(x, 5.0, "R"))
self.assertEqual([], ns.search(x, 5.0, "C"))
self.assertEqual([], ns.search(x, 5.0, "M"))
self.assertEqual([], ns.search(x, 5.0, "S"))
class KDTreeTest(unittest.TestCase):
nr_points = 5000 # number of points used in test
bucket_size = 5 # number of points per tree node
radius = 0.05 # radius of search (typically 0.05 or so)
query_radius = 10 # radius of search
def test_KDTree_exceptions(self):
bucket_size = self.bucket_size
nr_points = self.nr_points
radius = self.radius
coords = random((nr_points, 3)) * 100000000000000
with self.assertRaises(Exception) as context:
kdt = kdtrees.KDTree(coords, bucket_size)
self.assertIn(
"coordinate values should lie between -1e6 and 1e6", str(context.exception)
)
with self.assertRaises(Exception) as context:
kdt = kdtrees.KDTree(random((nr_points, 3 - 2)), bucket_size)
self.assertIn("expected a Nx3 numpy array", str(context.exception))
def test_KDTree_point_search(self):
"""Test searching all points within a certain radius of center.
Using the kdtrees C module, search all point pairs that are
within radius, and compare the results to a manual search.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points
for radius in (self.radius, 100 * self.radius):
for i in range(0, 10):
# kd tree search
coords = random((nr_points, 3))
center = random(3)
kdt = kdtrees.KDTree(coords, bucket_size)
points1 = kdt.search(center, radius)
points1.sort(key=lambda point: point.index) # noqa: E731
# manual search
points2 = []
for i in range(0, nr_points):
p = coords[i]
v = p - center
r = sqrt(dot(v, v))
if r <= radius:
point2 = kdtrees.Point(i, r)
points2.append(point2)
# compare results
self.assertEqual(len(points1), len(points2))
for point1, point2 in zip(points1, points2):
self.assertEqual(point1.index, point2.index)
self.assertAlmostEqual(point1.radius, point2.radius)
def test_KDTree_neighbor_search_simple(self):
"""Test all fixed radius neighbor search.
Test all fixed radius neighbor search using the KD tree C
module, and compare the results to those of a simple but
slow algorithm.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points
radius = self.radius
for i in range(0, 10):
# KD tree search
coords = random((nr_points, 3))
kdt = kdtrees.KDTree(coords, bucket_size)
neighbors1 = kdt.neighbor_search(radius)
# same search, using a simple but slow algorithm
neighbors2 = kdt.neighbor_simple_search(radius)
# compare results
self.assertEqual(len(neighbors1), len(neighbors2))
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
neighbors1.sort(key=key)
neighbors2.sort(key=key)
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
self.assertEqual(neighbor1.index1, neighbor2.index1)
self.assertEqual(neighbor1.index2, neighbor2.index2)
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
def test_KDTree_neighbor_search_manual(self):
"""Test all fixed radius neighbor search.
Test all fixed radius neighbor search using the KD tree C
module, and compare the results to those of a manual search.
"""
bucket_size = self.bucket_size
nr_points = self.nr_points // 10 # fewer points to speed up the test
for radius in (self.radius, 3 * self.radius):
for i in range(0, 5):
# KD tree search
coords = random((nr_points, 3))
kdt = kdtrees.KDTree(coords, bucket_size)
neighbors1 = kdt.neighbor_search(radius)
# manual search
neighbors2 = []
indices = argsort(coords[:, 0])
for j1 in range(nr_points):
index1 = indices[j1]
p1 = coords[index1]
for j2 in range(j1 + 1, nr_points):
index2 = indices[j2]
p2 = coords[index2]
if p2[0] - p1[0] > radius:
break
v = p1 - p2
r = sqrt(dot(v, v))
if r <= radius:
if index1 < index2:
i1, i2 = index1, index2
else:
i1, i2 = index2, index1
neighbor = kdtrees.Neighbor(i1, i2, r)
neighbors2.append(neighbor)
self.assertEqual(len(neighbors1), len(neighbors2))
key = lambda neighbor: (neighbor.index1, neighbor.index2) # noqa: E731
neighbors1.sort(key=key)
neighbors2.sort(key=key)
for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
self.assertEqual(neighbor1.index1, neighbor2.index1)
self.assertEqual(neighbor1.index2, neighbor2.index2)
self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)
if __name__ == "__main__":
runner = unittest.TextTestRunner(verbosity=2)
unittest.main(testRunner=runner)
|