1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
import torch
from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate
from torch._inductor.ir import IRNode
from torch._inductor.utils import IndentedBuffer
class CKTemplate(ROCmTemplate):
"""
Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic
"""
_TORCH_DTYPE_TO_CK = {
torch.float32: "F32",
torch.float64: "F64",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int32: "I32",
torch.int8: "I8",
torch.float8_e4m3fnuz: "F8",
torch.float8_e5m2fnuz: "BF8",
}
def header(self) -> IndentedBuffer:
res = super().header()
res.splice(
"""
// CK headers
#ifdef DEBUG_LOG
#define DEBUG_LOG_TMP DEBUG_LOG
#undef DEBUG_LOG
#else
#define DEBUG_LOG_TMP 0
#endif
#include "ck/ck.hpp"
#undef DEBUG_LOG
#define DEBUG_LOG DEBUG_LOG_TMP
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
"""
)
return res
def globals(self) -> IndentedBuffer:
res = super().globals()
res.splice(
"""
// CK globals
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template<typename... Ts>
using Tuple = ck::Tuple<Ts...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
// see "composable_kernel/include/ck/utility/data_type.hpp"
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using F16 = ck::half_t;
using F32 = float;
// using F64 = double;
using BF16 = ck::bhalf_t;
// using I32 = int32_t;
// using I8 = int8_t;
// using I4 = ck::int4_t;
#if DEBUG_LOG
static constexpr auto kDEBUG_LOG = 1;
#else
static constexpr auto kDEBUG_LOG = 0;
#endif
"""
)
return res
def torch_type_to_ck(self, node: IRNode, ptr: str) -> str:
if node is None:
return ptr
else:
return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})"
|