1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
|
# mypy: allow-untyped-defs
from typing import Any, Dict, List, Optional
import sympy
import torch
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
from ..utils import _type_of, expr_fits_within_32bit
from ..virtualized import V
from .common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg, WorkspaceArg
def should_unwrap_unspec_arg(name: str):
if V.graph.is_unspec_arg(name):
# Unwrap on all devices except CPU
if V.graph.get_current_device_or_throw().type != "cpu":
return True
# Only unwrap on CPU if the input is not used as an output
if name not in V.graph.mutated_buffers:
return True
return False
def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
if isinstance(arg, TensorArg):
# TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes.
# Related PR: https://github.com/openai/triton/pull/2279/
if arg.dtype == torch.float8_e4m3fn:
tye = "*fp8e4nv"
elif arg.dtype == torch.float8_e5m2:
tye = "*fp8e5"
elif arg.dtype == torch.float8_e4m3fnuz:
tye = "*fp8e4b8"
elif arg.dtype == torch.float8_e5m2fnuz:
tye = "*fp8e5b16"
else:
tye = _type_of(arg.dtype)
if should_unwrap_unspec_arg(arg.buffer):
# had unwrapped 0d tensor as scalar
new_tye = tye.lstrip("*")
if new_tye in ["fp16", "bf16"]:
return "fp32"
else:
return new_tye
else:
return tye
if isinstance(arg, SizeArg):
if arg.expr is None:
# From triton/runtime/jit.py
# `None` is nullptr. Implicitly convert to *i8.
return "*i8"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
# if this is a integer
if size_dtype == "tl.int32":
return "i32"
elif size_dtype == "tl.int64":
return "i64"
elif size_dtype is None:
# no hint: we'll see if we know that this is a 32-bit int, and guard if possible.
int_max = torch.iinfo(torch.int32).max
if expr_fits_within_32bit(arg.expr):
V.graph.sizevars.guard_leq(arg.expr, int_max)
return "i32"
else:
return "i64"
else:
raise NotImplementedError(f"unhandled size_dtype {size_dtype}")
if isinstance(arg, WorkspaceArg):
return _type_of(arg.dtype)
if isinstance(arg, TMADescriptorArg):
return "nvTmaDesc"
raise NotImplementedError(f"unhandled {type(arg)}: {arg}")
def signature_to_meta(
signature: List[KernelArgType],
*,
size_dtype: Optional[str],
argdefs: List[str],
indices: Optional[List[int]] = None,
) -> Dict[str, str]:
if indices is None:
indices = list(range(len(signature)))
return {
argdefs[i]: signature_of(arg, size_dtype=size_dtype)
for i, arg in zip(indices, signature)
}
def is_unaligned_buffer(arg: TensorArg):
buf_name = arg.buffer
if buf_name in V.graph.graph_inputs:
# See Note: [Input Alignment handling in Inductor]
return buf_name not in V.graph.aligned_inputs
if buf_name in V.graph.constants:
# all constants are assumed to be aligned
return False
if V.graph.scheduler:
layout = V.graph.scheduler.get_buffer_layout(buf_name)
else:
buffer = V.graph.try_get_buffer(buf_name)
# output arg
if not buffer:
assert buf_name == V.kernel.output_node.name
layout = V.kernel.output_node.layout
else:
layout = buffer.get_layout()
if isinstance(layout, torch._inductor.ir.NonOwningLayout):
return not layout.maybe_guard_aligned()
else:
return False
def config_of(
args: List[KernelArgType],
*,
indices: Optional[List[int]] = None,
) -> Any:
if indices is None:
indices = list(range(len(args)))
def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
"""
Roughly follow triton code here:
https://github.com/openai/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222
"""
if isinstance(x, TensorArg):
if include_tensor:
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
)
return offset_aligned and not is_unaligned_buffer(x)
else:
return False
if isinstance(x, SizeArg):
# TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with
# _maybe_evaluate_static...
if x.name.startswith("load_seed_offset"):
return False
if x.expr is None:
return False
if isinstance(x.expr, float):
return False
return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type]
if isinstance(x, WorkspaceArg):
# We allocate the workspace ourselves, so it is always aligned
return True
if isinstance(x, TMADescriptorArg):
return False
raise NotImplementedError(f"unhandled {type(x)}: {x}")
if config.triton.divisible_by_16:
divisible_by_16 = tuple(
i
for i, arg in zip(indices, args)
if is_aligned(arg, alignment=16, include_tensor=True)
)
else:
divisible_by_16 = ()
equal_to_1 = tuple(
i
for i, arg in zip(indices, args)
if isinstance(arg, SizeArg)
and isinstance(arg.expr, (int, sympy.Integer))
and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type]
)
return AttrsDescriptorWrapper(divisible_by_16, equal_to_1)
|