1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
# mypy: allow-untyped-defs
from __future__ import annotations
import functools
import logging
from ctypes import byref, c_int, c_size_t, c_void_p
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
from torch._inductor import config
from torch._inductor.autotune_process import (
BenchmarkRequest,
GPUDeviceBenchmarkMixin,
TensorMeta,
)
from torch._inductor.codecache import DLLWrapper, ROCmCodeCache
log = logging.getLogger(__name__)
class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
# Important: Instances of this class have to be serializable
# across process boundaries. Do not put CUDA Tensors in here!
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
self.source_code = source_code
self.workspace_size: int = 0
self.workspace: Optional[torch.Tensor] = None
self.DLL: Optional[DLLWrapper] = None
self._workspace_size_updated = False
self.hash_key: str = ""
self.source_file: str = ""
self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so")
def precompile(self):
# Prepopulate code cache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
ROCmCodeCache.compile(self.source_code, "so")
if config.rocm.generate_test_runner:
ROCmCodeCache.compile(self.source_code, "exe")
log.debug("Done precompiling %s", self)
def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.ensure_dll_loaded()
self.update_workspace_size()
args = [
c_void_p(tensor.data_ptr())
for tensor in list(input_tensors) + [output_tensor]
]
size_args = [c_int(arg) for arg in self.extra_args]
log.debug(
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
run_method = getattr(self.DLL, self.kernel_name)
workspace_ptr = c_void_p(0)
if self.workspace_size > 0:
self.workspace = torch.zeros(
(self.workspace_size + 7) // 8,
dtype=torch.float64,
device=output_tensor.device,
)
workspace_ptr = c_void_p(self.workspace.data_ptr())
# Generate partial function.
return functools.partial(
run_method,
*args,
*size_args,
None, # null workspace size ptr
workspace_ptr, # set workspace ptr,
stream_ptr,
)
def update_workspace_size(self) -> None:
if self._workspace_size_updated:
return
self.ensure_dll_loaded()
unique_input_count = len({meta.name for meta in self.input_tensor_meta})
args = [c_void_p(None) for _ in range(unique_input_count + 1)]
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
run_method = getattr(self.DLL, self.kernel_name)
# Retrieve workspace_size and initialize workspace.
c_workspace_size = c_size_t()
size_args = [c_int(arg) for arg in self.extra_args]
run_method(
*args, # input ptrs and output ptrs
*size_args,
byref(
c_workspace_size
), # set workspace size ptr to retrieve workspace size
None, # null workspace ptr
stream_ptr,
)
torch.cuda.synchronize() # shake out any CUDA errors
self.workspace_size = c_workspace_size.value
log.debug(
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
self.workspace_size,
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
self._workspace_size_updated = True
def ensure_dll_loaded(self):
if self.DLL is None:
self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load(
self.source_code, "so"
)
def cleanup_run_fn(self) -> None:
if self.DLL is not None:
self.DLL.close()
self.workspace = None
def __str__(self) -> str:
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
|