File: base.py

package info (click to toggle)
python-gimmik 3.2.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 204 kB
  • sloc: python: 323; makefile: 4
file content (155 lines) | stat: -rw-r--r-- 4,945 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# -*- coding: utf-8 -*-

import itertools as it
import pkgutil
import re

from mako.lookup import TemplateLookup
from mako.template import Template
import numpy as np


class _PlatformTemplateLookup(TemplateLookup):
    def __init__(self, platform):
        self.platform = platform

    def adjust_uri(self, uri, relto):
        return uri

    def get_template(self, name):
        platform = self.platform
        src = pkgutil.get_data(__name__, f'kernels/{platform}/{name}.mako')

        return Template(src, lookup=self)


def _dot(bfn, row, maxsplit=1):
    nzixs, = np.nonzero(row)

    if not nzixs.size:
        return '0.0'

    nsplit = max(min(maxsplit, nzixs.size // 3), 1)
    snzixs = np.array_split(nzixs, nsplit)

    frags = [' + '.join(f'{row[i]}*{bfn(i)}' for i in ix) for ix in snzixs]
    return ' + '.join(f'({f})' for f in frags)


def _partition(mat, into, by):
    if by == 'rows':
        return [list(range(i, len(mat), into)) for i in range(into)]
    elif by == 'cols':
        return [list(range(i, len(mat.T), into)) for i in range(into)]
    else:
        raise ValueError('Invalid partition by')


def _chunk(l, chunksz):
    l, n = iter(l), len(l)
    nchunks = -(-n // chunksz)

    return [list(it.islice(l, chunksz)) for i in range(nchunks)]


class MatMul:
    platform = None

    def __init__(self, A, beta=0.0, aligne=None, n=None, ldb=None, ldc=None):
        self.A = A
        self.beta = beta
        self.aligne = aligne

        if n is None and ldb is None and ldc is None:
            self.n = self.ldb = self.ldc = None
        elif n is not None and ldb is not None and ldc is not None:
            if aligne is not None and (ldb % aligne or ldc % aligne):
                raise ValueError('ldb/ldc not compatible with aligne')

            self.n, self.ldb, self.ldc = n, ldb, ldc
        else:
            raise ValueError('Must provide all of (n, ldb, ldc) or none')

        # Check the matrix has a non-zero
        if not A.any():
            raise ValueError('A can not be empty')

        # Extract the shape of A
        self.m, self.k = m, k = A.shape

        # Determine the index of the first and last non-zero in each row of A
        self.afix = (A != 0).argmax(axis=1)
        self.alix = k - 1 - (A != 0)[:, ::-1].argmax(axis=1)

        # Mark rows of A which are all zero
        self.afix = np.where(np.any(A != 0, axis=1), self.afix, -1)
        self.alix = np.where(np.any(A != 0, axis=1), self.alix, -1)
        self.has_zero_rows = np.any(self.afix == -1)

        # Determine which entries of B partake in the multiplication
        self.bix = np.nonzero(np.any(A != 0, axis=0))[0]
        self.bix = {kx: k for k, kx in enumerate(self.bix)}

    def kernels(self, dtype, kname='gimmik_mm', **kwargs):
        basemeta = self.basemeta

        # Process the data type
        dtype = np.dtype(dtype).type
        if dtype == np.float32:
            dtype, dsize = 'float', 4
        elif dtype == np.float64:
            dtype, dsize = 'double', 8
        else:
            raise ValueError('Invalid floating point data type')

        # Common template arguments
        baseargs = {
            'dtype': dtype, 'kname': kname,
            'A': self.A, 'beta': self.beta, 'width': 1,
            'm': self.m, 'n': self.n, 'k': self.k,
            'ldb': self.ldb, 'ldc': self.ldc,
            'afix': self.afix, 'alix': self.alix, 'bix': self.bix,
            'dot': _dot, 'partition': _partition, 'chunk': _chunk
        }

        # Incrementally generate and render the kernels
        gen = self._kernel_generators(dtype, dsize, **kwargs)
        try:
            resp = None
            while True:
                # Generate the next kernel in the sequence
                name, exargs, exmeta = gen.send(resp)

                # Merge in the base arguments and metadata
                args = baseargs | exargs
                meta = basemeta | exmeta

                # Render the kernel template
                src = self._render_kernel(dtype, name, args)

                # Post-process the metadata
                meta['tplname'] = name
                self._process_meta(meta)

                # Yield the source and metadata and await a response
                resp = yield (src, meta)
        except StopIteration:
            pass

    def _process_meta(self, meta):
        pass

    def _render_kernel(self, dtype, tplname, tplargs):
        tpl = _PlatformTemplateLookup(self.platform).get_template(tplname)
        src = tpl.render(**tplargs)

        # At single precision suffix all floating point constants by 'f'
        if dtype == 'float':
            src = re.sub(r'(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?',
                         r'\g<0>f', src)

        # Cleanup
        src = re.sub(r'^\w+\n$', '', src.strip())
        src = re.sub(r'\n\n+', r'\n\n', src) + '\n'
        src = re.sub(r'\w+$', '', src)
        return src