1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
# -*- coding: utf-8 -*-
from gimmik.base import MatMul
class HIPMatMul(MatMul):
platform = 'hip'
basemeta = {'block': (128, 1, 1), 'width': 1, 'shared': 0}
def _kernel_generators(self, dtype, dsize, *, gcn_arch=None, warp_size=64):
# B loading, C streaming kernel
yield ('cstream', {}, {})
# B streaming, C accumulation kernel
yield ('bstream', {}, {})
# Four-way m-split B streaming, C accumulation kernel
ms, bsz, blkx = 4, 24, 64
args = {'msplit': ms, 'bsz': bsz, 'blockx': blkx}
meta = {'block': (blkx, ms, 1), 'shared': 2*bsz*blkx*dsize}
yield ('bstream-msplit', args, meta)
# Two-way k-split B loading, C streaming kernel
ks, csz, blkx = 2, 24, 64
args = {'ksplit': ks, 'csz': csz, 'blockx': blkx}
meta = {'block': (blkx, ks, 1), 'shared': (ks - 1)*csz*blkx*dsize}
yield ('cstream-ksplit', args, meta)
def _process_meta(self, meta):
if self.n is not None:
div = meta['block'][0]*meta['width']
meta['grid'] = (-(-self.n // div), 1, 1)
|