1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
# mypy: allow-untyped-defs
import copy
from collections import defaultdict
import torch
from torch._dynamo.source import LocalSource
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.shape_inference.infer_symbol_values import (
infer_symbol_values,
)
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.utils import _pytree
"""
This is the function that runs shape inference. It will modify the input graph module so that shapes are annotated.
"""
def infer_shape(gm, input_tensors):
# Prepare environments
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True)
flatten_inputs, spec = _pytree.tree_flatten(input_tensors)
dim_count = 1
for input_tensor in flatten_inputs:
dim_count += input_tensor.dim() - 1
sample = {f"s{i}": 2 for i in range(dim_count)}
init_symints = [
mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC)
for k, v in sample.items()
]
symints = copy.deepcopy(init_symints)
symbol_to_idx_dict = {f"s{i}": i for i in range(dim_count)}
padding_constraints = defaultdict(list) # type: ignore[var-annotated]
complete_flag = False
allowed_try_times = dim_count * 2
while not complete_flag and allowed_try_times > 0:
# Create symbolic input tensors
with fake_mode:
sym_tensors = []
i = 1
for input_tensor in flatten_inputs:
curr_dim = input_tensor.dim()
desired_size = [symints[0]] + [
symints[ii] for ii in range(i, i + curr_dim - 1)
]
sym_tensor = torch.randn(desired_size)
sym_tensors.append(sym_tensor)
i += curr_dim - 1
sym_tensors = _pytree.tree_unflatten(sym_tensors, spec)
try:
with fake_mode:
make_fx(
gm,
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
pre_dispatch=True,
_allow_fake_constant=True,
)(*sym_tensors)
complete_flag = True
return (gm, input_tensors, fake_mode, symints[0])
except RuntimeError as e:
if e:
infer_symbol_values(
symints,
init_symints,
symbol_to_idx_dict,
padding_constraints,
str(e),
)
allowed_try_times -= 1
except ValueError as e:
if e:
infer_symbol_values(
symints,
init_symints,
symbol_to_idx_dict,
padding_constraints,
str(e),
)
allowed_try_times -= 1
def mksym(shape_env, value, source, dynamic_dim):
return shape_env.create_symintnode(
shape_env.create_symbol(
value,
source=source,
dynamic_dim=dynamic_dim,
),
hint=value,
source=source,
)
|