File: redis.py

package info (click to toggle)
python-aioredlock 0.7.3-3
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 184 kB
  • sloc: python: 608; makefile: 2
file content (445 lines) | stat: -rw-r--r-- 17,098 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
import asyncio
import logging
import re
import time
from packaging.version import parse as parse_version
from itertools import groupby

import aioredis

from aioredlock.errors import LockError, LockAcquiringError, LockRuntimeError
from aioredlock.sentinel import Sentinel
from aioredlock.utility import clean_password


def all_equal(iterable):
    """Returns True if all the elements are equal to each other"""
    g = groupby(iterable)
    return next(g, True) and not next(g, False)


def raise_error(results, default_message):
    errors = [e for e in results if isinstance(e, BaseException)]
    if any(type(e) is LockRuntimeError for e in errors):
        raise [e for e in errors if type(e) is LockRuntimeError][0]
    elif any(type(e) is LockAcquiringError for e in errors):
        raise [e for e in errors if type(e) is LockAcquiringError][0]
    else:
        raise LockError(default_message) from errors[0]


class Instance:

    # KEYS[1] - lock resource key
    # ARGS[1] - lock unique identifier
    # ARGS[2] - expiration time in milliseconds
    SET_LOCK_SCRIPT = """
    local identifier = redis.call('get', KEYS[1])
    if not identifier or identifier == ARGV[1] then
        return redis.call("set", KEYS[1], ARGV[1], 'PX', ARGV[2])
    else
        return redis.error_reply('ERROR')
    end"""

    # KEYS[1] - lock resource key
    # ARGS[1] - lock unique identifier
    UNSET_LOCK_SCRIPT = """
    local identifier = redis.call('get', KEYS[1])
    if not identifier then
        return redis.status_reply('OK')
    elseif identifier == ARGV[1] then
        return redis.call("del", KEYS[1])
    else
        return redis.error_reply('ERROR')
    end"""

    # KEYS[1] - lock resource key
    GET_LOCK_TTL_SCRIPT = """
    local identifier = redis.call('get', KEYS[1])
    if not identifier then
        return redis.error_reply('ERROR')
    elseif identifier == ARGV[1] then
        return redis.call("TTL", KEYS[1])
    else
        return redis.error_reply('ERROR')
    end"""

    def __init__(self, connection):
        """
        Redis instance constructor

        Constructor takes single argument - a redis host address
        The address can be one of the following:
         * a dict - {'host': 'localhost', 'port': 6379,
                     'db': 0, 'password': 'pass'}
           all keys except host and port will be passed as kwargs to
           the aioredis.create_redis_pool();
         * an aioredlock.redis.Sentinel object;
         * a Redis URI - "redis://host:6379/0?encoding=utf-8";
         * a (host, port) tuple - ('localhost', 6379);
         * or a unix domain socket path string - "/path/to/redis.sock".
         * a redis connection pool.

        :param connection: redis host address (dict, tuple or str)
        """

        self.connection = connection

        self._pool = None
        self._lock = asyncio.Lock()

        self.set_lock_script_sha1 = None
        self.unset_lock_script_sha1 = None
        self.get_lock_ttl_script_sha1 = None

    @property
    def log(self):
        return logging.getLogger(__name__)

    def __repr__(self):
        connection_details = clean_password(self.connection)
        return "<%s(connection='%s'>" % (self.__class__.__name__, connection_details)

    @staticmethod
    async def _create_redis_pool(*args, **kwargs):
        """
        Adapter to support both aioredis-0.3.0 and aioredis-1.0.0
        For aioredis-1.0.0 and later calls:
            aioredis.create_redis_pool(*args, **kwargs)
        For aioredis-0.3.0 calls:
            aioredis.create_pool(*args, **kwargs)
        """

        if parse_version(aioredis.__version__) >= parse_version('1.0.0'):  # pragma no cover
            return await aioredis.create_redis_pool(*args, **kwargs)
        else:  # pragma no cover
            return await aioredis.create_pool(*args, **kwargs)

    async def _register_scripts(self, redis):
        tasks = []
        for script in [
                self.SET_LOCK_SCRIPT,
                self.UNSET_LOCK_SCRIPT,
                self.GET_LOCK_TTL_SCRIPT,
        ]:
            script = re.sub(r'^\s+', '', script, flags=re.M).strip()
            tasks.append(redis.script_load(script))
        (
            self.set_lock_script_sha1,
            self.unset_lock_script_sha1,
            self.get_lock_ttl_script_sha1,
        ) = (r.decode() if isinstance(r, bytes) else r for r in await asyncio.gather(*tasks))

    async def connect(self):
        """
        Get an connection for the self instance
        """
        address, redis_kwargs = (), {}

        if isinstance(self.connection, Sentinel):
            self._pool = await self.connection.get_master()
        elif isinstance(self.connection, dict):
            # a dict like {'host': 'localhost', 'port': 6379,
            #              'db': 0, 'password': 'pass'}
            kwargs = self.connection.copy()
            address = (
                kwargs.pop('host', 'localhost'),
                kwargs.pop('port', 6379)
            )
            redis_kwargs = kwargs
        elif isinstance(self.connection, aioredis.Redis):
            self._pool = self.connection
        else:
            # a tuple or list ('localhost', 6379)
            # a string "redis://host:6379/0?encoding=utf-8" or
            # a unix domain socket path "/path/to/redis.sock"
            address = self.connection

        if self._pool is None:
            if 'minsize' not in redis_kwargs:
                redis_kwargs['minsize'] = 1
            if 'maxsize' not in redis_kwargs:
                redis_kwargs['maxsize'] = 100
            async with self._lock:
                if self._pool is None:
                    self.log.debug('Connecting %s', repr(self))
                    self._pool = await self._create_redis_pool(address, **redis_kwargs)

        if self.set_lock_script_sha1 is None or self.unset_lock_script_sha1 is None:
            with await self._pool as redis:
                await self._register_scripts(redis)

        return await self._pool

    async def close(self):
        """
        Closes connection and resets pool
        """
        if self._pool is not None and not isinstance(self.connection, aioredis.Redis):
            self._pool.close()
            await self._pool.wait_closed()
        self._pool = None

    async def set_lock(self, resource, lock_identifier, lock_timeout, register_scripts=False):
        """
        Lock this instance and set lock expiration time to lock_timeout
        :param resource: redis key to set
        :param lock_identifier: uniquie id of lock
        :param lock_timeout: timeout for lock in seconds
        :raises: LockError if lock is not acquired
        """

        lock_timeout_ms = int(lock_timeout * 1000)

        try:
            with await self.connect() as redis:
                if register_scripts is True:
                    await self._register_scripts(redis)
                await redis.evalsha(
                    self.set_lock_script_sha1,
                    keys=[resource],
                    args=[lock_identifier, lock_timeout_ms]
                )
        except aioredis.errors.ReplyError as exc:  # script fault
            if exc.args[0].startswith('NOSCRIPT'):
                return await self.set_lock(resource, lock_identifier, lock_timeout, register_scripts=True)
            self.log.debug('Can not set lock "%s" on %s',
                           resource, repr(self))
            raise LockAcquiringError('Can not set lock') from exc
        except (aioredis.errors.RedisError, OSError) as exc:
            self.log.error('Can not set lock "%s" on %s: %s',
                           resource, repr(self), repr(exc))
            raise LockRuntimeError('Can not set lock') from exc
        except asyncio.CancelledError:
            self.log.debug('Lock "%s" is cancelled on %s',
                           resource, repr(self))
            raise
        except Exception:
            self.log.exception('Can not set lock "%s" on %s',
                               resource, repr(self))
            raise
        else:
            self.log.debug('Lock "%s" is set on %s', resource, repr(self))

    async def get_lock_ttl(self, resource, lock_identifier, register_scripts=False):
        """
        Fetch this instance and set lock expiration time to lock_timeout
        :param resource: redis key to get
        :param lock_identifier: unique id of the lock to get
        :param register_scripts: register redis, usually already done, so 'False'.
        :raises: LockError if lock is not available
        """
        try:
            with await self.connect() as redis:
                if register_scripts is True:
                    await self._register_scripts(redis)
                ttl = await redis.evalsha(
                    self.get_lock_ttl_script_sha1,
                    keys=[resource],
                    args=[lock_identifier]
                )
        except aioredis.errors.ReplyError as exc:  # script fault
            if exc.args[0].startswith('NOSCRIPT'):
                return await self.get_lock_ttl(resource, lock_identifier, register_scripts=True)
            self.log.debug('Can not get lock "%s" on %s',
                           resource, repr(self))
            raise LockAcquiringError('Can not get lock') from exc
        except (aioredis.errors.RedisError, OSError) as exc:
            self.log.error('Can not get lock "%s" on %s: %s',
                           resource, repr(self), repr(exc))
            raise LockRuntimeError('Can not get lock') from exc
        except asyncio.CancelledError:
            self.log.debug('Lock "%s" is cancelled on %s',
                           resource, repr(self))
            raise
        except Exception:
            self.log.exception('Can not get lock "%s" on %s',
                               resource, repr(self))
            raise
        else:
            self.log.debug('Lock "%s" with TTL %s is on %s', resource, ttl, repr(self))
            return ttl

    async def unset_lock(self, resource, lock_identifier, register_scripts=False):
        """
        Unlock this instance
        :param resource: redis key to set
        :param lock_identifier: uniquie id of lock
        :raises: LockError if the lock resource acquired with different lock_identifier
        """
        try:
            with await self.connect() as redis:
                if register_scripts is True:
                    await self._register_scripts(redis)
                await redis.evalsha(
                    self.unset_lock_script_sha1,
                    keys=[resource],
                    args=[lock_identifier]
                )
        except aioredis.errors.ReplyError as exc:  # script fault
            if exc.args[0].startswith('NOSCRIPT'):
                return await self.unset_lock(resource, lock_identifier, register_scripts=True)
            self.log.debug('Can not unset lock "%s" on %s',
                           resource, repr(self))
            raise LockAcquiringError('Can not unset lock') from exc
        except (aioredis.errors.RedisError, OSError) as exc:
            self.log.error('Can not unset lock "%s" on %s: %s',
                           resource, repr(self), repr(exc))
            raise LockRuntimeError('Can not unset lock') from exc
        except asyncio.CancelledError:
            self.log.debug('Lock "%s" unset is cancelled on %s',
                           resource, repr(self))
            raise
        except Exception:
            self.log.exception('Can not unset lock "%s" on %s',
                               resource, repr(self))
            raise
        else:
            self.log.debug('Lock "%s" is unset on %s', resource, repr(self))

    async def is_locked(self, resource):
        """
        Checks if the resource is locked by any redlock instance.

        :param resource: The resource string name to check
        :returns: True if locked else False
        """

        with await self.connect() as redis:
            lock_identifier = await redis.get(resource)
        if lock_identifier:
            return True
        else:
            return False


