File: issue_85.py

package info (click to toggle)
diskcache 5.6.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,364 kB
  • sloc: python: 7,026; makefile: 20
file content (142 lines) | stat: -rw-r--r-- 3,765 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""Test Script for Issue #85

$ export PYTHONPATH=`pwd`
$ python tests/issue_85.py
"""

import collections
import os
import random
import shutil
import sqlite3
import threading
import time

import django


def remove_cache_dir():
    print('REMOVING CACHE DIRECTORY')
    shutil.rmtree('.cache', ignore_errors=True)


def init_django():
    global shard
    print('INITIALIZING DJANGO')
    os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings')
    django.setup()
    from django.core.cache import cache

    shard = cache._cache._shards[0]


def multi_threading_init_test():
    print('RUNNING MULTI-THREADING INIT TEST')
    from django.core.cache import cache

    def run():
        cache.get('key')

    threads = [threading.Thread(target=run) for _ in range(50)]
    _ = [thread.start() for thread in threads]
    _ = [thread.join() for thread in threads]


def show_sqlite_compile_options():
    print('SQLITE COMPILE OPTIONS')
    options = shard._sql('pragma compile_options').fetchall()
    print('\n'.join(val for val, in options))


def create_data_table():
    print('CREATING DATA TABLE')
    shard._con.execute('create table data (x)')
    nums = [(num,) for num in range(1000)]
    shard._con.executemany('insert into data values (?)', nums)


commands = {
    'begin/read/write': [
        'BEGIN',
        'SELECT MAX(x) FROM data',
        'UPDATE data SET x = x + 1',
        'COMMIT',
    ],
    'begin/write/read': [
        'BEGIN',
        'UPDATE data SET x = x + 1',
        'SELECT MAX(x) FROM data',
        'COMMIT',
    ],
    'begin immediate/read/write': [
        'BEGIN IMMEDIATE',
        'SELECT MAX(x) FROM data',
        'UPDATE data SET x = x + 1',
        'COMMIT',
    ],
    'begin immediate/write/read': [
        'BEGIN IMMEDIATE',
        'UPDATE data SET x = x + 1',
        'SELECT MAX(x) FROM data',
        'COMMIT',
    ],
    'begin exclusive/read/write': [
        'BEGIN EXCLUSIVE',
        'SELECT MAX(x) FROM data',
        'UPDATE data SET x = x + 1',
        'COMMIT',
    ],
    'begin exclusive/write/read': [
        'BEGIN EXCLUSIVE',
        'UPDATE data SET x = x + 1',
        'SELECT MAX(x) FROM data',
        'COMMIT',
    ],
}


values = collections.deque()


def run(statements):
    ident = threading.get_ident()
    try:
        for index, statement in enumerate(statements):
            if index == (len(statements) - 1):
                values.append(('COMMIT', ident))
            time.sleep(random.random() / 10.0)
            shard._sql(statement)
            if index == 0:
                values.append(('BEGIN', ident))
    except sqlite3.OperationalError:
        values.append(('ERROR', ident))


def test_transaction_errors():
    for key, statements in commands.items():
        print(f'RUNNING {key}')
        values.clear()
        threads = []
        for _ in range(100):
            thread = threading.Thread(target=run, args=(statements,))
            threads.append(thread)
        _ = [thread.start() for thread in threads]
        _ = [thread.join() for thread in threads]
        errors = [pair for pair in values if pair[0] == 'ERROR']
        begins = [pair for pair in values if pair[0] == 'BEGIN']
        commits = [pair for pair in values if pair[0] == 'COMMIT']
        print('Error count:', len(errors))
        print('Begin count:', len(begins))
        print('Commit count:', len(commits))
        begin_idents = [ident for _, ident in begins]
        commit_idents = [ident for _, ident in commits]
        print('Serialized:', begin_idents == commit_idents)


if __name__ == '__main__':
    remove_cache_dir()
    init_django()
    multi_threading_init_test()
    show_sqlite_compile_options()
    create_data_table()
    test_transaction_errors()