1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
|
# mypy: allow-untyped-defs
import contextlib
import itertools
from typing import Any, Callable, Dict, List, Optional
from unittest.mock import patch
import sympy
from .. import ir
from ..select_algorithm import PartialRender
from ..virtualized import V
from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
from .cpp_micro_gemm import LayoutType
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking
# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition
GEMM_SINGLE_THREAD_MM_STUB = r"""
{{kernel.def_kernel(
inputs={"X": X, "W": W},
outputs={"Y": Y_2d},
aliases=aliases,
function_name="single_thread_mm",
extra_sizevars=BY_sizevars + [b_index],
placeholder="<SINGLE_THREAD_MM_DEF_FOR_BMM>")}}"""
GEMM_THREADED_MM_STUB = r"""
{{kernel.def_kernel(
inputs={"X": X, "W": W},
outputs={"Y": Y_2d},
aliases=aliases,
function_name="threaded_mm",
extra_sizevars=BY_sizevars + [b_index],
placeholder="<THREADED_MM_DEF_FOR_BMM>")}}"""
BMM_TEMPLATE = r"""
{{ template.codegen_microkernel_def() }}
{{ template.codegen_single_thread_gemm() }}
{{ template.codegen_multi_thread_gemm() }}
extern "C"
{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}}
{
const int64_t B = {{kernel.size(BY_2d, 0)}};
{%- if num_threads > 1 %}
constexpr int64_t num_threads = {{num_threads}};
int64_t B_single_thread_block = (B / num_threads) * num_threads;
#pragma omp parallel for num_threads({{num_threads}})
{%- else %}
int64_t B_single_thread_block = B;
{%- endif %}
for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
{{template.get_gemm_function_call(
kernel,
"single_thread_mm",
"<SINGLE_THREAD_CALL_FOR_BMM>",
b_index="b_start",
)}}
}
for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
{{template.get_gemm_function_call(
kernel,
"threaded_mm",
"<THREADED_MM_CALL_FOR_BMM>",
b_index="b_start",
)}}
}
}
"""
class CppBmmTemplate(CppGemmTemplate):
def __init__(
self,
input_nodes,
layout: ir.Layout,
num_threads: int,
register_blocking: GemmBlocking,
beta=1,
alpha=1,
has_bias=False,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
should_block_weights: bool = False,
name="bmm",
):
"""
In order to simplify the implementation and increase code reuse, the BMM template implements
two versions of the GEMM kernel: a single-threaded version and a multi-threaded version.
GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls
for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel.
We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM
template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template
without any changes to the GEMM template itself.
"""
super().__init__(
input_nodes,
layout,
num_threads,
register_blocking,
beta=beta,
alpha=alpha,
has_bias=has_bias,
epilogue_creator=epilogue_creator,
should_block_weights=should_block_weights,
name=name,
)
self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True)
@staticmethod
def get_padded_size(n, block_n, k, should_block_weight):
if should_block_weight:
# Tensor is constant or not contiguous, so we will pad and block
new_size, padded_n = CppGemmTemplate.get_padded_size(
n, block_n, k, should_block_weight
)
# Add the new batch dimension
new_size.insert(0, -1)
return new_size, padded_n
else:
new_size = [-1, k, n]
return new_size, n
@staticmethod
def check_if_block_weight(W, micro_gemm):
return micro_gemm.get_b_layout() != LayoutType.NORMAL or (
(not W.get_layout().is_contiguous() or W.get_name() in V.graph.constants) # type: ignore[union-attr]
if isinstance(W, ir.IRNode)
else not W.is_contiguous()
)
def get_gemm_function_call(
self,
kernel: CppTemplateKernel,
function_name: str,
placeholder: str,
b_index: int,
) -> str:
"""
Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
generate a function call for the GEMM kernel.
Args:
placeholder: The string to replace the function call with
b_index: The index for slicing the 3D batch tensors
"""
def hook():
arg_defs, call_args, _, _ = kernel.args.python_argdefs()
for i, buf in enumerate(call_args):
if buf == self.b_index:
arg_defs[i] = b_index
call = f"{function_name}({', '.join(arg_defs)});"
return call
assert placeholder not in kernel.render_hooks
kernel.render_hooks[placeholder] = hook
return placeholder
def get_default_reindexers(self, epilogue_nodes):
def reindexer(args):
# if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index
return [self.b_index] + args
return [reindexer] * len(epilogue_nodes)
def get_options(
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
flag_template_buffer_has_other_users: Optional[bool] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
**kwargs,
) -> Dict[str, Any]:
options = super().get_options(
kernel=kernel,
template_buffer_node=template_buffer_node,
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
epilogue_nodes=epilogue_nodes,
**kwargs,
)
BX, BW, BY = options["X"], options["W"], options["Y"]
options["BX"], options["BW"], options["BY"] = BX, BW, BY
options["BY_2d"] = options["Y_2d"]
for kword in ["X", "W", "GemmOut", "Y_2d"]:
options[kword] = kernel.select(options[kword], 0, self.b_index)
for kword in ["X", "W", "Y_2d"]:
options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype]
options["b_index"] = self.b_index
options["BY_sizevars"] = [
s
for sym in itertools.chain(BY.get_size(), BY.get_stride())
if isinstance(sym, sympy.Expr)
for s in sym.free_symbols
]
return options
def render( # type: ignore[override, return]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
flag_template_buffer_has_other_users: Optional[bool] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
**kwargs,
) -> str:
options = self.get_options(
kernel=kernel,
template_buffer_node=template_buffer_node,
flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
epilogue_nodes=epilogue_nodes,
**kwargs,
)
self.render_options = options
with contextlib.ExitStack() as stack:
for buf in options["fake_buffers"]:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
result = self._template_from_string(BMM_TEMPLATE).render(**options)
# Finalize the function definitions for the gemm routines
sub_mm_hooks = {
name: hook
for name, hook in kernel.render_hooks.items()
if "FOR_BMM" in name
}
result = PartialRender(result, sub_mm_hooks).finalize_all()
for name in sub_mm_hooks:
del kernel.render_hooks[name]
del kernel.args.sizevars[options["b_index"]]
return result
def codegen_single_thread_gemm(self):
stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render(
self.render_options
)
return stub + self._template_from_string(GEMM_TEMPLATE).render(
{**self.render_options, "num_threads": 1}
)
def codegen_multi_thread_gemm(self):
stub = self._template_from_string(GEMM_THREADED_MM_STUB).render(
self.render_options
)
return stub + self._template_from_string(GEMM_TEMPLATE).render(
self.render_options
)
def codegen_gemm_stub_def(self):
return ""
|