1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
from sqlalchemy import schema, exceptions, util, sql, types
import StringIO, sys, re
from sqlalchemy.engine import base, default
"""Provide a thread-local transactional wrapper around the basic ComposedSQLEngine.
Multiple calls to engine.connect() will return the same connection for
the same thread. also provides begin/commit methods on the engine
itself which correspond to a thread-local transaction.
"""
class TLSession(object):
def __init__(self, engine):
self.engine = engine
self.__tcount = 0
def get_connection(self, close_with_result=False):
try:
return self.__transaction._increment_connect()
except AttributeError:
return TLConnection(self, close_with_result=close_with_result)
def reset(self):
try:
self.__transaction._force_close()
del self.__transaction
del self.__trans
except AttributeError:
pass
self.__tcount = 0
def in_transaction(self):
return self.__tcount > 0
def begin(self):
if self.__tcount == 0:
self.__transaction = self.get_connection()
self.__trans = self.__transaction._begin()
self.__tcount += 1
return self.__trans
def rollback(self):
if self.__tcount > 0:
try:
self.__trans._rollback_impl()
finally:
self.reset()
def commit(self):
if self.__tcount == 1:
try:
self.__trans._commit_impl()
finally:
self.reset()
elif self.__tcount > 1:
self.__tcount -= 1
def is_begun(self):
return self.__tcount > 0
class TLConnection(base.Connection):
def __init__(self, session, close_with_result):
base.Connection.__init__(self, session.engine, close_with_result=close_with_result)
self.__session = session
self.__opencount = 1
session = property(lambda s:s.__session)
def _increment_connect(self):
self.__opencount += 1
return self
def _create_transaction(self, parent):
return TLTransaction(self, parent)
def _begin(self):
return base.Connection.begin(self)
def in_transaction(self):
return self.session.in_transaction()
def begin(self):
return self.session.begin()
def close(self):
if self.__opencount == 1:
base.Connection.close(self)
self.__opencount -= 1
def _force_close(self):
self.__opencount = 0
base.Connection.close(self)
class TLTransaction(base.Transaction):
def _commit_impl(self):
base.Transaction.commit(self)
def _rollback_impl(self):
base.Transaction.rollback(self)
def commit(self):
self.connection.session.commit()
def rollback(self):
self.connection.session.rollback()
class TLEngine(base.Engine):
"""An Engine that includes support for thread-local managed transactions.
This engine is better suited to be used with threadlocal Pool
object.
"""
def __init__(self, *args, **kwargs):
"""The TLEngine relies upon the ConnectionProvider having
"threadlocal" behavior, so that once a connection is checked out
for the current thread, you get that same connection
repeatedly.
"""
super(TLEngine, self).__init__(*args, **kwargs)
self.context = util.ThreadLocal()
def raw_connection(self):
"""Return a DBAPI connection."""
return self.connection_provider.get_connection()
def connect(self, **kwargs):
"""Return a Connection that is not thread-locally scoped.
This is the equivalent to calling ``connect()`` on a
ComposedSQLEngine.
"""
return base.Connection(self, self.connection_provider.unique_connection())
def _session(self):
if not hasattr(self.context, 'session'):
self.context.session = TLSession(self)
return self.context.session
session = property(_session, doc="returns the current thread's TLSession")
def contextual_connect(self, **kwargs):
"""Return a TLConnection which is thread-locally scoped."""
return self.session.get_connection(**kwargs)
def begin(self):
return self.session.begin()
def commit(self):
self.session.commit()
def rollback(self):
self.session.rollback()
class TLocalConnectionProvider(default.PoolConnectionProvider):
def unique_connection(self):
return self._pool.unique_connection()
|