1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
|
from __future__ import print_function
import threading
import time
import random
import traceback
class DummyDatabase(object):
"""Key-value datastore"""
def __init__(self):
self.storage = {}
self.allowed_users = ["user123", "admin"]
def connect(self, user):
return Connection(self, user)
def __setitem__(self, key, value):
time.sleep(random.random()/10) # artificial delay
self.storage[key] = value
def __getitem__(self, item):
time.sleep(random.random()/10) # artificial delay
return self.storage[item]
class Connection(object):
"""
Connection to the key-value datastore with artificial limitation
that only a single thread may use the connection at the same time
"""
def __init__(self, db, user=None):
self.db = db
self.user = user
self.lock = threading.RLock()
def store(self, key, value, user=None):
user = user or self.user
assert user in self.db.allowed_users, "access denied"
if self.lock.acquire(blocking=False):
print("DB: user %s stores: %s = %s" % (user, key, value))
self.db[key] = value
self.lock.release()
else:
raise RuntimeError("ERROR: concurrent connection access (write) by multiple different threads")
def retrieve(self, key, user=None):
user = user or self.user
assert user in self.db.allowed_users, "access denied"
if self.lock.acquire(blocking=False):
print("DB: user %s retrieve: %s" % (user, key))
value = self.db[key]
self.lock.release()
return value
else:
raise RuntimeError("ERROR: concurrent connection access (read) by multiple different threads")
if __name__ == "__main__":
# first single threaded access
db = DummyDatabase()
conn = db.connect("user123")
for i in range(5):
conn.store("amount", 100+i)
conn.retrieve("amount")
# now multiple threads, should crash
class ClientThread(threading.Thread):
def __init__(self, conn):
super(ClientThread, self).__init__()
self.conn = conn
self.daemon = True
def run(self):
for i in range(5):
try:
self.conn.store("amount", 100+i)
except Exception:
traceback.print_exc()
try:
self.conn.retrieve("amount")
except Exception:
traceback.print_exc()
client1 = ClientThread(conn)
client2 = ClientThread(conn)
client1.start()
client2.start()
time.sleep(0.1)
client1.join()
client2.join()
|