1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
|
from __future__ import annotations
import collections
import itertools
from typing import Any, Dict, Iterable, List, Type, Union
import sympy
import torch
from ...utils._ordered_set import OrderedSet
from ..dependencies import Dep, MemoryDep
from ..runtime.hints import ReductionHint
from ..scheduler import SchedulerNode
from ..utils import cache_on_self
from ..virtualized import V
class NodeScheduleMarker:
@staticmethod
def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
for item in it:
if not (item is DisableReduction or item is EnableReduction):
yield item # type: ignore[misc]
@staticmethod
def is_reduction() -> bool:
return False
NodeScheduleEntry = Union[SchedulerNode, Type[NodeScheduleMarker]]
class DisableReduction(NodeScheduleMarker):
"""
Marker to invoke `kernel.disable_reduction()`. This closes a
reduction loop and allows for pointwise ops to occur on the output
of a reduction.
"""
class EnableReduction(NodeScheduleMarker):
"""
Marker to end a DisableReduction block.
"""
@staticmethod
def filter(node_schedule: List[NodeScheduleEntry]) -> Iterable[SchedulerNode]:
"""
Get the nodes from node_schedule skipping those in a
DisableReduction block.
"""
disabled = False
for node in node_schedule:
if node in (EnableReduction, DisableReduction):
# Don't tile stuff outside the main reduction loop
disabled = node is DisableReduction
elif disabled:
pass
else:
yield node # type: ignore[misc]
class SIMDKernelFeatures:
"""
An ordered schedule of nodes that will become a single kernel.
"""
def __init__(
self,
node_schedule: List[NodeScheduleEntry],
numel: sympy.Expr,
reduction_numel: sympy.Expr = sympy.S.One,
):
self.node_schedule = node_schedule
# numel excludes reduction_numel
self.numel: sympy.Expr = V.graph.sizevars.simplify(numel)
self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel)
@cache_on_self
def is_reduction(self) -> bool:
return self.reduction_numel != 1
@cache_on_self
def scheduler_nodes(self) -> Iterable[SchedulerNode]:
return tuple(NodeScheduleMarker.only_nodes(self.node_schedule))
def reduction_nodes(self) -> List[SchedulerNode]:
return [n for n in self.scheduler_nodes() if n.is_reduction()]
@cache_on_self
def buf_accesses(self) -> Dict[str, List[Dep]]:
"""only needed for config.benchmark_kernel"""
buf_accesses = collections.defaultdict(list)
for node in self.scheduler_nodes():
for access in node.read_writes.reads | node.read_writes.writes:
buf_accesses[access.name].append(access)
return buf_accesses
@cache_on_self
def op_counts(self) -> collections.Counter[str]:
counts: collections.Counter[str] = collections.Counter()
for node in self.scheduler_nodes():
counts.update(node._body.op_counts)
return counts
def contains_op(self, op_name: str) -> bool:
"""True if V.ops.{op_name} is used in node_schedule"""
return bool(self.op_counts().get(op_name))
def get_mutations(self) -> OrderedSet[str]:
mutations: OrderedSet[str] = OrderedSet()
for node in self.scheduler_nodes():
for buf in node.get_outputs():
mutations.update(buf.get_mutations())
return mutations
@cache_on_self
def select_index_dtype(self) -> torch.dtype:
# Gather all used buffer names
buffer_names: OrderedSet[str] = OrderedSet()
for node in self.scheduler_nodes():
buffer_names.update(node.get_buffer_names())
buffer_names.update(node.used_buffer_names())
buffers = [V.graph.get_buffer(name) for name in buffer_names]
# In theory we can separately check xnumel and rnumel are <= int_max
# but some indexers do use the full linear index so we need to be
# conservative here.
total_numel = self.numel * self.reduction_numel
from .simd import SIMDScheduling
if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers):
return torch.int32
return torch.int64
@cache_on_self
def get_reduction_hint(self) -> ReductionHint:
reductions = self.reduction_nodes()
if len(reductions) > 0:
hints = [self.reduction_hint(n) for n in reductions]
if hints.count(hints[0]) == len(hints):
reduction_hint_val = hints[0]
else:
reduction_hint_val = ReductionHint.DEFAULT
if (
reduction_hint_val == ReductionHint.INNER
and self.has_non_contiguous_pw_in_reduction_kernel()
):
reduction_hint_val = ReductionHint.DEFAULT
else:
reduction_hint_val = ReductionHint.DEFAULT
return reduction_hint_val
def has_non_contiguous_pw_in_reduction_kernel(self) -> bool:
pointwise_nodes = [
n
for n in self.scheduler_nodes()
if not n.is_reduction()
and n.group[1][0] == self.numel * self.reduction_numel
]
for node in pointwise_nodes:
# An index can be an integer when loading a random seed.
if not all(
not isinstance(dep, MemoryDep)
or dep.is_contiguous()
or isinstance(dep.index, (sympy.Integer, int))
or dep.stride1_for_last_dim()
for dep in itertools.chain(
node.read_writes.reads, node.read_writes.writes
)
):
return True
return False
@staticmethod
def reduction_hint(node: Any) -> ReductionHint:
assert node.is_reduction()
if node.node.data.reduction_hint != ReductionHint.INNER and all(
dep.is_contiguous()
for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes)
):
return ReductionHint.INNER
else:
return node.node.data.reduction_hint
|