File: sqlite_log_db.py

package info (click to toggle)
golang-github-google-certificate-transparency 0.0~git20160709.0.0f6e3d1~ds1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bullseye, buster
  • size: 5,676 kB
  • sloc: cpp: 35,278; python: 11,838; java: 1,911; sh: 1,885; makefile: 950; xml: 520; ansic: 225
file content (151 lines) | stat: -rw-r--r-- 6,703 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import logging
import sqlite3
import time

from ct.client.db import log_db
from ct.client.db import database
from ct.proto import client_pb2

class SQLiteLogDB(log_db.LogDB):
    def __init__(self, connection_manager):
        """Initialize the database and tables.
        Args:
            connection_manager: an SQLiteConnectionManager object."""
        self.__mgr = connection_manager

        with self.__mgr.get_connection() as conn:
            # TODO(ekasper): give users control of table names via flags so
            # we can explicitly avoid conflicts between database objects
            # sharing the same underlying SQLiteConnection.
            conn.execute("CREATE TABLE IF NOT EXISTS logs("
                         "id INTEGER PRIMARY KEY, log_server TEXT UNIQUE, "
                         "metadata BLOB)")
            conn.execute("CREATE TABLE IF NOT EXISTS sths(log_id INTEGER, "
                         "fetch_timestamp INTEGER,"
                         "timestamp INTEGER, sth_data BLOB, "
                         "audit_info BLOB,"
                         "UNIQUE(log_id, timestamp, sth_data, audit_info) ON "
                         "CONFLICT IGNORE,"
                         "FOREIGN KEY(log_id) REFERENCES logs(id))")
            conn.execute("CREATE INDEX IF NOT EXISTS sth_by_timestamp on sths("
                         "log_id, timestamp)")
        self.__tables = ["logs", "sths"]

    def __repr__(self):
        return "%r(db: %r)" % (self.__class__.__name__, self.__mgr)

    def __str__(self):
        return "%s(db: %s, tables: %s): " % (self.__class__.__name__,
                                             self.__mgr, self.__tables)

    def __encode_log_metadata(self, metadata):
        log_server = metadata.log_server
        local_metadata = client_pb2.CtLogMetadata()
        local_metadata.CopyFrom(metadata)
        local_metadata.ClearField("log_server")
        return log_server, sqlite3.Binary(local_metadata.SerializeToString())

    def __decode_log_metadata(self, log_server, serialized_metadata):
        metadata = client_pb2.CtLogMetadata()
        metadata.ParseFromString(serialized_metadata)
        metadata.log_server = log_server
        return metadata

    def add_log(self, metadata):
        log_server, serialized_metadata = self.__encode_log_metadata(
            metadata)
        with self.__mgr.get_connection() as conn:
            try:
                conn.execute("INSERT INTO logs(log_server, metadata) "
                             "VALUES(?, ?)", (log_server, serialized_metadata))
            except sqlite3.IntegrityError:
                logging.warning("Ignoring duplicate log server %s", log_server)

    def update_log(self, metadata):
        log_server, serialized_metadata = self.__encode_log_metadata(
            metadata)
        with self.__mgr.get_connection() as conn:
            conn.execute("INSERT OR REPLACE INTO logs(id, log_server, "
                         "metadata) VALUES((SELECT id FROM logs WHERE "
                         "log_server = ?), ?, ?) ", (log_server, log_server,
                                                     serialized_metadata))

    def logs(self):
        with self.__mgr.get_connection() as conn:
            for log_server, metadata in conn.execute(
                "SELECT log_server, metadata FROM logs"):
                yield self.__decode_log_metadata(log_server, metadata)

    def _get_log_id(self, conn, log_server):
        res = conn.execute("SELECT id FROM logs WHERE log_server = ?",
                           (log_server,))
        try:
            log_id = res.next()
        except StopIteration:
            raise database.KeyError("Unknown log server: %s", log_server)
        return log_id[0]

    def get_log_id(self, log_server):
        with self.__mgr.get_connection() as conn:
            return self._get_log_id(conn, log_server)

    def __encode_sth(self, audited_sth):
        timestamp = audited_sth.sth.timestamp
        sth = client_pb2.SthResponse()
        sth.CopyFrom(audited_sth.sth)
        sth.ClearField("timestamp")
        audit = client_pb2.AuditInfo()
        audit.CopyFrom(audited_sth.audit)
        return (timestamp, sqlite3.Binary(sth.SerializeToString()),
                sqlite3.Binary(audit.SerializeToString()))

    def __decode_sth(self, sth_row):
        _, _, timestamp, serialized_sth, serialized_audit = sth_row
        audited_sth = client_pb2.AuditedSth()
        audited_sth.sth.ParseFromString(serialized_sth)
        audited_sth.sth.timestamp = timestamp
        audited_sth.audit.ParseFromString(serialized_audit)
        return audited_sth

    # This ignores a duplicate STH even if the audit data differs.
    # TODO(ekasper): add an update method for updating audit data, as needed.
    def store_sth(self, log_server, audited_sth):
        """Store the STH in the database.
        Will store the STH with a unique ID unless an exact copy already exists.
        Note: the fetch_timestamp is time of calling this function, not actual
        fetching timestamp.

        Args:
            log_server: the server name, i.e., the <log_server> path prefix
            audited_sth: a client_pb2.AuditedSth proto
        """
        timestamp, sth_data, audit_info = self.__encode_sth(audited_sth)
        with self.__mgr.get_connection() as conn:
            log_id = self._get_log_id(conn, log_server)
            conn.execute("INSERT INTO sths(log_id, fetch_timestamp, timestamp, "
                         "sth_data, audit_info) VALUES(?, ?, ?, ?, ?)",
                         (log_id, int(time.time()), timestamp, sth_data, audit_info))

    def get_latest_sth(self, log_server):
        row = None
        with self.__mgr.get_connection() as conn:
            log_id = self._get_log_id(conn, log_server)
            res = conn.execute("SELECT * FROM sths WHERE log_id = ? "
                               "ORDER BY timestamp DESC LIMIT 1", (log_id,))
            try:
                row = res.next()
            except StopIteration:
                pass
        if row is not None:
            return self.__decode_sth(row)

    def scan_latest_sth_range(self, log_server, start=0,
                              end=log_db.LogDB.timestamp_max, limit=0):
        sql_limit = -1 if not limit else limit
        with self.__mgr.get_connection() as conn:
            log_id = self._get_log_id(conn, log_server)
            for row in conn.execute(
                "SELECT * FROM sths WHERE log_id = ? "
                "AND timestamp >= ? AND timestamp <= ? ORDER BY timestamp DESC "
                "LIMIT ?", (log_id, start, end, sql_limit)):
                yield self.__decode_sth(row)