1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
# mypy: allow-untyped-defs
from __future__ import annotations
import collections
import functools
import typing
from enum import auto, Enum
from typing import Dict, List, Optional, Union
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"Y": 1024,
"Z": 1024,
"R": 4096 * 16, # * 16 is multi-kernel only
}
TRITON_MAX_RSPLIT = 64
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
def _is_triton_available() -> bool:
try:
import triton # noqa: F401
return True
except ImportError:
return False
# Define `AttrsDescriptorWrapper` function with clear conditional handling
if _is_triton_available():
try:
from triton.backends.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"tt.divisibility": divisible_by_16,
"tt.equal_to": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
res = AttrsDescriptor.from_dict(
{"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
)
assert res.property_values["tt.divisibility"] == 16
assert res.property_values["tt.equal_to"] == 1
return res
except ImportError:
from triton.compiler.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"divisible_by_16": divisible_by_16,
"equal_to_1": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
return AttrsDescriptor(**kwargs)
else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
"AttrsDescriptor",
["divisible_by_16", "equal_to_1"],
defaults=[(), ()],
)
_NUM_THREADS_PER_WARP = 32
class HeuristicType(Enum):
PERSISTENT_REDUCTION = auto()
POINTWISE = auto()
REDUCTION = auto()
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
FIXED = auto()
class AutotuneHint(Enum):
ONE_ELEMENT_PER_THREAD = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__
class DeviceProperties(typing.NamedTuple):
"""Copy device properties into a data structure not requiring torch to be imported"""
type: str # type: ignore[assignment]
index: int # type: ignore[assignment]
multi_processor_count: int
cc: int
major: Optional[int] = None
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
warp_size: Optional[int] = None
@classmethod
@functools.lru_cache(None)
def create(cls, device) -> DeviceProperties:
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type
if torch.version.hip and device_type == "cuda":
device_type = "hip"
device_interface = get_interface_for_device(device)
props = device_interface.get_device_properties(device)
try:
multi_processor_count = props.multi_processor_count
except AttributeError:
if device_type == "xpu":
multi_processor_count = props.gpu_subslice_count
else:
raise
return cls(
type=device_type,
index=device.index,
multi_processor_count=multi_processor_count,
cc=device_interface.get_compute_capability(device),
major=getattr(props, "major", None),
regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
max_threads_per_multi_processor=getattr(
props, "max_threads_per_multi_processor", None
),
warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
)
class HalideInputSpec(typing.NamedTuple):
ctype: str
name: str
shape: Optional[List[str]] = None
stride: Optional[List[str]] = None
offset: Optional[str] = None
alias_of: Optional[str] = None
def bindings_type(self) -> str:
if self.ctype in ("half*", "bfloat16*"):
return "uint16_t*" # half not defined
return self.ctype
def halide_type(self) -> str:
if self.ctype == "half*":
return "halide_type_t(halide_type_float, 16)" # half not defined
if self.ctype == "bfloat16*":
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
return f"halide_type_of<{self.ctype.replace('*', '')}>()"
def is_scalar(self) -> bool:
return self.shape is None
def is_buffer(self) -> bool:
return self.shape is not None
class HalideMeta(typing.NamedTuple):
argtypes: List[HalideInputSpec]
target: str
scheduler: Optional[str] = None
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
cuda_device: Optional[int] = None
def args(self) -> List[str]:
"""Command line args to pass to halide generator"""
args = [f"target={self.target}"]
if self.scheduler:
args.append(f"autoscheduler={self.scheduler}")
if self.scheduler_flags:
assert self.scheduler
for k, v in self.scheduler_flags.items():
args.append(f"autoscheduler.{k}={v}")
return args
def is_cuda(self) -> bool:
return self.cuda_device is not None
|