1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
# mypy: allow-untyped-defs
from typing import Sequence, Union
from ..scheduler import (
BaseSchedulerNode,
BaseScheduling,
FusedSchedulerNode,
Scheduler,
SchedulerNode,
)
from .cuda.cuda_cpp_scheduling import CUDACPPScheduling
from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling
from .triton import TritonScheduling
class CUDACombinedScheduling(BaseScheduling):
"""
Scheduler for CUDA Kernels, which delegates calls as appropriate
to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices
and use a unified-wrapper for codegen.
If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code,
this would also be the place to do it.
"""
def __init__(self, scheduler: Scheduler) -> None:
super().__init__()
self._scheduler = scheduler
self._triton_scheduling = TritonScheduling(scheduler)
self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler)
self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler)
def get_backend_features(self, device): # type:ignore[override]
return self._triton_scheduling.get_backend_features(device)
def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling:
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
return self._cuda_cpp_scheduling
if self._rocm_cpp_scheduling.is_rocm_cpp_template(node):
return self._rocm_cpp_scheduling
return self._triton_scheduling
def can_fuse_vertical(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2):
return True
return self._triton_scheduling.can_fuse_vertical(node1, node2)
def can_fuse_horizontal(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode):
for node in (node1, node2):
if self._cuda_cpp_scheduling.is_cuda_cpp_template(node):
return self._cuda_cpp_scheduling.can_fuse_horizontal(
node1, node2
) # always False at the moment
return self._triton_scheduling.can_fuse_horizontal(node1, node2)
def group_fn(self, sizes):
return self._triton_scheduling.group_fn(sizes)
def codegen_template(
self,
template_node: BaseSchedulerNode,
epilogue_nodes: Sequence[BaseSchedulerNode],
):
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
assert epilogue_nodes is None or len(epilogue_nodes) == 0
return self._cuda_cpp_scheduling.codegen_template(
template_node, epilogue_nodes
)
elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node):
assert epilogue_nodes is None or len(epilogue_nodes) == 0
return self._rocm_cpp_scheduling.codegen_template(
template_node, epilogue_nodes
)
else:
return self._triton_scheduling.codegen_template(
template_node, epilogue_nodes
)
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):
return self._triton_scheduling.codegen_node(node)
def codegen_sync(self):
return self._triton_scheduling.codegen_sync()
def flush(self):
return self._triton_scheduling.flush()
def codegen_combo_kernel(self, *args, **kwargs):
return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
def benchmark_fused_nodes(self, nodes):
return self._triton_scheduling.benchmark_fused_nodes(nodes)
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
return self._triton_scheduling.generate_kernel_code_from_nodes(
nodes, benchmark_kernel
)
def benchmark_combo_kernel(self, node_list):
return self._triton_scheduling.benchmark_combo_kernel(node_list)
|