File: test_PDB_KDTree.py

package info (click to toggle)
python-biopython 1.78%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 65,756 kB
  • sloc: python: 221,141; xml: 178,777; ansic: 13,369; sql: 1,208; makefile: 131; sh: 70
file content (191 lines) | stat: -rw-r--r-- 7,776 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Copyright 2009-2010 by Eric Talevich.  All rights reserved.
# Revisions copyright 2010 by Peter Cock.  All rights reserved.
#
# Converted by Eric Talevich from an older unit test copyright 2002
# by Thomas Hamelryck.
#
# This file is part of the Biopython distribution and governed by your
# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
# Please see the LICENSE file that should have been included as part of this
# package.

"""Unit tests for those parts of the Bio.PDB module using Bio.PDB.kdtrees."""

import unittest

try:
    from numpy import array, dot, sqrt, argsort
    from numpy.random import random
except ImportError:
    from Bio import MissingExternalDependencyError

    raise MissingExternalDependencyError(
        "Install NumPy if you want to use Bio.PDB."
    ) from None

try:
    from Bio.PDB import kdtrees
except ImportError:
    from Bio import MissingExternalDependencyError

    raise MissingExternalDependencyError(
        "C module Bio.PDB.kdtrees not compiled"
    ) from None

from Bio.PDB.NeighborSearch import NeighborSearch


class NeighborTest(unittest.TestCase):
    def test_neighbor_search(self):
        """NeighborSearch: Find nearby randomly generated coordinates.

        Based on the self test in Bio.PDB.NeighborSearch.
        """

        class RandomAtom:
            def __init__(self):
                self.coord = 100 * random(3)

            def get_coord(self):
                return self.coord

        for i in range(0, 20):
            atoms = [RandomAtom() for j in range(100)]
            ns = NeighborSearch(atoms)
            hits = ns.search_all(5.0)
            self.assertIsInstance(hits, list)
            self.assertGreaterEqual(len(hits), 0)
        x = array([250, 250, 250])  # Far away from our random atoms
        self.assertEqual([], ns.search(x, 5.0, "A"))
        self.assertEqual([], ns.search(x, 5.0, "R"))
        self.assertEqual([], ns.search(x, 5.0, "C"))
        self.assertEqual([], ns.search(x, 5.0, "M"))
        self.assertEqual([], ns.search(x, 5.0, "S"))


class KDTreeTest(unittest.TestCase):

    nr_points = 5000  # number of points used in test
    bucket_size = 5  # number of points per tree node
    radius = 0.05  # radius of search (typically 0.05 or so)
    query_radius = 10  # radius of search

    def test_KDTree_exceptions(self):
        bucket_size = self.bucket_size
        nr_points = self.nr_points
        radius = self.radius
        coords = random((nr_points, 3)) * 100000000000000
        with self.assertRaises(Exception) as context:
            kdt = kdtrees.KDTree(coords, bucket_size)
        self.assertIn(
            "coordinate values should lie between -1e6 and 1e6", str(context.exception)
        )
        with self.assertRaises(Exception) as context:
            kdt = kdtrees.KDTree(random((nr_points, 3 - 2)), bucket_size)
        self.assertIn("expected a Nx3 numpy array", str(context.exception))

    def test_KDTree_point_search(self):
        """Test searching all points within a certain radius of center.

        Using the kdtrees C module, search all point pairs that are
        within radius, and compare the results to a manual search.
        """
        bucket_size = self.bucket_size
        nr_points = self.nr_points
        for radius in (self.radius, 100 * self.radius):
            for i in range(0, 10):
                # kd tree search
                coords = random((nr_points, 3))
                center = random(3)
                kdt = kdtrees.KDTree(coords, bucket_size)
                points1 = kdt.search(center, radius)
                points1.sort(key=lambda point: point.index)  # noqa: E731
                # manual search
                points2 = []
                for i in range(0, nr_points):
                    p = coords[i]
                    v = p - center
                    r = sqrt(dot(v, v))
                    if r <= radius:
                        point2 = kdtrees.Point(i, r)
                        points2.append(point2)
                # compare results
                self.assertEqual(len(points1), len(points2))
                for point1, point2 in zip(points1, points2):
                    self.assertEqual(point1.index, point2.index)
                    self.assertAlmostEqual(point1.radius, point2.radius)

    def test_KDTree_neighbor_search_simple(self):
        """Test all fixed radius neighbor search.

        Test all fixed radius neighbor search using the KD tree C
        module, and compare the results to those of a simple but
        slow algorithm.
        """
        bucket_size = self.bucket_size
        nr_points = self.nr_points
        radius = self.radius
        for i in range(0, 10):
            # KD tree search
            coords = random((nr_points, 3))
            kdt = kdtrees.KDTree(coords, bucket_size)
            neighbors1 = kdt.neighbor_search(radius)
            # same search, using a simple but slow algorithm
            neighbors2 = kdt.neighbor_simple_search(radius)
            # compare results
            self.assertEqual(len(neighbors1), len(neighbors2))
            key = lambda neighbor: (neighbor.index1, neighbor.index2)  # noqa: E731
            neighbors1.sort(key=key)
            neighbors2.sort(key=key)
            for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
                self.assertEqual(neighbor1.index1, neighbor2.index1)
                self.assertEqual(neighbor1.index2, neighbor2.index2)
                self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)

    def test_KDTree_neighbor_search_manual(self):
        """Test all fixed radius neighbor search.

        Test all fixed radius neighbor search using the KD tree C
        module, and compare the results to those of a manual search.
        """
        bucket_size = self.bucket_size
        nr_points = self.nr_points // 10  # fewer points to speed up the test
        for radius in (self.radius, 3 * self.radius):
            for i in range(0, 5):
                # KD tree search
                coords = random((nr_points, 3))
                kdt = kdtrees.KDTree(coords, bucket_size)
                neighbors1 = kdt.neighbor_search(radius)
                # manual search
                neighbors2 = []
                indices = argsort(coords[:, 0])
                for j1 in range(nr_points):
                    index1 = indices[j1]
                    p1 = coords[index1]
                    for j2 in range(j1 + 1, nr_points):
                        index2 = indices[j2]
                        p2 = coords[index2]
                        if p2[0] - p1[0] > radius:
                            break
                        v = p1 - p2
                        r = sqrt(dot(v, v))
                        if r <= radius:
                            if index1 < index2:
                                i1, i2 = index1, index2
                            else:
                                i1, i2 = index2, index1
                            neighbor = kdtrees.Neighbor(i1, i2, r)
                            neighbors2.append(neighbor)
                self.assertEqual(len(neighbors1), len(neighbors2))
                key = lambda neighbor: (neighbor.index1, neighbor.index2)  # noqa: E731
                neighbors1.sort(key=key)
                neighbors2.sort(key=key)
                for neighbor1, neighbor2 in zip(neighbors1, neighbors2):
                    self.assertEqual(neighbor1.index1, neighbor2.index1)
                    self.assertEqual(neighbor1.index2, neighbor2.index2)
                    self.assertAlmostEqual(neighbor1.radius, neighbor2.radius)


if __name__ == "__main__":
    runner = unittest.TextTestRunner(verbosity=2)
    unittest.main(testRunner=runner)