#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved.                      #
# This file is part of the LIBXSMM library.                                   #
#                                                                             #
# For information on the license, see the LICENSE file.                       #
# Further information: https://github.com/hfp/libxsmm/                        #
# SPDX-License-Identifier: BSD-3-Clause                                       #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
from string import Template
import libxsmm_utilities
import fnmatch
import sys


if __name__ == "__main__":
    argc = len(sys.argv)
    if 1 < argc:
        # required argument(s)
        filename = sys.argv[1]

        # default configuration if no arguments are given
        precision = 0  # all
        ifversion = 1  # interface
        prefetch = -1  # auto
        mnklist = list()

        # optional argument(s)
        if 2 < argc:
            ivalue = int(sys.argv[2])
            ifversion = (ivalue >> 2)
            precision = (ivalue & 3)
        if 3 < argc:
            prefetch = int(sys.argv[3])
        if 4 < argc:
            mnklist = sorted(libxsmm_utilities.load_mnklist(sys.argv[4:], 0))

        template = Template(open(filename, "r").read())
        if fnmatch.fnmatch(filename, "*.h*"):
            optional = [", ...", ""][0 <= prefetch]
            substitute = {"MNK_INTERFACE_LIST": ""}
            for mnk in mnklist:
                mnkstr = "_".join(map(str, mnk))
                if 2 != precision:
                    pfsig = [
                        optional + ");",
                        ",\n  "
                        "const float* pa, "
                        "const float* pb, "
                        "const float* pc);"
                    ][0 < prefetch]
                    substitute["MNK_INTERFACE_LIST"] += (
                        "\nLIBXSMM_API void libxsmm_smm_"
                        + mnkstr
                        + "(const float* a, const float* b, float* c"
                        + pfsig
                    )
                if 1 != precision:
                    pfsig = [
                        optional + ");",
                        ",\n  "
                        "const double* pa, "
                        "const double* pb, "
                        "const double* pc);"
                    ][0 < prefetch]
                    substitute["MNK_INTERFACE_LIST"] += (
                        "\nLIBXSMM_API void libxsmm_dmm_"
                        + mnkstr
                        + "(const double* a, const double* b, double* c"
                        + pfsig
                    )
                if 0 == precision:
                    substitute["MNK_INTERFACE_LIST"] += "\n"
            if mnklist and 0 != precision:
                substitute["MNK_INTERFACE_LIST"] += "\n"
            print(template.substitute(substitute))
        else:  # Fortran interface
            if 1 > ifversion and 0 != ifversion:
                raise ValueError("Fortran interface level is inconsistent!")
            # Fortran's OPTIONAL allows to always generate an interface
            # with prefetch signature (more flexible usage)
            if 0 == prefetch:
                prefetch = -1
            version, branch, realversion = libxsmm_utilities.version_branch(16)
            major, minor, update, patch = libxsmm_utilities.version_numbers(
                version
            )
            substitute = {
                "VERSION": realversion,
                "BRANCH": branch,
                "MAJOR": major,
                "MINOR": minor,
                "UPDATE": update,
                "PATCH": patch,
                "MNK_INTERFACE_LIST": "",
                "CONTIGUOUS": ["", ", CONTIGUOUS"][1 < ifversion]
            }
            if mnklist:
                substitute["MNK_INTERFACE_LIST"] += "\n"
                for mnk in mnklist:
                    mnkstr = "_".join(map(str, mnk))
                    if 0 == precision:
                        substitute["MNK_INTERFACE_LIST"] += (
                            "\n        "
                            "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
                            + mnkstr
                            + ", libxsmm_dmm_"
                            + mnkstr
                        )
                    elif 2 != precision:
                        substitute["MNK_INTERFACE_LIST"] += (
                            "\n        "
                            "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_smm_"
                            + mnkstr
                        )
                    elif 1 != precision:
                        substitute["MNK_INTERFACE_LIST"] += (
                            "\n        "
                            "!DIR$ ATTRIBUTES OFFLOAD:MIC :: libxsmm_dmm_"
                            + mnkstr
                        )
                substitute["MNK_INTERFACE_LIST"] += "\n        INTERFACE"
                optional = [", OPTIONAL", ""][0 < prefetch]
                bindc = ["", "BIND(C)"][0 < prefetch]
                for mnk in mnklist:
                    mnkstr = "_".join(map(str, mnk))
                    if 2 != precision:
                        pfsiga = [
                            ") BIND(C)\n",
                            ","
                            + "&".rjust(26 - len(mnkstr))
                            + "\n     &    pa, pb, pc) "
                            + bindc
                            + "\n"
                        ][0 != prefetch]
                        pfsigb = [
                            "",
                            "            REAL(C_FLOAT), "
                            "INTENT(IN)" + optional + " :: "
                            "pa(*), "
                            "pb(*), "
                            "pc(*)\n"
                        ][0 != prefetch]
                        substitute["MNK_INTERFACE_LIST"] += (
                            "\n          "
                            "PURE SUBROUTINE libxsmm_smm_"
                            + mnkstr
                            + "(a, b, c"
                            + pfsiga
                            + "            IMPORT :: C_FLOAT\n"
                            "            REAL(C_FLOAT), "
                            "INTENT(IN) :: a(*), b(*)\n"
                            "            REAL(C_FLOAT), "
                            "INTENT(INOUT) :: c(*)\n"
                            + pfsigb
                            + "          END SUBROUTINE"
                        )
                    if 1 != precision:
                        pfsiga = [
                            ") BIND(C)\n",
                            ","
                            + "&".rjust(26 - len(mnkstr))
                            + "\n     &    pa, pb, pc) "
                            + bindc
                            + "\n"
                        ][0 != prefetch]
                        pfsigb = [
                            "",
                            "            REAL(C_DOUBLE), "
                            "INTENT(IN)" + optional + " :: "
                            "pa(*), "
                            "pb(*), "
                            "pc(*)\n"
                        ][0 != prefetch]
                        substitute["MNK_INTERFACE_LIST"] += (
                            "\n          "
                            "PURE SUBROUTINE libxsmm_dmm_"
                            + mnkstr
                            + "(a, b, c"
                            + pfsiga
                            + "            IMPORT :: C_DOUBLE\n"
                            "            REAL(C_DOUBLE), "
                            "INTENT(IN) :: a(*), b(*)\n"
                            "            REAL(C_DOUBLE), "
                            "INTENT(INOUT) :: c(*)\n"
                            + pfsigb
                            + "          END SUBROUTINE"
                        )
                substitute["MNK_INTERFACE_LIST"] += "\n        END INTERFACE"
            print(template.safe_substitute(substitute))
    else:
        sys.tracebacklimit = 0
        raise ValueError(sys.argv[0] + ": wrong number of arguments!")
