File: test_trie.py

package info (click to toggle)
python-biopython 1.73%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 57,852 kB
  • sloc: python: 169,977; xml: 97,539; ansic: 15,653; sql: 1,208; makefile: 159; sh: 63
file content (194 lines) | stat: -rw-r--r-- 7,734 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python

# This code is part of the Biopython distribution and governed by its
# license.  Please see the LICENSE file that should have been included
# as part of this package.

import random
import tempfile
import unittest
from string import ascii_lowercase
from io import BytesIO

import warnings
from Bio import BiopythonDeprecationWarning
with warnings.catch_warnings():
    warnings.simplefilter('ignore', BiopythonDeprecationWarning)
    warnings.simplefilter('ignore', RuntimeWarning)  # for the trie module
    try:
        from Bio import trie
        from Bio import triefind
    except ImportError:
        import os
        from Bio import MissingPythonDependencyError
        if os.name == "java":
            message = "Not available on Jython, Bio.trie requires compiled C code."
        else:
            message = "Could not import Bio.trie, check C code was compiled."
        raise MissingPythonDependencyError(message)


class TestTrie(unittest.TestCase):

    def test_get_set(self):
        trieobj = trie.trie()
        trieobj["hello world"] = "s1"
        trieobj["bye"] = "s2"
        trieobj["hell sucks"] = "s3"
        trieobj["hebee"] = "s4"
        self.assertEqual(trieobj["hello world"], "s1")
        self.assertEqual(trieobj["bye"], "s2")
        self.assertEqual(trieobj["hell sucks"], "s3")
        self.assertEqual(trieobj["hebee"], "s4")
        trieobj["blah"] = "s5"
        self.assertEqual(trieobj["blah"], "s5")
        self.assertEqual(trieobj.get("foobar"), None)
        self.assertEqual(len(trieobj), 5)
        trieobj["blah"] = "snew"
        self.assertEqual(trieobj["blah"], "snew")

    def test_prefix(self):
        trieobj = trie.trie()
        trieobj["hello"] = 5
        trieobj["he"] = 7
        trieobj["hej"] = 9
        trieobj["foo"] = "bar"
        k = sorted(trieobj.keys())
        self.assertEqual(k, ["foo", "he", "hej", "hello"])
        self.assertEqual(trieobj["hello"], 5)
        self.assertEqual(trieobj.get("bye"), None)
        self.assertIn("hello", trieobj)
        self.assertIn("he", trieobj)
        self.assertFalse("bye" in trieobj)
        self.assertTrue(trieobj.has_prefix("h"))
        self.assertTrue(trieobj.has_prefix("hel"))
        self.assertFalse(trieobj.has_prefix("foa"))
        self.assertFalse(trieobj.has_prefix("hello world"))
        self.assertEqual(len(trieobj), 4)
        k = sorted(trieobj.with_prefix("he"))
        self.assertEqual(k, ["he", "hej", "hello"])
        k = trieobj.with_prefix("l")
        self.assertEqual(k, [])
        k = trieobj.with_prefix("hej")
        self.assertEqual(k, ["hej"])
        k = trieobj.with_prefix("hejk")
        self.assertEqual(k, [])

    def test_save(self):
        trieobj = trie.trie()
        trieobj["foo"] = 1
        k = list(trieobj.keys())
        self.assertEqual(k, ["foo"])
        v = list(trieobj.values())
        self.assertEqual(v, [1])
        self.assertEqual(trieobj.get("bar", 99), 99)
        trieobj["hello"] = '55a'
        self.assertEqual(trieobj.get_approximate("foo", 0), [("foo", 1, 0)])
        self.assertEqual(trieobj.get_approximate("foo", 1), [("foo", 1, 0)])
        self.assertEqual(trieobj.get_approximate("foa", 0), [])
        self.assertEqual(trieobj.get_approximate("foa", 1), [("foo", 1, 1)])
        x = sorted(trieobj.get_approximate("foa", 2))
        self.assertEqual(x, [("foo", 1, 1), ("foo", 1, 2), ("foo", 1, 2)])
        # foo  foo-  foo-
        # foa  f-oa  fo-a
        # mismatch a->o
        # insertion after f, deletion of o
        # insertion after o, deletion of o
        x = trieobj.get_approximate("foo", 4)
        y = {}
        for z in x:
            y[z] = y.get(z, 0) + 1
        x = sorted(y.items())
        self.assertEqual(x, [(('foo', 1, 0), 1), (('hello', '55a', 4), 6)])
        h = BytesIO()
        trie.save(h, trieobj)
        h.seek(0)
        trieobj = trie.load(h)
        k = list(trieobj.keys())
        self.assertIn("foo", k)
        self.assertIn("hello", k)
        self.assertEqual(repr(trieobj["foo"]), '1')
        self.assertEqual(repr(trieobj["hello"]), "'55a'")

    def test_get_approximate(self):
        # Found bug, doesn't handle insertions and deletions at end properly.
        trieobj = trie.trie()
        trieobj["hello"] = 1
        self.assertEqual(trieobj.get_approximate('he', 2), [])
        self.assertEqual(trieobj.get_approximate('he', 3), [('hello', 1, 3)])
        self.assertEqual(trieobj.get_approximate('hello me!', 3), [])
        self.assertEqual(trieobj.get_approximate('hello me!', 4), [('hello', 1, 4)])
        self.assertEqual(trieobj.get_approximate('hello me!', 5), [('hello', 1, 4)])

    def test_with_prefix(self):
        trieobj = trie.trie()
        s = "BANANA"
        for i in range(len(s)):  # insert all suffixes into trie
            trieobj[s[i:]] = i
            self.assertEqual(trieobj[s[i:]], i)
        self.assertEqual(set(trieobj.values()), set(range(6)))
        self.assertEqual(set(['A', 'ANA', 'ANANA', 'BANANA', 'NA', 'NANA']),
                         set(trieobj.keys()))
        self.assertEqual(set(['NA', 'NANA']),
                         set(trieobj.with_prefix("N")))
        self.assertEqual(set(['NA', 'NANA']),
                         set(trieobj.with_prefix("NA")))
        self.assertEqual(set(['A', 'ANA', 'ANANA']),
                         set(trieobj.with_prefix("A")))
        self.assertEqual(set(['ANA', 'ANANA']),
                         set(trieobj.with_prefix("AN")))

    def test_large_save_load(self):
        """Generate random key/val pairs in three length categories.

        100 items in each category. Insert them into a trie and into a reference dict.
        Write the trie to a temp file and read it back, verify that trie entries match
        the reference dict.
        """
        cmp_dict = {}
        trieobj = trie.trie()
        self.assertEqual(trieobj.get("foobar"), None)
        for max_str_len in [100, 1000, 10000]:
            cmp_dict = {}
            for i in range(1000):
                key = ''.join([random.choice(ascii_lowercase) for _ in range(max_str_len)])
                val = ''.join([random.choice(ascii_lowercase) for _ in range(max_str_len)])
                trieobj[key] = val
                cmp_dict[key] = val
            for key in cmp_dict:
                self.assertEqual(trieobj[key], cmp_dict[key])

        with tempfile.TemporaryFile(mode='w+b') as f:
            trie.save(f, trieobj)
            f.seek(0)
            trieobj = trie.load(f)
        for key in cmp_dict:
            self.assertEqual(trieobj[key], cmp_dict[key])


