1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import operator
import torch
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
cache_dir,
default_cache_dir,
triton_cache_dir,
)
def conditional_product(*args):
return functools.reduce(operator.mul, [x for x in args if x])
def ceildiv(numer: int, denom: int) -> int:
return -(numer // -denom)
def is_power_of_2(n: int) -> bool:
"""Returns whether n = 2 ** m for some integer m."""
return n > 0 and n & n - 1 == 0
def next_power_of_2(n: int) -> int:
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n
def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int:
"""
Return the total number of bytes the arguments of tensor type takes.
For in/out args, tensor sizes are counted twice: once for reading and
once for writing.
The first num_in_out_args arguments are in out tensors.
"""
return sum(
arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args))
for i, arg in enumerate(args)
if isinstance(arg, torch.Tensor)
)
def triton_config_to_hashable(cfg):
"""
Convert triton config to a tuple that can uniquely identify it. We can use
the return value as a dictionary key.
"""
items = sorted(cfg.kwargs.items())
items.append(("num_warps", cfg.num_warps))
items.append(("num_stages", cfg.num_stages))
return tuple(items)
def validate_triton_config(cfg):
# [Note: Triton pre_hook in inductor]
# pre-hook is a lambda function, which we don't attempt to serialize.
# right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching.
assert (
getattr(cfg, "pre_hook", None) is None
), "triton configs with pre_hooks not supported"
def create_bandwidth_info_str(ms, num_gb, gb_per_s, prefix="", suffix="", color=True):
info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}"
slow = ms > 0.012 and gb_per_s < 650
return red_text(info_str) if color and slow else info_str
def get_max_y_grid():
return 65535
try:
import colorama
HAS_COLORAMA = True
except ModuleNotFoundError:
HAS_COLORAMA = False
colorama = None # type: ignore[assignment]
def _color_text(msg, color):
if not HAS_COLORAMA:
return msg
return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET
def green_text(msg):
return _color_text(msg, "green")
def yellow_text(msg):
return _color_text(msg, "yellow")
def red_text(msg):
return _color_text(msg, "red")
def blue_text(msg):
return _color_text(msg, "blue")
def get_first_attr(obj, *attrs):
"""
Return the first available attribute or throw an exception if none is present.
"""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type]
def triton_hash_to_path_key(key):
# In early versions of Triton, the hash is directly used in the path name.
# Later, the hash is converted to base64 before being used in the path name.
# Later, the base64 convertion was replaced to the base32
#
# This code tries to import _base64 and falls back to _base32 if _base64 is unavailable.
#
# To handle this, try to import the to-base64-conversion function.
# If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly.
try:
from triton.runtime.cache import _base64
return _base64(key)
except Exception as e:
try:
from triton.runtime.cache import _base32
return _base32(key)
except Exception as e:
return key
|