1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
|
# mypy: allow-untyped-defs
from typing import Callable, List, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_run_with_interpreter,
_set_compilation_env,
autograd_not_implemented,
diff_tensor_meta,
reenter_make_fx,
UnsupportedAliasMutationException,
validate_subgraph_args_types,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
_temp_remove_metadata_torch_function_mode,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.fx.passes.shape_prop import _extract_tensor_metadata
class WhileLoopOp(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("while_loop")
def __call__(
self,
cond_fn: Callable,
body_fn: Callable,
carried_inputs: Tuple[Union[torch.Tensor, int, float, bool]],
additional_inputs: Tuple[Union[torch.Tensor, torch.SymInt, int], ...],
/,
):
if not isinstance(carried_inputs, tuple):
raise RuntimeError(
f"carried_inputs must be a tuple, got {type(carried_inputs)}"
)
if not isinstance(additional_inputs, tuple):
raise RuntimeError(
f"additional_inputs must be a tuple, got {type(additional_inputs)}"
)
validate_subgraph_args_types(carried_inputs)
validate_subgraph_args_types(additional_inputs)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
while_loop_op = WhileLoopOp()
def while_loop(cond_fn, body_fn, carried_inputs):
r"""
Run body_fn(*carried_inputs) while cond_fn(*carried_inputs) returns a True scalar tensor. Returns the output of body_fn or
initial carried_inputs.
.. warning::
`torch.while_loop` is a prototype feature in PyTorch. It has limited support for input and output types and
doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
`while_loop` is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
`while_loop` is equivalent to the following:
def while_loop(cond_fn, body_fn, carried_inputs):
val = carried_inputs
while cond_fn(*val):
val = body_fn(*val)
return val
Args:
cond_fn (Callable): A callable function that returns a boolean Scalar tensor.
body_fn (Callable): A callable function that takes the same inputs as `cond_fn` and returns a tuple of tensors
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors): A tuple of inputs to cond_fn and body_fn. It's also
the initial value of states that are carried across iterations.
Example:
def cond_fn(iter, x):
return iter.sum() < 10
def body_fn(iter, x):
return iter + 1, x.sin()
while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
Restrictions:
- body_fn must return tensors with the same metadata (e.g.shape, dtype) as inputs.
- body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
- body_fn and cond_fn must not mutate python varialbles (e.g. list/dict) created outside of the body_fn.
- body_fn and cond_fn's output cannot aliase any of the inputs. A clone is required.
.. warning::
Temporal Limitations:
- 'while_loop' only supports **inference** right now. Autograd will be supported in the future.
"""
from torch._dynamo.backends.debugging import (
make_eager_backend_with_torch_function_mode,
)
# Currently, additional_inputs is not a user-facing input. It will be automatically set in dynamo.
# parameters and buffers accessed in cond_fn or body_fn or tensor closures will become additional_inputs.
additional_inputs: Tuple = ()
# The reason we flatten the output before calling into dynamo is that
# we want to create a consistent input ordering for cond_fn and body_fn.
# and we also want to the input ordering matches the output ordering.
# Also see NOTE: [why we cannot use "automatic" for while_loop]
# Construct flat cond_fn and flat_body_fn, which takes flattened inputs
flat_inputs, in_spec = pytree.tree_flatten((carried_inputs, additional_inputs))
def flat_cond_fn(*flat_args):
carried, additional = pytree.tree_unflatten(flat_args, in_spec)
return cond_fn(*carried, *additional)
def flat_body_fn(*flat_args):
carried, additional = pytree.tree_unflatten(flat_args, in_spec)
return body_fn(*carried, *additional)
if torch.compiler.is_dynamo_compiling():
return while_loop_op(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
def _validate_input(cond_fn, body_fn, carried_inputs):
from torch._higher_order_ops.utils import validate_subgraph_args_types
if not callable(cond_fn) or not callable(body_fn):
raise RuntimeError("Expect cond_fn and body_fn to be callable.")
validate_subgraph_args_types(flat_inputs)
if not pytree.tree_all(lambda t: isinstance(t, torch.Tensor), carried_inputs):
raise RuntimeError(
"Expect carried_inputs to be a tuple of possibly nested dict/list/tuple that only"
f"consists of tensor leaves, but got {carried_inputs}."
)
_validate_input(cond_fn, body_fn, carried_inputs)
# Dynamo is expecting a callable with "__code__" attribute.
# We cannot directly pass cond_op to it. So we wrap it in a dummy function.
def _while_loop_op_wrapper(*args, **kwargs):
return while_loop_op(*args, **kwargs)
with _set_compilation_env(), torch._dynamo.utils.disable_cache_limit():
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
with _temp_remove_metadata_torch_function_mode() as metadata_mode:
if metadata_mode:
backend = make_eager_backend_with_torch_function_mode(metadata_mode)
else:
backend = "eager"
return torch.compile(
_while_loop_op_wrapper, backend=backend, fullgraph=True
)(flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple())
@while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd)
def while_loop_dense(cond_fn, body_fn, carried_inputs, additional_inputs):
carried_vals = carried_inputs
def _is_boolean_scalar_tensor(pred):
return (
isinstance(pred, torch.Tensor)
and pred.size() == torch.Size([])
and pred.dtype == torch.bool
)
if not isinstance(carried_inputs, tuple):
raise RuntimeError(
f"carried_inputs must be a tuple but got {type(carried_inputs)}"
)
while pred := cond_fn(*carried_vals, *additional_inputs):
if not _is_boolean_scalar_tensor(pred):
raise RuntimeError(
f"cond_fn must return a boolean scalar tensor but got {pred}"
)
out = body_fn(*carried_vals, *additional_inputs)
assert isinstance(
out, tuple
), f"body_fn should return a tuple but got {type(out)}"
assert len(out) == len(
carried_inputs
), "body_fn should return the same number of elements as carried_inputs"
carried_vals = out
return carried_vals
while_loop_op.py_impl(DispatchKey.Autograd)(
autograd_not_implemented(while_loop_op, deferred_error=True)
)
@while_loop_op.py_impl(ProxyTorchDispatchMode)
def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs):
def _trace_while_loop(
proxy_mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
):
cond_graph = reenter_make_fx(cond_fn)(*carried_inputs, *additional_inputs)
body_graph = reenter_make_fx(body_fn)(*carried_inputs, *additional_inputs)
next_name = None
i = 0
while not next_name:
candidate = f"while_loop_cond_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
i += 1
else:
next_name = candidate
cond_graph_name = next_name
body_graph_name = f"while_loop_body_graph_{i}"
assert not hasattr(proxy_mode.tracer.root, body_graph_name)
proxy_mode.tracer.root.register_module(cond_graph_name, cond_graph)
proxy_mode.tracer.root.register_module(body_graph_name, body_graph)
args = (cond_graph, body_graph, carried_inputs, additional_inputs)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", while_loop_op, proxy_args, {}, name="while_loop"
)
out = while_loop_op(cond_graph, body_graph, carried_inputs, additional_inputs)
return track_tensor_tree(
out, out_proxy, constant=None, tracer=proxy_mode.tracer
)
return _trace_while_loop(
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
)
def check_outputs_carry_consistency(
outs: List[torch.Tensor], carries: List[torch.Tensor]
) -> None:
all_diffs_in_meta = []
for out, cry in zip(outs, carries):
if diff := diff_tensor_meta(
_extract_tensor_metadata(cry), _extract_tensor_metadata(out)
):
all_diffs_in_meta.append(",".join(diff))
if all_diffs_in_meta:
diff_str = "\n".join(all_diffs_in_meta)
raise RuntimeError(
f"Expected carried_inputs and body outputs return tensors with same metadata but found:\n{diff_str}"
)
@while_loop_op.py_impl(FakeTensorMode)
def while_loop_fake_tensor_mode(
mode, cond_fn, body_fn, carried_inputs, additional_inputs
):
with mode:
# NOTE: [Handling unback symints created in subgraph of while_loop]
# The idea is that the scope of unbacked symints are limited to the subgraph.
#
# We're implementing the fake tensor mode of while_loop operator.
# and we run body_fn once to get an fake output.
# Let's only consider tensor output for now:
#
# Case 1:
# if the unbacked symints is local to the subgraph e.g.
# def body_fn(it, x):
# nz = x.nonzero()
# return it+1. nz.sum()
# we can just ignore the newly created unbacked symints because it has
# no effect on the output of while_loop and it's tracked when we tracing.
# the subgraph.
#
# Case 2.1:
# if the unbacked symints are part of output of while_loop e.g.
# def body_fn(it, x):
# nz = x.nonzero()
# return it+1, nz
# This will fail the shape check because in each iteration, the carried_input's shape
# must match the output shape as nz.shape contains newly allocated unbacked symint, this
# won't match the carried_input's shape.
#
# Case 2.2:
# if the unbacked symints are part of carried_inputs e.g.
# nz = a.nonzero()
# body_fn(it, nz):
# return it+1. nz.sin() + 1,
# There's no new unbacked symints allocated in subgraph, so we're safe.
with mode.shape_env.ignore_fresh_unbacked_symbols():
# body_fn return output with the same pytree and tensor meta data as carried_inputs
# so we could just return the output after one iteration.
body_outs = body_fn(*carried_inputs, *additional_inputs)
check_outputs_carry_consistency(body_outs, carried_inputs)
return body_outs
@while_loop_op.py_functionalize_impl
def while_loop_func(ctx, cond_fn, body_fn, carried_inputs, additional_inputs):
unwrapped_carried_inputs = ctx.unwrap_tensors(carried_inputs)
unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs)
unwrapped_inputs = unwrapped_carried_inputs + unwrapped_additional_inputs
with ctx.redispatch_to_next() as m:
functional_cond_fn = ctx.functionalize(_maybe_run_with_interpreter(cond_fn))
functional_body_fn = ctx.functionalize(_maybe_run_with_interpreter(body_fn))
pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch
for fn, fn_name in [
(functional_cond_fn, "cond_fn"),
(functional_body_fn, "body_fn"),
]:
if _has_potential_branch_input_mutation(
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be modifying the input!"
)
if _has_potential_branch_input_alias(
fn, unwrapped_inputs, pre_dispatch=pre_dispatch
):
raise UnsupportedAliasMutationException(
f"torch.while_loop's {fn_name} might be aliasing the input!"
)
ret = while_loop_op(
functional_cond_fn,
functional_body_fn,
unwrapped_carried_inputs,
unwrapped_additional_inputs,
)
return ctx.wrap_tensors(ret)
|