1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
|
import argparse
import glob
import json
import os
import subprocess
import sys
from tqdm import tqdm
OLD_DB_NAME = "old.duckdb"
NEW_DB_NAME = "new.duckdb"
PROFILE_FILENAME = "duckdb_profile.json"
ENABLE_PROFILING = "PRAGMA enable_profiling=json"
PROFILE_OUTPUT = f"PRAGMA profile_output='{PROFILE_FILENAME}'"
BANNER_SIZE = 52
def init_db(cli, dbname, benchmark_dir):
print(f"INITIALIZING {dbname} ...")
subprocess.run(
f"{cli} {dbname} < {benchmark_dir}/init/schema.sql", shell=True, check=True, stdout=subprocess.DEVNULL
)
subprocess.run(f"{cli} {dbname} < {benchmark_dir}/init/load.sql", shell=True, check=True, stdout=subprocess.DEVNULL)
print("INITIALIZATION DONE")
class PlanCost:
def __init__(self):
self.total = 0
self.build_side = 0
self.probe_side = 0
self.time = 0
def __add__(self, other):
self.total += other.total
self.build_side += other.build_side
self.probe_side += other.probe_side
return self
def __gt__(self, other):
if self == other or self.total < other.total:
return False
# if the total intermediate cardinalities is greater, also inspect time.
# it's possible a plan reordering increased cardinalities, but overall execution time
# was not greatly affected
total_card_increased = self.total > other.total
build_card_increased = self.build_side > other.build_side
if total_card_increased and build_card_increased:
return True
# we know the total cardinality is either the same or higher and the build side has not increased
# in this case fall back to the timing. It's possible that even if the probe side is higher
# since the tuples are in flight, the plan executes faster
return self.time > other.time * 1.03
def __lt__(self, other):
if self == other:
return False
return not (self > other)
def __eq__(self, other):
return self.total == other.total and self.build_side == other.build_side and self.probe_side == other.probe_side
def is_measured_join(op) -> bool:
if 'name' not in op:
return False
if op['name'] != 'HASH_JOIN':
return False
if 'Join Type' not in op['extra_info']:
return False
if op['extra_info']['Join Type'].startswith('MARK'):
return False
return True
def op_inspect(op) -> PlanCost:
cost = PlanCost()
if 'Query' in op:
cost.time = op['operator_timing']
if is_measured_join(op):
cost.total = op['operator_cardinality']
if 'operator_cardinality' in op['children'][0]:
cost.probe_side += op['children'][0]['operator_cardinality']
if 'operator_cardinality' in op['children'][1]:
cost.build_side += op['children'][1]['operator_cardinality']
left_cost = op_inspect(op['children'][0])
right_cost = op_inspect(op['children'][1])
cost.probe_side += left_cost.probe_side + right_cost.probe_side
cost.build_side += left_cost.build_side + right_cost.build_side
cost.total += left_cost.total + right_cost.total
return cost
for child_op in op['children']:
cost += op_inspect(child_op)
return cost
def query_plan_cost(cli, dbname, query):
try:
subprocess.run(
f"{cli} --readonly {dbname} -c \"{ENABLE_PROFILING};{PROFILE_OUTPUT};{query}\"",
shell=True,
check=True,
capture_output=True,
)
except subprocess.CalledProcessError as e:
print("-------------------------")
print("--------Failure----------")
print("-------------------------")
print(e.stderr.decode('utf8'))
print("-------------------------")
print("--------Output----------")
print("-------------------------")
print(e.output.decode('utf8'))
print("-------------------------")
raise e
with open(PROFILE_FILENAME, 'r') as file:
return op_inspect(json.load(file))
def print_banner(text):
text_len = len(text)
rest = BANNER_SIZE - text_len - 10
l_width = int(rest / 2)
r_width = l_width
if rest % 2 != 0:
l_width += 1
print("")
print("=" * BANNER_SIZE)
print("=" * l_width + " " * 5 + text + " " * 5 + "=" * r_width)
print("=" * BANNER_SIZE)
def print_diffs(diffs):
for query_name, old_cost, new_cost in diffs:
print("")
print("Query:", query_name)
print("Old total cost:", old_cost.total)
print("Old build cost:", old_cost.build_side)
print("Old probe cost:", old_cost.probe_side)
print("New total cost:", new_cost.total)
print("New build cost:", new_cost.build_side)
print("New probe cost:", new_cost.probe_side)
def main():
parser = argparse.ArgumentParser(description="Plan cost regression test script with old and new versions.")
parser.add_argument("--old", type=str, help="Path to the old runner.", required=True)
parser.add_argument("--new", type=str, help="Path to the new runner.", required=True)
parser.add_argument("--dir", type=str, help="Path to the benchmark directory.", required=True)
args = parser.parse_args()
old = args.old
new = args.new
benchmark_dir = args.dir
init_db(old, OLD_DB_NAME, benchmark_dir)
init_db(new, NEW_DB_NAME, benchmark_dir)
improvements = []
regressions = []
files = glob.glob(f"{benchmark_dir}/queries/*.sql")
files.sort()
print("")
print("RUNNING BENCHMARK QUERIES")
for f in tqdm(files):
query_name = f.split("/")[-1].replace(".sql", "")
with open(f, "r") as file:
query = file.read()
old_cost = query_plan_cost(old, OLD_DB_NAME, query)
new_cost = query_plan_cost(new, NEW_DB_NAME, query)
if old_cost > new_cost:
improvements.append((query_name, old_cost, new_cost))
elif new_cost > old_cost:
regressions.append((query_name, old_cost, new_cost))
exit_code = 0
if improvements:
print_banner("IMPROVEMENTS DETECTED")
print_diffs(improvements)
if regressions:
exit_code = 1
print_banner("REGRESSIONS DETECTED")
print_diffs(regressions)
if not improvements and not regressions:
print_banner("NO DIFFERENCES DETECTED")
os.remove(OLD_DB_NAME)
os.remove(NEW_DB_NAME)
os.remove(PROFILE_FILENAME)
exit(exit_code)
if __name__ == "__main__":
main()
|