1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# mypy: allow-untyped-defs
import collections
import logging
import torch
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from .. import config, inductor_prims
from ..pattern_matcher import (
CallFunctionVarArgs,
Match,
PatternMatcherPass,
register_graph_pattern,
)
from ..virtualized import V
log = logging.getLogger(__name__)
patterns = PatternMatcherPass()
aten = torch.ops.aten
def replace_random_passes(gm: torch.fx.GraphModule):
"""Modify the given FX graph to use backend-native random ops"""
if config.fallback_random:
return 0
count = patterns.apply(gm)
with GraphTransformObserver(gm, "fuse_seed_creation_pass"):
count += fuse_seed_creation_pass(gm.graph)
return count
def fuse_seed_creation_pass(graph: torch.fx.Graph):
"""
Horizontally fuse all the seed generation on each device
a = inductor_seed(dev)
b = inductor_seed(dev)
Becomes:
seeds = inductor_seeds(2, dev)
a = inductor_lookup_seed(seeds, 0)
b = inductor_lookup_seed(seeds, 1)
We do this because seed creation is entirely launch overhead bound.
"""
device_seeds = collections.defaultdict(list)
for node in graph.nodes:
if CallFunctionVarArgs(inductor_prims.seed).match(node):
device_seeds[node.args[0]].append(node)
if not device_seeds:
return 0
for device, seeds in device_seeds.items():
with graph.inserting_before(seeds[0]):
combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
with V.fake_mode:
combined.meta["val"] = torch.empty(
[len(seeds)], device=device, dtype=torch.int64
)
combined.meta["tensor_meta"] = _extract_tensor_metadata(
combined.meta["val"]
)
for idx, seed in enumerate(seeds):
with graph.inserting_before(seed):
new_seed = graph.call_function(
inductor_prims.lookup_seed, (combined, idx)
)
seed.replace_all_uses_with(new_seed)
new_seed.meta.update(seed.meta)
graph.erase_node(seed)
return len(device_seeds)
def default_kwargs(device):
return {}
def get_device(device):
if device is not None:
return device
return torch.empty([]).device # default device
@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
def replace_random(
match: Match,
size,
*,
generator=None,
dtype=None,
device=None,
layout=None,
pin_memory=None,
):
if generator is not None:
return
def replacement(size):
result = inductor_prims.random(
size, inductor_prims.seed(device), mode, **default_kwargs(device)
)
if dtype is not None:
result = result.to(dtype)
return result
mode = {
aten.rand: "rand",
aten.randn: "randn",
}[
match.output_node().target.overloadpacket # type: ignore[union-attr]
] # type: ignore[union-attr]
device = get_device(device)
match.replace_by_example(replacement, [size])
@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
def replace_randint(
match: Match,
low,
high,
size,
*,
dtype=torch.int64,
device=None,
layout=None,
pin_memory=None,
):
def replacement(low, high, size):
result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
return result.to(dtype)
device = get_device(device)
match.replace_by_example(replacement, [low, high, size])
|