File: __init__.py

package info (click to toggle)
python-gimmik 3.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 204 kB
  • sloc: python: 323; makefile: 4
file content (29 lines) | stat: -rw-r--r-- 844 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# -*- coding: utf-8 -*-

from gimmik._version import __version__
from gimmik.c import CMatMul
from gimmik.copenmp import COpenMPMatMul
from gimmik.cuda import CUDAMatMul
from gimmik.ispc import ISPCMatMul
from gimmik.hip import HIPMatMul
from gimmik.metal import MetalMatMul
from gimmik.opencl import OpenCLMatMul


def generate_mm(mat, dtype, platform, alpha=1.0, beta=0.0, funcn='gimmik_mm',
                n=None, ldb=None, ldc=None):
    import warnings

    warnings.warn('generate_mm is deprecated, use MatMul', DeprecationWarning)

    platmap = {
        'c': CMatMul,
        'c-omp': COpenMPMatMul,
        'cuda': CUDAMatMul,
        'ispc': ISPCMatMul,
        'hip': HIPMatMul,
        'opencl': OpenCLMatMul
    }

    mm = platmap[platform](alpha*mat, beta, None, n, ldb, ldc)
    return next(mm.kernels(dtype, kname=funcn))[0]