1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
#!/usr/bin/env python
# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
# PEP8 Python style guide and uses a max-width of 120 characters per line.
#
# Author(s):
# Cedric Nugteren <www.cedricnugteren.nl>
import sys
import os.path
import glob
import argparse
import database.io as io
import database.db as db
import database.clblast as clblast
import database.bests as bests
import database.defaults as defaults
# Server storing a copy of the database
DATABASE_SERVER_URL = "https://raw.githubusercontent.com/CNugteren/CLBlast-database/master/database.json"
def remove_mismatched_arguments(database):
"""Checks for tuning results with mis-matched entries and removes them according to user preferences"""
kernel_attributes = clblast.DEVICE_TYPE_ATTRIBUTES + clblast.KERNEL_ATTRIBUTES + ["kernel"]
# For Python 2 and 3 compatibility
try:
user_input = raw_input
except NameError:
user_input = input
pass
# Check for mis-matched entries
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
print("[database] WARNING: entries for a single kernel with multiple argument values " +
str(kernel_group_name))
print("[database] Either quit or remove all but one of the argument combinations below:")
for index, (attribute_group_name, mismatching_entries) in enumerate(group_by_arguments):
print("[database] %d: %s" % (index, attribute_group_name))
for attribute_group_name, mismatching_entries in group_by_arguments:
response = user_input("[database] Remove entries corresponding to %s, [y/n]? " %
str(attribute_group_name))
if response == "y":
for entry in mismatching_entries:
database["sections"].remove(entry)
print("[database] Removed %d entry/entries" % len(mismatching_entries))
# Sanity-check: all mis-matched entries should be removed
for kernel_group_name, kernel_group in db.group_by(database["sections"], kernel_attributes):
group_by_arguments = db.group_by(kernel_group, clblast.ARGUMENT_ATTRIBUTES)
if len(group_by_arguments) != 1:
print("[database] ERROR: entries for a single kernel with multiple argument values " +
str(kernel_group_name))
assert len(group_by_arguments) == 1
def remove_database_entries(database, remove_if_matches_fields):
assert len(remove_if_matches_fields.keys()) > 0
def remove_this_entry(section):
for key in remove_if_matches_fields.keys():
if section[key] != remove_if_matches_fields[key]:
return False
return True
old_length = len(database["sections"])
database["sections"] = [x for x in database["sections"] if not remove_this_entry(x)]
new_length = len(database["sections"])
print("[database] Removed %d entries from the database" % (old_length - new_length))
def add_tuning_parameter(database, parameter_name, kernel, value):
num_changes = 0
for section in database["sections"]:
if section["kernel"] == kernel:
for result in section["results"]:
if parameter_name not in result["parameters"]:
result["parameters"][parameter_name] = value
section["parameter_names"].append(parameter_name)
num_changes += 1
print("[database] Made %d addition(s) of %s" % (num_changes, parameter_name))
def main(argv):
# Parses the command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("source_folder", help="The folder with JSON files to parse to add to the database")
parser.add_argument("clblast_root", help="Root of the CLBlast sources")
parser.add_argument("-r", "--remove_device", type=str, default=None, help="Removes all entries for a specific device")
parser.add_argument("--add_tuning_parameter", type=str, default=None, help="Adds this parameter to existing entries")
parser.add_argument("--add_tuning_parameter_for_kernel", type=str, default=None, help="Adds the above parameter for this kernel")
parser.add_argument("--add_tuning_parameter_value", type=int, default=0, help="Set this value as the default for the above parameter")
parser.add_argument("-v", "--verbose", action="store_true", help="Increase verbosity of the script")
cl_args = parser.parse_args(argv)
# Parses the path arguments
database_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database.json")
database_best_filename = os.path.join(cl_args.clblast_root, "scripts", "database", "database_best.json")
json_files = os.path.join(cl_args.source_folder, "*.json")
cpp_database_path = os.path.join(cl_args.clblast_root, "src", "database", "kernels")
# Checks whether the command-line arguments are valid
clblast_header = os.path.join(cl_args.clblast_root, "include", "clblast.h") # Not used but just for validation
if not os.path.isfile(clblast_header):
raise RuntimeError("The path '" + cl_args.clblast_root +
"' does not point to the root of the CLBlast library")
if len(glob.glob(json_files)) < 1:
print("[database] The path '" + cl_args.source_folder + "' does not contain any JSON files")
# Downloads the database if a local copy is not present
if not os.path.isfile(database_filename):
io.download_database(database_filename, DATABASE_SERVER_URL)
# Loads the database from disk
database = io.load_database(database_filename)
# Loops over all JSON files in the supplied folder
for file_json in glob.glob(json_files):
sys.stdout.write("[database] Processing '" + file_json + "' ") # No newline printed
try:
# Loads the newly imported data
imported_data = io.load_tuning_results(file_json)
# Adds the new data to the database
old_size = db.length(database)
database = db.add_section(database, imported_data)
new_size = db.length(database)
print("with " + str(new_size - old_size) + " new items") # Newline printed here
except ValueError:
print("--- WARNING: invalid file, skipping")
# Checks for tuning results with mis-matched entries
remove_mismatched_arguments(database)
# Stores the modified database back to disk
if len(glob.glob(json_files)) >= 1:
io.save_database(database, database_filename)
# Removes database entries before continuing
if cl_args.remove_device is not None:
print("[database] Removing all results for device '%s'" % cl_args.remove_device)
remove_database_entries(database, {"clblast_device_name": cl_args.remove_device})
#, "kernel_family": "xgemm"})
io.save_database(database, database_filename)
# Adds new tuning parameters to existing database entries
if cl_args.add_tuning_parameter is not None and\
cl_args.add_tuning_parameter_for_kernel is not None:
print("[database] Adding tuning parameter: '%s' for kernel '%s' with default %d" %
(cl_args.add_tuning_parameter, cl_args.add_tuning_parameter_for_kernel,
cl_args.add_tuning_parameter_value))
add_tuning_parameter(database, cl_args.add_tuning_parameter,
cl_args.add_tuning_parameter_for_kernel,
cl_args.add_tuning_parameter_value)
io.save_database(database, database_filename)
# Retrieves the best performing results
print("[database] Calculating the best results per device/kernel...")
database_best_results = bests.get_best_results(database)
# Determines the defaults for other vendors and per vendor
print("[database] Calculating the default values...")
database_defaults = defaults.calculate_defaults(database, cl_args.verbose)
database_best_results["sections"].extend(database_defaults["sections"])
# Optionally outputs the database to disk
if cl_args.verbose:
io.save_database(database_best_results, database_best_filename)
# Outputs the database as a C++ database
print("[database] Producing a C++ database in '" + cpp_database_path + "'...")
clblast.print_cpp_database(database_best_results, cpp_database_path)
print("[database] All done")
if __name__ == '__main__':
main(sys.argv[1:])
|