File: libxsmm_dispatch.py

package info (click to toggle)
libxsmm 1.17-4
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 14,976 kB
  • sloc: ansic: 119,587; cpp: 27,680; fortran: 9,179; sh: 5,765; makefile: 5,040; pascal: 2,312; python: 1,812; f90: 1,773
file content (116 lines) | stat: -rwxr-xr-x 4,744 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
###############################################################################
# Copyright (c) Intel Corporation - All rights reserved.                      #
# This file is part of the LIBXSMM library.                                   #
#                                                                             #
# For information on the license, see the LICENSE file.                       #
# Further information: https://github.com/hfp/libxsmm/                        #
# SPDX-License-Identifier: BSD-3-Clause                                       #
###############################################################################
# Hans Pabst (Intel Corp.)
###############################################################################
import libxsmm_utilities
import sys
import os


if __name__ == "__main__":
    argc = len(sys.argv)
    if 1 < argc:
        arg1_filename = [sys.argv[1], ""]["0" == sys.argv[1]]
        arg1_isfile = os.path.isfile(arg1_filename)
        base = 1
        if arg1_isfile:
            print("#if !defined(_WIN32)")
            print("{ static const char *const build_state =")
            print('#   include "../' + os.path.basename(arg1_filename) + '"')
            print("  ;")
            print("  internal_build_state = build_state;")
            print("}")
            print("#endif")
            base = 2
        if (base + 2) < argc:
            precision = int(sys.argv[base + 0])
            threshold = int(sys.argv[base + 1])
            mnklist = libxsmm_utilities.load_mnklist(sys.argv[base + 2:], 0)
            print(
                "/* omit registering code if JIT is enabled"
                " and if an ISA extension is found"
            )
            print(
                " * which is beyond the static code"
                " path used to compile the library"
            )
            print(" */")
            print("#if (0 != LIBXSMM_JIT) && !defined(__MIC__)")
            print(
                "if (LIBXSMM_X86_GENERIC > libxsmm_target_archid "
                "/* JIT code gen. is not available */"
            )
            print(
                "   /* conditions allows to avoid JIT "
                "(if static code is good enough) */"
            )
            print(
                "   || (LIBXSMM_STATIC_TARGET_ARCH == libxsmm_target_archid)"
            )
            print(
                "   || (LIBXSMM_X86_AVX512_CORE <= libxsmm_target_archid &&"
            )
            print(
                "       libxsmm_cpuid_vlen32(LIBXSMM_STATIC_TARGET_ARCH) =="
            )
            print(
                "       libxsmm_cpuid_vlen32(libxsmm_target_archid)))"
            )
            print("#endif")
            print("{")
            print("  libxsmm_xmmfunction func;")
            for mnk in mnklist:
                mstr, nstr, kstr, mnkstr = (
                    str(mnk[0]),
                    str(mnk[1]),
                    str(mnk[2]),
                    "_".join(map(str, mnk)),
                )
                mnksig = mstr + ", " + nstr + ", " + kstr
                # prefer registering double-precision kernels
                # when approaching an exhausted registry
                if 1 != precision:  # only double-precision
                    print(
                        "  func.dmm = (libxsmm_dmmfunction)libxsmm_dmm_"
                        + mnkstr
                        + ";"
                    )
                    print(
                        "  internal_register_static_code("
                        + "LIBXSMM_GEMM_PRECISION_F64, "
                        + mnksig
                        + ", func, new_registry);"
                    )
            for mnk in mnklist:
                mstr, nstr, kstr, mnkstr = (
                    str(mnk[0]),
                    str(mnk[1]),
                    str(mnk[2]),
                    "_".join(map(str, mnk)),
                )
                mnksig = mstr + ", " + nstr + ", " + kstr
                # prefer registering double-precision kernels
                # when approaching an exhausted registry
                if 2 != precision:  # only single-precision
                    print(
                        "  func.smm = (libxsmm_smmfunction)libxsmm_smm_"
                        + mnkstr
                        + ";"
                    )
                    print(
                        "  internal_register_static_code("
                        + "LIBXSMM_GEMM_PRECISION_F32, "
                        + mnksig
                        + ", func, new_registry);"
                    )
            print("}")
    else:
        sys.tracebacklimit = 0
        raise ValueError(sys.argv[0] + ": wrong number of arguments!")