File: log_db_test.py

package info (click to toggle)
golang-github-google-certificate-transparency 0.0~git20160709.0.0f6e3d1~ds1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye, buster
  • size: 5,676 kB
  • sloc: cpp: 35,278; python: 11,838; java: 1,911; sh: 1,885; makefile: 950; xml: 520; ansic: 225
file content (190 lines) | stat: -rw-r--r-- 8,030 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import abc
from ct.client.db import log_db
from ct.client.db import database
from ct.proto import client_pb2

# This class provides common tests for all CT log database implementations.
# It only inherits from object so that unittest won't attempt to run the test_*
# methods on this class. Derived classes should use multiple inheritance
# from LogDBTest and unittest.TestCase to get test automation.
class LogDBTest(object):
    """All LogDB tests should derive from this class as well as
    unittest.TestCase."""
    __metaclass__ = abc.ABCMeta

    # Set up a default fake test log server and STH.
    default_log = client_pb2.CtLogMetadata()
    default_log.log_server = "test"
    default_log.log_id = "c29tZWtleWlk"  # b64("somekeyid")
    default_log.public_key_info.type = client_pb2.KeyInfo.ECDSA
    default_log.public_key_info.pem_key = "base64encodedkey"

    default_sth = client_pb2.AuditedSth()
    default_sth.sth.timestamp = 1234
    default_sth.sth.sha256_root_hash = "base64hash"
    default_sth.audit.status = client_pb2.VERIFIED

    @abc.abstractmethod
    def db(self):
        """Derived classes must override to initialize a database."""
        pass

    def test_add_log(self):
        self.db().add_log(LogDBTest.default_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, LogDBTest.default_log)
        self.assertRaises(StopIteration, generator.next)

    def test_update_log(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)

        new_log = client_pb2.CtLogMetadata()
        new_log.CopyFrom(LogDBTest.default_log)
        new_log.public_key_info.pem_key = "newkey"
        self.db().update_log(new_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, new_log)
        self.assertRaises(StopIteration, generator.next)

        # Should still be able to access STHs after updating log metadata
        read_sth = self.db().get_latest_sth(new_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_update_log_adds_log(self):
        self.db().update_log(LogDBTest.default_log)
        generator = self.db().logs()
        metadata = generator.next()
        self.assertEqual(metadata, LogDBTest.default_log)
        self.assertRaises(StopIteration, generator.next)

    def test_store_sth(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_store_sth_ignores_duplicate(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        duplicate_sth = client_pb2.AuditedSth()
        duplicate_sth.audit.status = client_pb2.VERIFY_ERROR
        self.db().store_sth(LogDBTest.default_log.log_server, duplicate_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertTrue(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_log_not_found_raises(self):
        self.assertRaises(database.KeyError, self.db().store_sth,
                          LogDBTest.default_log.log_server,
                          LogDBTest.default_sth)

    def test_get_latest_sth_returns_latest(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        new_sth = client_pb2.AuditedSth()
        new_sth.CopyFrom(LogDBTest.default_sth)
        new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp - 1
        self.db().store_sth(LogDBTest.default_log.log_server, new_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertIsNotNone(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_get_latest_sth_returns_none_if_empty(self):
        self.db().add_log(LogDBTest.default_log)
        self.assertIsNone(self.db().get_latest_sth(
            LogDBTest.default_log.log_server))

    def test_get_latest_sth_honours_log_server(self):
        self.db().add_log(LogDBTest.default_log)
        self.db().store_sth(LogDBTest.default_log.log_server,
                            LogDBTest.default_sth)
        new_sth = client_pb2.AuditedSth()
        new_sth.CopyFrom(LogDBTest.default_sth)
        new_sth.sth.timestamp = LogDBTest.default_sth.sth.timestamp + 1

        new_log = client_pb2.CtLogMetadata()
        new_log.log_server = "test2"
        new_log.log_id = "c29tZW90aGVya2V5aWQ="  # b64("someotherkeyid")
        self.db().add_log(new_log)

        new_sth.sth.sha256_root_hash = "hash2"
        self.db().store_sth(new_log.log_server, new_sth)
        read_sth = self.db().get_latest_sth(LogDBTest.default_log.log_server)
        self.assertIsNotNone(read_sth)
        self.assertEqual(LogDBTest.default_sth, read_sth)

    def test_scan_latest_sth_range_finds_all(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range(
            LogDBTest.default_log.log_server)
        for i in range(3, -1, -1):
            sth = generator.next()
            # Scan runs in descending timestamp order
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)

        self.assertRaises(StopIteration, generator.next)

    def test_scan_latest_sth_range_honours_log_server(self):
        for i in range(4):
            log = client_pb2.CtLogMetadata()
            log.log_server = "test-%d" % i
            self.db().add_log(log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth("test-%d" % i, sth)

        for i in range(4):
            generator = self.db().scan_latest_sth_range("test-%d" % i)
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % i)

    def test_scan_latest_sth_range_honours_range(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", start=1, end=2)
        for i in range(2):
            sth = generator.next()
            self.assertEqual(sth.sth.timestamp, 2-i)
            self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % (2-i))

        self.assertRaises(StopIteration, generator.next)

    def test_scan_latest_sth_range_honours_limit(self):
        self.db().add_log(LogDBTest.default_log)
        for i in range(4):
            sth = client_pb2.AuditedSth()
            sth.sth.timestamp = i
            sth.sth.sha256_root_hash = "hash-%d" % i
            self.db().store_sth(LogDBTest.default_log.log_server, sth)

        generator = self.db().scan_latest_sth_range("test", limit=1)
        sth = generator.next()
        # Returns most recent
        self.assertEqual(sth.sth.timestamp, 3)
        self.assertEqual(sth.sth.sha256_root_hash, "hash-%d" % 3)

        self.assertRaises(StopIteration, generator.next)