1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
|
"""An sqlite database for storing nodes and key/value pairs."""
from datetime import datetime, timedelta
from pysqlite2 import dbapi2 as sqlite
from binascii import a2b_base64, b2a_base64
from time import sleep
import os
from twisted.trial import unittest
class DBExcept(Exception):
pass
class khash(str):
"""Dummy class to convert all hashes to base64 for storing in the DB."""
class dht_value(str):
"""Dummy class to convert all DHT values to base64 for storing in the DB."""
# Initialize the database to work with 'khash' objects (binary strings)
sqlite.register_adapter(khash, b2a_base64)
sqlite.register_converter("KHASH", a2b_base64)
sqlite.register_converter("khash", a2b_base64)
# Initialize the database to work with DHT values (binary strings)
sqlite.register_adapter(dht_value, b2a_base64)
sqlite.register_converter("DHT_VALUE", a2b_base64)
sqlite.register_converter("dht_value", a2b_base64)
class DB:
"""An sqlite database for storing persistent node info and key/value pairs.
@type db: C{string}
@ivar db: the database file to use
@type conn: L{pysqlite2.dbapi2.Connection}
@ivar conn: an open connection to the sqlite database
"""
def __init__(self, db):
"""Load or create the database file.
@type db: C{string}
@param db: the database file to use
"""
self.db = db
try:
os.stat(db)
except OSError:
self._createNewDB(db)
else:
self._loadDB(db)
if sqlite.version_info < (2, 1):
sqlite.register_converter("TEXT", str)
sqlite.register_converter("text", str)
else:
self.conn.text_factory = str
#{ Loading the DB
def _loadDB(self, db):
"""Open a new connection to the existing database file"""
try:
self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
except:
import traceback
raise DBExcept, "Couldn't open DB", traceback.format_exc()
def _createNewDB(self, db):
"""Open a connection to a new database and create the necessary tables."""
self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
c = self.conn.cursor()
c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
"PRIMARY KEY (key, value))")
c.execute("CREATE INDEX kv_key ON kv(key)")
c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
self.conn.commit()
def close(self):
self.conn.close()
#{ This node's ID
def getSelfNode(self):
"""Retrieve this node's ID from a previous run of the program."""
c = self.conn.cursor()
c.execute('SELECT id FROM self WHERE num = 0')
id = c.fetchone()
if id:
return id[0]
else:
return None
def saveSelfNode(self, id):
"""Store this node's ID for a subsequent run of the program."""
c = self.conn.cursor()
c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
self.conn.commit()
#{ Routing table
def dumpRoutingTable(self, buckets):
"""Save routing table nodes to the database."""
c = self.conn.cursor()
c.execute("DELETE FROM nodes WHERE id NOT NULL")
for bucket in buckets:
for node in bucket.nodes:
c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
self.conn.commit()
def getRoutingTable(self):
"""Load routing table nodes from database."""
c = self.conn.cursor()
c.execute("SELECT * FROM nodes")
return c.fetchall()
#{ Key/value pairs
def retrieveValues(self, key):
"""Retrieve values from the database."""
c = self.conn.cursor()
c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
l = []
rows = c.fetchall()
for row in rows:
l.append(row[0])
return l
def countValues(self, key):
"""Count the number of values in the database."""
c = self.conn.cursor()
c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
res = 0
row = c.fetchone()
if row:
res = row[0]
return res
def storeValue(self, key, value):
"""Store or update a key and value."""
c = self.conn.cursor()
c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)",
(khash(key), dht_value(value), datetime.now()))
self.conn.commit()
def expireValues(self, expireAfter):
"""Expire older values after expireAfter seconds."""
t = datetime.now() - timedelta(seconds=expireAfter)
c = self.conn.cursor()
c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
self.conn.commit()
def keyStats(self):
"""Count the total number of keys and values in the database.
@rtype: (C{int}, C{int})
@return: the number of distinct keys and total values in the database
"""
c = self.conn.cursor()
c.execute("SELECT COUNT(value) as num_values FROM kv")
values = 0
row = c.fetchone()
if row:
values = row[0]
c.execute("SELECT COUNT(key) as num_keys FROM (SELECT DISTINCT key FROM kv)")
keys = 0
row = c.fetchone()
if row:
keys = row[0]
return keys, values
class TestDB(unittest.TestCase):
"""Tests for the khashmir database."""
timeout = 5
db = '/tmp/khashmir.db'
key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
def setUp(self):
self.store = DB(self.db)
def test_selfNode(self):
self.store.saveSelfNode(self.key)
self.failUnlessEqual(self.store.getSelfNode(), self.key)
def test_Value(self):
self.store.storeValue(self.key, self.key)
self.failUnlessEqual(self.store.countValues(self.key), 1)
val = self.store.retrieveValues(self.key)
self.failUnlessEqual(len(val), 1)
self.failUnlessEqual(val[0], self.key)
def test_expireValues(self):
self.store.storeValue(self.key, self.key)
sleep(2)
self.store.storeValue(self.key, self.key+self.key)
self.store.expireValues(1)
val = self.store.retrieveValues(self.key)
self.failUnlessEqual(len(val), 1)
self.failUnlessEqual(val[0], self.key+self.key)
def test_RoutingTable(self):
class dummy:
id = self.key
host = "127.0.0.1"
port = 9977
def contents(self):
return (self.id, self.host, self.port)
dummy2 = dummy()
dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
dummy2.host = '205.23.67.124'
dummy2.port = 12345
class bl:
def __init__(self):
self.nodes = []
bl1 = bl()
bl1.nodes.append(dummy())
bl2 = bl()
bl2.nodes.append(dummy2)
buckets = [bl1, bl2]
self.store.dumpRoutingTable(buckets)
rt = self.store.getRoutingTable()
self.failUnlessIn(dummy().contents(), rt)
self.failUnlessIn(dummy2.contents(), rt)
def tearDown(self):
self.store.close()
os.unlink(self.db)
|