File: test_serialization_bwc.py

package info (click to toggle)
duckdb 1.5.1-2
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 299,196 kB
  • sloc: cpp: 865,414; ansic: 57,292; python: 18,871; sql: 12,663; lisp: 11,751; yacc: 7,412; lex: 1,682; sh: 747; makefile: 558
file content (226 lines) | stat: -rw-r--r-- 7,789 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import sqllogictest
from sqllogictest import SQLParserException, SQLLogicParser, SQLLogicTest
import duckdb
from typing import Optional
import argparse
import shutil
import os
import subprocess

# example usage: python3 scripts/test_serialization_bwc.py --old-source ../duckdb-bugfix --test-file test/sql/aggregate/aggregates/test_median.test

serialized_path = os.path.join('test', 'api', 'serialized_plans')
db_load_path = os.path.join(serialized_path, 'db_load.sql')
queries_path = os.path.join(serialized_path, 'queries.sql')
result_binary = os.path.join(serialized_path, 'serialized_plans.binary')
unittest_binary = os.path.join('build', 'debug', 'test', 'unittest')


def complete_query(q):
    q = q.strip()
    if q.endswith(';'):
        return q
    return q + ';'


def parse_test_file(filename):
    parser = SQLLogicParser()
    try:
        out: Optional[SQLLogicTest] = parser.parse(filename)
        if not out:
            raise SQLParserException(f"Test {filename} could not be parsed")
    except:
        return {'load': [], 'query': []}
    loop_count = 0
    load_statements = []
    query_statements = []
    for stmt in out.statements:
        if type(stmt) is sqllogictest.statement.skip.Skip:
            # mode skip - just skip entire test
            break
        if type(stmt) is sqllogictest.statement.loop.Loop or type(stmt) is sqllogictest.statement.foreach.Foreach:
            loop_count += 1
        if type(stmt) is sqllogictest.statement.endloop.Endloop:
            loop_count -= 1
        if loop_count > 0:
            # loops are ignored currently
            continue
        if not (
            type(stmt) is sqllogictest.statement.query.Query or type(stmt) is sqllogictest.statement.statement.Statement
        ):
            # only handle query and statement nodes for now
            continue
        if type(stmt) is sqllogictest.statement.statement.Statement:
            # skip expected errors
            if stmt.expected_result.type == sqllogictest.ExpectedResult.Type.ERROR:
                continue
        query = ' '.join(stmt.lines)
        try:
            sql_stmt_list = duckdb.extract_statements(query)
        except KeyboardInterrupt:
            raise
        except:
            continue
        for sql_stmt in sql_stmt_list:
            if sql_stmt.type == duckdb.StatementType.SELECT:
                query_statements.append(query)
            elif sql_stmt.type == duckdb.StatementType.PRAGMA:
                continue
            else:
                load_statements.append(query)
    return {'load': load_statements, 'query': query_statements}


def build_sources(old_source, new_source):
    # generate the sources
    current_path = os.getcwd()
    os.chdir(old_source)
    # build if not yet build
    if not os.path.isfile(unittest_binary):
        res = subprocess.run(['make', 'debug']).returncode
        if res != 0:
            raise Exception("Failed to build old sources")

    # run the verification
    os.chdir(current_path)
    os.chdir(new_source)

    # build if not yet build
    if not os.path.isfile(unittest_binary):
        res = subprocess.run(['make', 'debug']).returncode
        if res != 0:
            raise Exception("Failed to build new sources")
    os.chdir(current_path)


def run_test(filename, old_source, new_source, no_exit):
    statements = parse_test_file(filename)

    # generate the sources
    current_path = os.getcwd()
    os.chdir(old_source)
    # write the files
    with open(os.path.join(old_source, db_load_path), 'w+') as f:
        for stmt in statements['load']:
            f.write(complete_query(stmt) + '\n')

    with open(os.path.join(old_source, queries_path), 'w+') as f:
        for stmt in statements['query']:
            f.write(complete_query(stmt) + '\n')

    # generate the serialization
    my_env = os.environ.copy()
    my_env['GEN_PLAN_STORAGE'] = '1'
    res = subprocess.run(['build/debug/test/unittest', 'Generate serialized plans file'], env=my_env).returncode
    if res != 0:
        print(f"SKIPPING TEST {filename}")
        return True

    os.chdir(current_path)

    # copy over the files
    for f in [db_load_path, queries_path, result_binary]:
        shutil.copy(os.path.join(old_source, f), os.path.join(new_source, f))

    # run the verification
    os.chdir(new_source)

    res = subprocess.run(['build/debug/test/unittest', "Test deserialized plans from file"]).returncode
    if res != 0:
        if no_exit:
            print("BROKEN TEST")
            with open('broken_tests.list', 'a') as f:
                f.write(filename + '\n')
            return False
        raise Exception("Deserialization failure")
    os.chdir(current_path)
    return True


def parse_excluded_tests(path):
    exclusion_list = {}
    with open(path) as f:
        for line in f:
            if len(line.strip()) == 0 or line[0] == '#':
                continue
            exclusion_list[line.strip()] = True
    return exclusion_list


def find_tests_recursive(dir, excluded_paths):
    test_list = []
    for f in os.listdir(dir):
        path = os.path.join(dir, f)
        if path in excluded_paths:
            continue
        if os.path.isdir(path):
            test_list += find_tests_recursive(path, excluded_paths)
        elif path.endswith('.test'):
            test_list.append(path)
    return test_list


def main():
    parser = argparse.ArgumentParser(description="Test serialization")
    parser.add_argument("--new-source", type=str, help="Path to the new source", default='.')
    parser.add_argument("--old-source", type=str, help="Path to the old source")
    parser.add_argument("--start-at", type=str, help="Start running tests at this specific test", default=None)
    parser.add_argument("--no-exit", action="store_true", help="Keep running even if a test fails", default=False)
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument("--test-file", type=str, help="Path to the SQL logic file", default='')
    group.add_argument("--all-tests", action='store_true', help="Run all tests", default=False)
    group.add_argument("--test-list", type=str, help="Load tests to run from a file list", default=None)
    args = parser.parse_args()

    old_source = args.old_source
    new_source = args.new_source
    files = []
    if args.all_tests:
        # run all tests
        excluded_tests = parse_excluded_tests(
            os.path.join(new_source, 'test', 'api', 'serialized_plans', 'excluded_tests.list')
        )
        test_dir = os.path.join('test', 'sql')
        if new_source != '.':
            test_dir = os.path.join(new_source, test_dir)
        files = find_tests_recursive(test_dir, excluded_tests)
    elif args.test_list is not None:
        with open(args.test_list, 'r') as f:
            for line in f:
                if len(line.strip()) == 0:
                    continue
                files.append(line.strip())
    else:
        # run a single test
        files.append(args.test_file)
    files.sort()

    current_path = os.getcwd()
    try:
        build_sources(old_source, new_source)

        all_succeeded = True
        started = False
        if args.start_at is None:
            started = True
        for filename in files:
            if not started:
                if filename == args.start_at:
                    started = True
                else:
                    continue

            print(f"Run test {filename}")
            os.chdir(current_path)
            if not run_test(filename, old_source, new_source, args.no_exit):
                all_succeeded = False
        if not all_succeeded:
            exit(1)
    except:
        raise
    finally:
        os.chdir(current_path)


if __name__ == "__main__":
    main()