1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
import math
from typing import Any, Dict, List
import sympy
import torch
from torch.utils._sympy.value_ranges import ValueRanges
from .loop_body import LoopBody
from .utils import dominated_nodes
def val_expressable_in_32_bits(val: Any) -> bool:
if getattr(val, "is_Boolean", False):
return True
if isinstance(val, sympy.Expr):
assert val.is_number
if val.is_Integer or val.is_Boolean:
val = int(val)
else:
val = float(val)
# bound within mantissa
if isinstance(val, float):
return val <= (2**24) and val >= -(2**24)
if isinstance(val, int):
iinfo = torch.iinfo(torch.int32)
return val <= iinfo.max and val >= iinfo.min
raise TypeError(f"Unexpected value {val}")
def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool:
return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits(
range.upper
)
def try_to_reduce_precision(
node: Any,
bounds: Dict[Any, Any],
indirect_vars: List[Any],
indices: Dict[Any, sympy.Expr],
replacement_vals: Dict[Any, ValueRanges[sympy.Expr]],
) -> None:
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
# then it's precision is set for that chain of uses, and we don't need to consider those
# dominated values
def skip_filter(node: Any) -> bool:
return node.target == "to_dtype" and node.args[2] in (
torch.int32,
torch.float32,
torch.float64,
)
# TODO - there are dominated uses whose dtype does not depend on whether
# we reduce the precision here, e.g. add(int64, int64) one of the args can be reduced to
# int32 without changing the output precision of the node. this case hasn't shown up
for dominated in dominated_nodes([node], skip_filter):
if dominated.target in ["store", "output"]:
continue
if isinstance(dominated.target, str) and "set_indirect" in dominated.target:
idx = int(dominated.target[len("set_indirect") :])
indirect_var = indirect_vars[idx]
# We check that we can compute all the indices it's involved in with int32
for index, expr in indices.items():
if indirect_var in expr.free_symbols:
index_val = replacement_vals[index]
if math.isinf(index_val.lower) or math.isinf(index_val.upper):
return
# all indices are integers, so make sure that we
# use the bounds of integers instead of floats.
# TODO - not sure if we should be doing int/float casts while tracing,
# might interfere with sympy.
index_val_int = ValueRanges[sympy.Expr](
int(index_val.lower), int(index_val.upper)
)
if not range_expressable_in_32_bits(index_val_int):
return
if not range_expressable_in_32_bits(bounds[dominated]):
return
args = list(node.args)
args[2] = torch.int32
node.args = tuple(args)
def indexing_dtype_strength_reduction(loop_body: LoopBody) -> None:
"""
Performs Value Range Analysis on LoopBody's fx graph to reduce precision of
intermediaries from int64 to int32
"""
bv = loop_body.bounds()
int64_dtype_nodes = [
node
for node in loop_body.get_nodes()
if (
node.target == "to_dtype"
and node.args[2] == torch.int64
and node not in bv.unbounded_vars
)
]
if not int64_dtype_nodes:
return
bounds = bv.get_bounds()
# TODO - if dominated node of one to_dtype is not expressible in int32,
# we should short circuit another to_dtype node if that node also dominates
for node in int64_dtype_nodes:
try_to_reduce_precision(
node,
bounds,
loop_body.indirect_vars,
loop_body.indexing_exprs,
bv.replacement_vals,
)
|