#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved.                      #
# This file is part of the LIBXSMM library.                                   #
#                                                                             #
# For information on the license, see the LICENSE file.                       #
# Further information: https://github.com/hfp/libxsmm/                        #
# SPDX-License-Identifier: BSD-3-Clause                                       #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import libxsmm_utilities
import sys
import os


if __name__ == "__main__":
    argc = len(sys.argv)
    if 1 < argc:
        arg1_filename = [sys.argv[1], ""]["0" == sys.argv[1]]
        arg1_isfile = os.path.isfile(arg1_filename)
        base = 1
        if arg1_isfile:
            print("#if !defined(_WIN32)")
            print("{ static const char *const build_state =")
            print('#   include "../' + os.path.basename(arg1_filename) + '"')
            print("  ;")
            print("  internal_build_state = build_state;")
            print("}")
            print("#endif")
            base = 2
        if (base + 2) < argc:
            precision = int(sys.argv[base + 0])
            threshold = int(sys.argv[base + 1])
            mnklist = libxsmm_utilities.load_mnklist(sys.argv[base + 2:], 0)
            print(
                "/* omit registering code if JIT is enabled"
                " and if an ISA extension is found"
            )
            print(
                " * which is beyond the static code"
                " path used to compile the library"
            )
            print(" */")
            print("#if (0 != LIBXSMM_JIT) && !defined(__MIC__)")
            print(
                "if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
                "/* JIT code gen. is not available */"
            )
            print(
                "   /* conditions allows to avoid JIT "
                "(if static code is good enough) */"
            )
            print(
                "   || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
            )
            print(
                "   || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
            )
            print(
                "       libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
            )
            print(
                "       libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
            )
            print("#endif")
            print("{")
            print("  libxsmm_xmmfunction func;")
            for mnk in mnklist:
                mstr, nstr, kstr, mnkstr = (
                    str(mnk[0]),
                    str(mnk[1]),
                    str(mnk[2]),
                    "_".join(map(str, mnk)),
                )
                mnksig = mstr + ", " + nstr + ", " + kstr
                # prefer registering double-precision kernels
                # when approaching an exhausted registry
                if 1 != precision:  # only double-precision
                    print(
                        "  func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
                        + mnkstr
                        + ";"
                    )
                    print(
                        "  internal_register_static_code("
                        + "LIBXSMM_GEMM_PRECISION_F64, "
                        + mnksig
                        + ", func, new_registry);"
                    )
            for mnk in mnklist:
                mstr, nstr, kstr, mnkstr = (
                    str(mnk[0]),
                    str(mnk[1]),
                    str(mnk[2]),
                    "_".join(map(str, mnk)),
                )
                mnksig = mstr + ", " + nstr + ", " + kstr
                # prefer registering double-precision kernels
                # when approaching an exhausted registry
                if 2 != precision:  # only single-precision
                    print(
                        "  func.smm = (libxsmm_smmfunction)libxsmm_smm_"
                        + mnkstr
                        + ";"
                    )
                    print(
                        "  internal_register_static_code("
                        + "LIBXSMM_GEMM_PRECISION_F32, "
                        + mnksig
                        + ", func, new_registry);"
                    )
            print("}")
    else:
        sys.tracebacklimit = 0
        raise ValueError(sys.argv[0] + ": wrong number of arguments!")