class TestTrieFind(unittest.TestCase):

    def test_find(self):
        trieobj = trie.trie()
        trieobj["hello"] = 5
        trieobj["he"] = 7
        trieobj["hej"] = 9
        trieobj["foo"] = "bar"
        trieobj["wor"] = "ld"
        self.assertEqual(triefind.match("hello world!", trieobj), "hello")
        k = sorted(triefind.match_all("hello world!", trieobj))
        self.assertEqual(k, ["he", "hello"])
        k = sorted(triefind.find("hello world!", trieobj))
        self.assertEqual(k, [("he", 0, 2), ("hello", 0, 5), ("wor", 6, 9)])
        k = sorted(triefind.find_words("hello world!", trieobj))
        self.assertEqual(k, [("hello", 0, 5)])
        trieobj["world"] = "full"
        k = sorted(triefind.find("hello world!", trieobj))
        self.assertEqual(k, [("he", 0, 2), ("hello", 0, 5), ("wor", 6, 9), ("world", 6, 11)])
        k = sorted(triefind.find_words("hello world!", trieobj))
        self.assertEqual(k, [("hello", 0, 5), ("world", 6, 11)])


if __name__ == "__main__":
    runner = unittest.TextTestRunner(verbosity=2)
    unittest.main(testRunner=runner)