1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
|
# RUN: env SUPPORT_LIB=%mlir_cuda_runtime \
# RUN: %PYTHON %s | FileCheck %s
# ===----------------------------------------------------------------------===//
# Chapter 4 : Multistage GEMM with Tensor Core
# ===----------------------------------------------------------------------===//
#
# This program exemplifies a GEMM operation for `f32+=f16*f16`, utilizing the
# Multistage method with a tile size of 128x128x64. The code completely
# parallelizes the two outermost loops into thread blocks. It launches one Warp
# Groups (128 threads in total) and allocates multiple slots/stage in the
# shared memory. The program consists of three main parts: prologue, mainloop,
# and epilogue. In the prologue, thread0 requests for TMA to load data into
# shared memory slots. The mainloop executes MMA while simultaneously loading
# TMA for the utilized slots. This overlap of TMA and MMA operations enhances
# performance by maximizing computational throughput.
#
# Loops illustration:
#
# for s in range(num_stages):
# TMA_128x64_64x128...
# for ti in range(M//128): # -> blockIdx.x
# for tj in range(N//128): # -> blockIdx.y
# for tk in range(K//64):
# MMA_128x128x64...
# TMA_128x64_64x128...
# Epilogue...
#
# This chapter introduces demonstrates:
# 1. Partition shape based on block IDs
# 2. Prologue
# 2.1 Execute TMA Load for two input matrices for each stage
# 3. Main loop
# 3.1 Wait for completion of TMA load with mbarrier
# 3.2 Performs Tensor Core GEMM 64x128x64 by warpgroup
# 3.3 Load next stage if needed
# 4. Epilogue
# 4.1 Store fragmented registers to shared memory
# 4.2 Store shared memory to global
#
# ===----------------------------------------------------------------------===//
from mlir import ir
from mlir.dialects import gpu, scf, nvgpu, nvvm
from mlir.extras import types as T
from tools.nvdsl import *
import numpy as np
def partition_shape():
"""
Calculate the partition shape based on the block IDs.
It partitions the shape like below:
for(.. i < M ...) --> blockIdx.x
for(.. j < N ...) --> blockIdx.y
for(.. k < K ...)
Returns:
dimX (int): Dimension along the x-axis.
dimY (int): Dimension along the y-axis.
"""
bidx = gpu.block_id(gpu.Dimension.x)
bidy = gpu.block_id(gpu.Dimension.y)
dimX = bidx * TILE_M
dimY = bidy * TILE_N
return dimX, dimY
def tma_load(
mbar_group: Mbarriers,
a_tma: TMA,
b_tma: TMA,
slot,
stage,
num_stages,
p=None,
):
"""
TMA loads two input matrices from global memory to shared memory. It performs the following operations:
- tma.load a_shared_memory[off_x] at coordinate [x, z] (Loads 128x64)
- tma.load b_shared_memory[off_y1] at coordinate [y, x] (Loads 64x64)
- tma.load b_shared_memory[off_y2] at coordinate [y + 64, x] (Loads 64x64)
mbarrier.arrive ta_count = 128x64x2x4
"""
dimX, dimY = partition_shape()
tidx = gpu.thread_id(gpu.Dimension.x)
begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_tma_a = get_type_size(a_tma.tma_memref)
size_tma_b = get_type_size(b_tma.tma_memref)
ta_count = size_tma_a + (size_tma_b * 2)
tidx = gpu.thread_id(gpu.Dimension.x)
p = tidx == 0 if p is None else p
off_a = slot * size_tma_a
off_b = (slot * size_tma_a) + begin_b
off_b2 = off_b + size_tma_b
a_elem_ty = a_tma.tma_memref.element_type
b_elem_ty = b_tma.tma_memref.element_type
a = get_dynamic_shared_memory(a_tma.tma_memref.shape, a_elem_ty, off_a)
b1 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b)
b2 = get_dynamic_shared_memory(b_tma.tma_memref.shape, b_elem_ty, off_b2)
mbar_group[slot].arrive(ta_count, predicate=p)
c1 = stage * 64
a_tma.load(a, mbar_group[slot], coords=[c1, dimX], predicate=p)
b_tma.load(b1, mbar_group[slot], coords=[dimY, c1], predicate=p)
b_tma.load(b2, mbar_group[slot], coords=[dimY + 64, c1], predicate=p)
def initialize(a_tma: TMA, b_tma: TMA, num_stages):
"""
Initialize mbarriers and prefetch TMA descriptors.
"""
tidx = gpu.thread_id(gpu.Dimension.x)
mbar_group = Mbarriers(number_of_barriers=num_stages)
isThread0 = tidx == const(0)
with ir.InsertionPoint(scf.IfOp(isThread0).then_block):
for i in scf.for_(0, num_stages, 1):
mbar_group[i].init(1)
scf.yield_([])
a_tma.prefetch()
b_tma.prefetch()
scf.yield_([])
return mbar_group
def prologue(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
"""
Prologue of the GEMM kernel. It loads 2 input matrices for each stage in loop like below:
for stage in range(NUM_STAGES):
tma_load x, y, stage
"""
ns = num_stages if num_stages == 1 else num_stages - 1
for iv in scf.for_(0, ns, 1):
tma_load(mbar_group, a_tma, b_tma, iv, iv, num_stages)
scf.yield_([])
def mainloop(mbar_group: Mbarriers, a_tma: TMA, b_tma: TMA, num_stages):
"""
Main loop of the Multistage GEMM kernel. It iterates through
stages and performs matrix multiplication, loading data by TMA to shared memory. It like following
MatrixAccumulator D
for k in range(K // TILE_K):
try_wait(stage, ...) # Wait TMA load
Matrix A(stage, ...) # Find shared memory slot
Matrix B(stage, ...) # Find shared memory slot
D += A @ B # Multiply and accumulate
if(needLoad) # Load next stage if needed
tma_load(x, y, nextSlot, nextStage)
"""
ns = num_stages if num_stages == 1 else num_stages - 1
tidx = gpu.thread_id(gpu.Dimension.x)
begin_b = num_stages * get_type_size(a_tma.tma_memref)
size_a = TILE_M * TILE_K * get_type_size(T.f16())
# Initialize A and B (input matrices) and C (accumulator)
A = WGMMAMatrix(WGMMAType.Descriptor, [TILE_M, TILE_K], desc=a_tma)
B = WGMMAMatrix(WGMMAType.Descriptor, [TILE_K, TILE_N], desc=b_tma)
D = WGMMAMatrix(WGMMAType.Accumulator, shape=[TILE_M, TILE_N], ty=T.f32())
phase = const(False, ty=T.bool())
# Main Loop
for_op = scf.ForOp(const(0), const(K // TILE_K), const(1), [D.acc_op, phase])
with ir.InsertionPoint(for_op.body):
phase = for_op.inner_iter_args[1]
iv = for_op.induction_variable
stage = iv % num_stages
# Wait for current stage
mbar_group[stage].try_wait(phase=phase)
# Find shared memory slot
offset_a = stage * size_a
offset_b = offset_a + begin_b
a_smem = get_dynamic_shared_memory([TILE_M, TILE_K], T.f16(), offset_a)
b_smem = get_dynamic_shared_memory([TILE_K, TILE_N], T.f16(), offset_b)
# Iterate input matrices, update accumulator
A.update_smem(a_smem)
B.update_smem(b_smem)
D.update_accumulator(for_op.inner_iter_args[0])
# Matrix Multiply
D += A @ B
# Wait Tensor Core for single stage
if num_stages == 1:
nvvm.WgmmaWaitGroupSyncOp(0)
# Load next stage
pred = ((iv + ns) < const(K // TILE_K)) & (tidx == 0)
nextStage = iv + ns
nextSlot = nextStage % num_stages
tma_load(mbar_group, a_tma, b_tma, nextSlot, nextStage, num_stages, pred)
# Switch phase parity for the mbarrier
newPhase = arith.select(
stage == (num_stages - 1),
(phase ^ const(True, ty=T.bool())),
phase,
)
scf.yield_([D.acc_op, newPhase])
nvvm.WgmmaWaitGroupSyncOp(0)
D.update_accumulator(for_op.results[0])
return D
def epilogue(D: WGMMAMatrix, d_dev):
"""
Epilogue of the GEMM kernel. It stores the fragmented registers to global memory.
MatrixAccumulator D # Fragmented results
store D -> Shared Memory # Store Shared Memory
Shared Memory -> Z[dimX][dimY] # Store Shared Memory to Global Memory
"""
tidx = gpu.thread_id(gpu.Dimension.x)
dimX, dimY = partition_shape()
d_smem = get_dynamic_shared_memory([TILE_M, TILE_N], T.f32())
d_gmem = memref.subview(d_dev, [dimX, dimY], [TILE_M, TILE_N], [1, 1])
# Store (registers -> shared memory)
D.store_accumulator(d_smem)
gpu.barrier()
# Store (shared memory --> global memory)
for i in scf.for_(0, TILE_M, 1):
val = memref.load(d_smem, [i, tidx])
memref.store(val, d_gmem, [i, tidx])
scf.yield_([])
# The decorator generates
# a -> memref<MxKxf16>
# b -> memref<NxKf16>
# d -> memref<MxNxf32>
@NVDSL.mlir_func
def gemm_multistage(a, b, d, num_stages):
token_ty = gpu.AsyncTokenType.get()
t1 = gpu.wait(token_ty, [])
a_dev, t2 = gpu.alloc(a.type, token_ty, [t1], [], [])
b_dev, t3 = gpu.alloc(b.type, token_ty, [t2], [], [])
d_dev, t4 = gpu.alloc(d.type, token_ty, [t3], [], [])
t5 = gpu.memcpy(token_ty, [t4], a_dev, a)
t6 = gpu.memcpy(token_ty, [t5], b_dev, b)
t7 = gpu.wait(token_ty, [t6])
sw = nvgpu.TensorMapSwizzleKind.SWIZZLE_128B
a_tma = TMA([128, 64], a.type, swizzle=sw)
b_tma = TMA([64, 64], b.type, swizzle=sw)
a_tma.create_descriptor(a_dev)
b_tma.create_descriptor(b_dev)
grid = [(M // TILE_M), (N // TILE_N), 1]
block = [128, 1, 1]
size_a = get_type_size(a.type.element_type) * TILE_M * TILE_K
size_b = get_type_size(b.type.element_type) * TILE_N * TILE_K
smem_size_in_bytes = (size_a + size_b) * num_stages
@NVDSL.mlir_gpu_launch(grid=grid, block=block, smem=smem_size_in_bytes)
def gemm_multistage_kernel():
# Initialize mbarriers and prefetch TMA descriptors
mbar_group = initialize(a_tma, b_tma, num_stages)
# Fill the pipeline stages
prologue(mbar_group, a_tma, b_tma, num_stages)
# Main loop
D = mainloop(mbar_group, a_tma, b_tma, num_stages)
# Store registers to global memory
epilogue(D, d_dev)
gemm_multistage_kernel()
t8 = gpu.memcpy(token_ty, [t7], d, d_dev)
gpu.wait(None, [t8])
# Python pass arguments to MLIR
N = 256
M = 512
K = 1024
TILE_M = 128
TILE_N = 128
TILE_K = 64
a = np.random.randn(M, K).astype(np.float16)
b = np.random.randn(K, N).astype(np.float16)
d = np.zeros((M, N), np.float32)
gemm_multistage(a, b, d, num_stages=7)
# Verify MLIR with reference computation
ref_d = a.astype(np.float16) @ b.astype(np.float16)
np.testing.assert_allclose(d, ref_d, rtol=5e-03, atol=1e-01)
print("PASS")
# CHECK-NOT: Mismatched elements
|