File: plan_cost_runner.py

package info (click to toggle)
duckdb 1.5.1-2
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 299,196 kB
  • sloc: cpp: 865,414; ansic: 57,292; python: 18,871; sql: 12,663; lisp: 11,751; yacc: 7,412; lex: 1,682; sh: 747; makefile: 558
file content (207 lines) | stat: -rw-r--r-- 6,569 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import argparse
import glob
import json
import os
import subprocess
import sys
from tqdm import tqdm


OLD_DB_NAME = "old.duckdb"
NEW_DB_NAME = "new.duckdb"
PROFILE_FILENAME = "duckdb_profile.json"

ENABLE_PROFILING = "PRAGMA enable_profiling=json"
PROFILE_OUTPUT = f"PRAGMA profile_output='{PROFILE_FILENAME}'"

BANNER_SIZE = 52


def init_db(cli, dbname, benchmark_dir):
    print(f"INITIALIZING {dbname} ...")
    subprocess.run(
        f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL
    )
    subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/load.sql", shell=True, check=True, stdout=subprocess.DEVNULL)
    print("INITIALIZATION DONE")


class PlanCost:
    def __init__(self):
        self.total = 0
        self.build_side = 0
        self.probe_side = 0
        self.time = 0

    def __add__(self, other):
        self.total += other.total
        self.build_side += other.build_side
        self.probe_side += other.probe_side
        return self

    def __gt__(self, other):
        if self == other or self.total < other.total:
            return False
        # if the total intermediate cardinalities is greater, also inspect time.
        # it's possible a plan reordering increased cardinalities, but overall execution time
        # was not greatly affected
        total_card_increased = self.total > other.total
        build_card_increased = self.build_side > other.build_side
        if total_card_increased and build_card_increased:
            return True
        # we know the total cardinality is either the same or higher and the build side has not increased
        # in this case fall back to the timing. It's possible that even if the probe side is higher
        # since the tuples are in flight, the plan executes faster
        return self.time > other.time * 1.03

    def __lt__(self, other):
        if self == other:
            return False
        return not (self > other)

    def __eq__(self, other):
        return self.total == other.total and self.build_side == other.build_side and self.probe_side == other.probe_side


def is_measured_join(op) -> bool:
    if 'name' not in op:
        return False
    if op['name'] != 'HASH_JOIN':
        return False
    if 'Join Type' not in op['extra_info']:
        return False
    if op['extra_info']['Join Type'].startswith('MARK'):
        return False
    return True


def op_inspect(op) -> PlanCost:
    cost = PlanCost()
    if 'Query' in op:
        cost.time = op['operator_timing']
    if is_measured_join(op):
        cost.total = op['operator_cardinality']
        if 'operator_cardinality' in op['children'][0]:
            cost.probe_side += op['children'][0]['operator_cardinality']
        if 'operator_cardinality' in op['children'][1]:
            cost.build_side += op['children'][1]['operator_cardinality']

        left_cost = op_inspect(op['children'][0])
        right_cost = op_inspect(op['children'][1])
        cost.probe_side += left_cost.probe_side + right_cost.probe_side
        cost.build_side += left_cost.build_side + right_cost.build_side
        cost.total += left_cost.total + right_cost.total
        return cost

    for child_op in op['children']:
        cost += op_inspect(child_op)

    return cost


def query_plan_cost(cli, dbname, query):
    try:
        subprocess.run(
            f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"",
            shell=True,
            check=True,
            capture_output=True,
        )
    except subprocess.CalledProcessError as e:
        print("-------------------------")
        print("--------Failure----------")
        print("-------------------------")
        print(e.stderr.decode('utf8'))
        print("-------------------------")
        print("--------Output----------")
        print("-------------------------")
        print(e.output.decode('utf8'))
        print("-------------------------")
        raise e
    with open(PROFILE_FILENAME, 'r') as file:
        return op_inspect(json.load(file))


def print_banner(text):
    text_len = len(text)
    rest = BANNER_SIZE - text_len - 10
    l_width = int(rest / 2)
    r_width = l_width
    if rest % 2 != 0:
        l_width += 1
    print("")
    print("=" * BANNER_SIZE)
    print("=" * l_width + " " * 5 + text + " " * 5 + "=" * r_width)
    print("=" * BANNER_SIZE)


def print_diffs(diffs):
    for query_name, old_cost, new_cost in diffs:
        print("")
        print("Query:", query_name)
        print("Old total cost:", old_cost.total)
        print("Old build cost:", old_cost.build_side)
        print("Old probe cost:", old_cost.probe_side)
        print("New total cost:", new_cost.total)
        print("New build cost:", new_cost.build_side)
        print("New probe cost:", new_cost.probe_side)


def main():
    parser = argparse.ArgumentParser(description="Plan cost regression test script with old and new versions.")

    parser.add_argument("--old", type=str, help="Path to the old runner.", required=True)
    parser.add_argument("--new", type=str, help="Path to the new runner.", required=True)
    parser.add_argument("--dir", type=str, help="Path to the benchmark directory.", required=True)

    args = parser.parse_args()

    old = args.old
    new = args.new
    benchmark_dir = args.dir

    init_db(old, OLD_DB_NAME, benchmark_dir)
    init_db(new, NEW_DB_NAME, benchmark_dir)

    improvements = []
    regressions = []

    files = glob.glob(f"{benchmark_dir}/queries/*.sql")
    files.sort()

    print("")
    print("RUNNING BENCHMARK QUERIES")
    for f in tqdm(files):
        query_name = f.split("/")[-1].replace(".sql", "")

        with open(f, "r") as file:
            query = file.read()

        old_cost = query_plan_cost(old, OLD_DB_NAME, query)
        new_cost = query_plan_cost(new, NEW_DB_NAME, query)

        if old_cost > new_cost:
            improvements.append((query_name, old_cost, new_cost))
        elif new_cost > old_cost:
            regressions.append((query_name, old_cost, new_cost))

    exit_code = 0
    if improvements:
        print_banner("IMPROVEMENTS DETECTED")
        print_diffs(improvements)
    if regressions:
        exit_code = 1
        print_banner("REGRESSIONS DETECTED")
        print_diffs(regressions)
    if not improvements and not regressions:
        print_banner("NO DIFFERENCES DETECTED")

    os.remove(OLD_DB_NAME)
    os.remove(NEW_DB_NAME)
    os.remove(PROFILE_FILENAME)

    exit(exit_code)


if __name__ == "__main__":
    main()