1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved. #
# This file is part of the LIBXSMM library. #
# #
# For information on the license, see the LICENSE file. #
# Further information: https://github.com/hfp/libxsmm/ #
# SPDX-License-Identifier: BSD-3-Clause #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import libxsmm_utilities
import sys
import os
if __name__ == "__main__":
argc = len(sys.argv)
if 1 < argc:
arg1_filename = [sys.argv[1], ""]["0" == sys.argv[1]]
arg1_isfile = os.path.isfile(arg1_filename)
base = 1
if arg1_isfile:
print("#if !defined(_WIN32)")
print("{ static const char *const build_state =")
print('# include "../' + os.path.basename(arg1_filename) + '"')
print(" ;")
print(" internal_build_state = build_state;")
print("}")
print("#endif")
base = 2
if (base + 2) < argc:
precision = int(sys.argv[base + 0])
threshold = int(sys.argv[base + 1])
mnklist = libxsmm_utilities.load_mnklist(sys.argv[base + 2:], 0)
print(
"/* omit registering code if JIT is enabled"
" and if an ISA extension is found"
)
print(
" * which is beyond the static code"
" path used to compile the library"
)
print(" */")
print("#if (0 != LIBXSMM_JIT) && !defined(__MIC__)")
print(
"if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
"/* JIT code gen. is not available */"
)
print(
" /* conditions allows to avoid JIT "
"(if static code is good enough) */"
)
print(
" || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
)
print(
" || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
)
print(
" libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
)
print(
" libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
)
print("#endif")
print("{")
print(" libxsmm_xmmfunction func;")
for mnk in mnklist:
mstr, nstr, kstr, mnkstr = (
str(mnk[0]),
str(mnk[1]),
str(mnk[2]),
"_".join(map(str, mnk)),
)
mnksig = mstr + ", " + nstr + ", " + kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if 1 != precision: # only double-precision
print(
" func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
+ mnkstr
+ ";"
)
print(
" internal_register_static_code("
+ "LIBXSMM_GEMM_PRECISION_F64, "
+ mnksig
+ ", func, new_registry);"
)
for mnk in mnklist:
mstr, nstr, kstr, mnkstr = (
str(mnk[0]),
str(mnk[1]),
str(mnk[2]),
"_".join(map(str, mnk)),
)
mnksig = mstr + ", " + nstr + ", " + kstr
# prefer registering double-precision kernels
# when approaching an exhausted registry
if 2 != precision: # only single-precision
print(
" func.smm = (libxsmm_smmfunction)libxsmm_smm_"
+ mnkstr
+ ";"
)
print(
" internal_register_static_code("
+ "LIBXSMM_GEMM_PRECISION_F32, "
+ mnksig
+ ", func, new_registry);"
)
print("}")
else:
sys.tracebacklimit = 0
raise ValueError(sys.argv[0] + ": wrong number of arguments!")
|