class Redis:

    def __init__(self, redis_connections):

        self.instances = []
        for connection in redis_connections:
            self.instances.append(Instance(connection))

    @property
    def log(self):
        return logging.getLogger(__name__)

    async def set_lock(self, resource, lock_identifier, lock_timeout=10.0):
        """
        Tries to set the lock to all the redis instances

        :param resource: The resource string name to lock
        :param lock_identifier: The id of the lock. A unique string
        :param lock_timeout: lock's lifetime
        :return float: The elapsed time that took to lock the instances
            in seconds
        :raises: LockRuntimeError or LockAcquiringError or LockError if the lock has not
            been set to at least (N/2 + 1) instances
        """
        start_time = time.monotonic()

        successes = await asyncio.gather(*[
            i.set_lock(resource, lock_identifier, lock_timeout) for
            i in self.instances
        ], return_exceptions=True)
        successful_sets = sum(s is None for s in successes)

        elapsed_time = time.monotonic() - start_time
        locked = successful_sets >= int(len(self.instances) / 2) + 1

        self.log.debug('Lock "%s" is set on %d/%d instances in %s seconds',
                       resource, successful_sets, len(self.instances), elapsed_time)

        if not locked:
            raise_error(successes, 'Can not acquire the lock "%s"' % resource)

        return elapsed_time

    async def get_lock_ttl(self, resource, lock_identifier=None):
        """
        Tries to get the lock from all the redis instances

        :param resource: The resource string name to fetch
        :param lock_identifier: The id of the lock. A unique string
        :return float: The TTL of that lock reported by redis
        :raises: LockRuntimeError or LockAcquiringError or LockError if the lock has not
            been set to at least (N/2 + 1) instances
        """
        start_time = time.monotonic()
        successes = await asyncio.gather(*[
            i.get_lock_ttl(resource, lock_identifier) for
            i in self.instances
        ], return_exceptions=True)
        successful_list = [s for s in successes if not isinstance(s, Exception)]
        # should check if all the value are approx. the same with math.isclose...
        locked = len(successful_list) >= int(len(self.instances) / 2) + 1
        success = all_equal(successful_list) and locked
        elapsed_time = time.monotonic() - start_time

        self.log.debug('Lock "%s" is set on %d/%d instances in %s seconds',
                       resource, len(successful_list), len(self.instances), elapsed_time)

        if not success:
            raise_error(successes, 'Could not fetch the TTL for lock "%s"' % resource)

        return successful_list[0]

    async def unset_lock(self, resource, lock_identifier):
        """
        Tries to unset the lock to all the redis instances

        :param resource: The resource string name to lock
        :param lock_identifier: The id of the lock. A unique string
        :return float: The elapsed time that took to lock the instances in iseconds
        :raises: LockRuntimeError or LockAcquiringError or LockError if the lock has no
            matching identifier in more then (N/2 - 1) instances
        """

        if not self.instances:
            return .0

        start_time = time.monotonic()

        successes = await asyncio.gather(*[
            i.unset_lock(resource, lock_identifier) for
            i in self.instances
        ], return_exceptions=True)
        successful_removes = sum(s is None for s in successes)

        elapsed_time = time.monotonic() - start_time
        unlocked = successful_removes >= int(len(self.instances) / 2) + 1

        self.log.debug('Lock "%s" is unset on %d/%d instances in %s seconds',
                       resource, successful_removes, len(self.instances), elapsed_time)

        if not unlocked:
            raise_error(successes, 'Can not release the lock')

        return elapsed_time

    async def is_locked(self, resource):
        """
        Checks if the resource is locked by any redlock instance.

        :param resource: The resource string name to lock
        :returns: True if locked else False
        """

        successes = await asyncio.gather(*[
            i.is_locked(resource) for
            i in self.instances
        ], return_exceptions=True)
        successful_sets = sum(s is True for s in successes)

        return successful_sets >= int(len(self.instances) / 2) + 1

    async def clear_connections(self):

        self.log.debug('Clearing connection')

        if self.instances:
            coros = []
            while self.instances:
                coros.append(self.instances.pop().close())
            await asyncio.gather(*(coros))