1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634
|
# mypy: allow-untyped-defs
# ruff: noqa: TCH004
import functools
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar
import torch
from torch.utils._contextlib import _DecoratorContextManager
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import trace_rules, variables
from .comptime import comptime
from .eval_frame import (
_set_stance,
DisableContext,
DynamoStance,
innermost_fn,
RunOnlyContext,
)
from .exc import IncorrectUsage
from .external_utils import is_compiling
from .utils import is_function
if TYPE_CHECKING:
from types import FunctionType
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_eval_frame,
set_guard_error_hook,
skip_code,
unsupported,
)
from .variables import VariableTracker
else:
for name in dir(torch._C._dynamo.eval_frame):
if name.startswith("__"):
continue
globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
_F = TypeVar("_F", bound=Callable[..., Any])
def run(fn=None):
"""Don't do any dynamic compiles, just use prior optimizations"""
if fn is not None:
fn = innermost_fn(fn)
assert callable(fn)
return RunOnlyContext()(fn)
return RunOnlyContext()
def disable(fn=None, recursive=True):
"""
Decorator to disable TorchDynamo
If recursive=True, Dynamo is completely skipped on the decorated function
frame as well as the recursively invoked functions.
If recursive=False, Dynamo skips frames associated with the function code,
but still process recursively invoked frames.
"""
if recursive:
if fn is not None:
fn = innermost_fn(fn)
assert callable(fn)
return DisableContext()(fn)
return DisableContext()
else:
return skip(fn)
def skip(fn=None):
"""
Skip frames associated with the function code, but still process recursively
invoked frames
"""
if fn is None:
return skip
fn = innermost_fn(fn)
assert callable(fn)
skip_code(fn.__code__)
fn._torchdynamo_disable = True
return fn
class set_stance(_DecoratorContextManager):
"""
Decorator, context manager, function to set the current stance of the compiler.
Stances documented in corresponding function in torch/compiler/__init__.py
"""
_dynamo_forbidden = True
def __init__(
self,
stance: str = "default",
*,
skip_guard_eval_unsafe: bool = False,
force_backend=None,
) -> None:
if force_backend is not None and stance != "default":
raise RuntimeError("non-default stance cannot have force_backend set")
self.stance = DynamoStance(stance, skip_guard_eval_unsafe, force_backend)
self.prev = _set_stance(self.stance)
def __call__(self, fn):
_set_stance(self.prev)
wrapper = super().__call__(fn)
# forbid wrapper in graph
wrapper._dynamo_forbidden = True # type: ignore[attr-defined]
return wrapper
def __enter__(self):
_set_stance(self.stance)
def __exit__(self, exc_type, exc_val, exc_tb):
_set_stance(self.prev)
def clone(self):
return self.__class__(self.stance.stance, force_backend=self.stance.backend)
def assume_constant_result(fn):
fn._dynamo_marked_constant = True
return fn
def allow_in_graph(fn):
"""
Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
and instead directly write it to the graph when encountered.
See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation
WARNING: this API can be a footgun, please read the documentation carefully.
"""
if isinstance(fn, (list, tuple)):
return [allow_in_graph(x) for x in fn]
assert callable(fn), "allow_in_graph expects a callable"
if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
trace_rules._disallowed_callable_ids.remove(id(fn))
trace_rules._allowed_callable_ids.add(id(fn))
return fn
def _disallow_in_graph_helper(throw_if_not_allowed):
def inner(fn):
if isinstance(fn, (list, tuple)):
return [disallow_in_graph(x) for x in fn]
assert callable(fn), "disallow_in_graph expects a callable"
if (
throw_if_not_allowed
and trace_rules.lookup_callable(fn)
!= variables.TorchInGraphFunctionVariable
and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
):
raise IncorrectUsage(
"disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "
"Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
)
trace_rules._allowed_callable_ids.remove(id(fn))
trace_rules._disallowed_callable_ids.add(id(fn))
return fn
return inner
def disallow_in_graph(fn):
"""
Customize which functions TorchDynamo will exclude in the generated
graph and force a graph break on.
::
torch._dynamo.disallow_in_graph(torch.sub)
@torch._dynamo.optimize(...)
def fn(a):
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
return x
fn(...)
Will break the graph on `torch.sub`, and give two graphs each with a
single `torch.add()` op.
"""
return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn)
@_disallow_in_graph_helper(throw_if_not_allowed=False)
def graph_break():
"""Force a graph break"""
def forbid_in_graph(fn):
"""
Customize which functions TorchDynamo will assert are not present while tracing.
If you want a graph break on this function instead, use disallow_in_graph.
TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust
documentation would not be amiss.
"""
if isinstance(fn, (list, tuple)):
return [forbid_in_graph(x) for x in fn]
assert callable(fn), "forbid_in_graph applies only to callables"
fn._dynamo_forbidden = True
return fn
def substitute_in_graph(
original_fn: _F,
*,
can_constant_fold_through: bool = False,
skip_signature_check: bool = False,
# type that is embedded in the Python interpreter
is_embedded_type: bool = False, # internal use only
) -> Callable[[_F], _F]:
"""
Register a polyfill handler for a function, usually a C function from the C extension, to be
used in place of the original function when inlining the original function in the graph.
.. note::
The polyfill handler is only used when inlining the original function. It is not used when
the original function is called directly. In the eager mode, the decorated function calls
the performant C function rather than the polyfill handler.
The polyfill handler is a function that will be called in place of the original function when
inlining the original function. The polyfill handler should have the same signature and the same
behavior as the original function.
Args:
original_fn (callable): The original function, usually a C function, to register a polyfill
handler for.
can_constant_fold_through (bool, optional): Whether the polyfill handler can be constant
folded through. That is, if the polyfill handler is a pure function and its arguments
are constant, the result of the polyfill handler can be constant folded during the
compilation. Defaults to ``False``.
skip_signature_check (bool, optional): Whether to skip the signature check between the
original function and the polyfill handler. Defaults to ``False``.
Returns:
A decorator that registers the polyfill handler for the original function.
Example::
>>> # xdoctest: +SKIP("conflict with the tests: duplicate polyfill handlers")
>>> import operator
>>> operator.indexOf([1, 2, 3, 4, 5], 3)
2
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
Traceback (most recent call last):
...
torch._dynamo.exc.Unsupported: ...
>>> @torch.compiler.substitute_in_graph(operator.indexOf)
... def indexOf(a, b, /):
... for i, item in enumerate(a):
... if item is b or item == b:
... return i
... raise ValueError("sequence.index(x): x not in sequence")
>>>
>>> torch.compile(operator.indexOf, fullgraph=True)([1, 2, 3, 4, 5], 3)
2
"""
if not is_function(original_fn) and not (
is_embedded_type and inspect.isclass(original_fn)
):
raise TypeError(
f"substitute_in_graph expects a function but got {type(original_fn)!r}"
)
if is_embedded_type:
if not inspect.isclass(original_fn):
raise TypeError(
f"substitute_in_graph expects a class but got {type(original_fn)!r}"
)
from .variables.builder import ITERTOOLS_POLYFILLED_TYPE_IDS, ITERTOOLS_TYPE_IDS
if id(original_fn) in ITERTOOLS_TYPE_IDS:
ITERTOOLS_POLYFILLED_TYPE_IDS.add(id(original_fn))
def wrapper(traceable_fn: _F) -> _F:
if not is_function(traceable_fn):
raise TypeError(
f"@substitute_in_graph(...) expects a function but got {type(traceable_fn)!r}"
)
if not skip_signature_check:
try:
original_sig = inspect.signature(original_fn)
except ValueError:
pass
else:
traceable_sig = inspect.signature(traceable_fn)
def sig_ident(sig):
# Ignore annotations for parameters and return type
return (
tuple(
p.name
for p in sig.parameters.values()
if (
p.kind
not in {
p.KEYWORD_ONLY,
# the name of *args and **kwargs is not important
p.VAR_POSITIONAL,
p.VAR_KEYWORD,
}
)
),
{
p.name
for p in sig.parameters.values()
if p.kind == p.KEYWORD_ONLY
},
{
p.name: p.default
for p in sig.parameters.values()
# the name of *args and **kwargs is not important
if p.kind not in {p.VAR_POSITIONAL, p.VAR_KEYWORD}
},
)
wildcard_sig = inspect.signature(lambda *args, **kwargs: None)
if (
sig_ident(original_sig) != sig_ident(traceable_sig)
and sig_ident(original_sig) != sig_ident(wildcard_sig)
and sig_ident(traceable_sig) != sig_ident(wildcard_sig)
):
raise TypeError(
f"Signature mismatch between {original_fn} and {traceable_fn}: "
f"{original_sig} != {traceable_sig}"
)
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.trace_rules import (
_polyfilled_function_ids,
get_torch_obj_rule_map,
)
from torch._dynamo.variables import PolyfilledFunctionVariable
from torch._dynamo.variables.builder import VariableBuilder
id_dispatch_map = VariableBuilder._id_dispatch()
if id(original_fn) in id_dispatch_map:
raise ValueError(
f"Duplicate dispatch rule for {original_fn}: "
"already registered in VariableBuilder's id dispatch map"
)
if id(original_fn) in _polyfilled_function_ids:
raise ValueError(f"Duplicate polyfilled object {original_fn}")
rule_map: Dict[Any, Type[VariableTracker]] = get_torch_obj_rule_map()
if original_fn in rule_map:
raise ValueError(
f"Duplicate object {original_fn} with different rules: "
f"{PolyfilledFunctionVariable}, {rule_map[original_fn]}"
)
polyfill_handlers: Dict[Callable[..., Any], FunctionType]
polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers()
if original_fn in polyfill_handlers:
raise ValueError(
f"Duplicate polyfill handlers for {original_fn}: "
f"already handled by {polyfill_handlers[original_fn]}"
)
# Need to wrap the function because we may cannot assign __torch_dynamo_polyfill__ to a
# C++ function.
@functools.wraps(traceable_fn)
def wrapped(*args, **kwargs):
return original_fn(*args, **kwargs)
def dispatch_fn(self, value: _F) -> PolyfilledFunctionVariable:
return PolyfilledFunctionVariable(
value,
source=self.source,
**self.install_guards(GuardBuilder.FUNCTION_MATCH),
)
id_dispatch_map[id(original_fn)] = id_dispatch_map[id(wrapped)] = dispatch_fn
_polyfilled_function_ids.add(id(original_fn))
_polyfilled_function_ids.add(id(wrapped))
rule_map[original_fn] = rule_map[wrapped] = PolyfilledFunctionVariable
polyfill_handlers[original_fn] = polyfill_handlers[wrapped] = wrapped # type: ignore[assignment]
wrapped.__torch_dynamo_original__ = original_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_polyfill__ = traceable_fn # type: ignore[attr-defined]
wrapped.__torch_dynamo_can_constant_fold_through__ = can_constant_fold_through # type: ignore[attr-defined]
return wrapped # type: ignore[return-value]
return wrapper
# Helper function to flatten a tensor subclass and apply a function to
# all inner tensors that match the outer dim. Used to reduce duplication
# across the various marking APIs.
def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__()
assert isinstance(t, torch.Tensor)
for attr in attrs:
inner = getattr(t, attr)
if inner.dim() == t.dim():
func(inner, *args, **kwargs)
@dataclass(frozen=True)
class _DimRange:
"""
This represents an dimension of a tensor and the corresponding
min and max values it can take. Don't create this
class directly; instead, use :func:`mark_dynamic`.
"""
dim: int
min: int
max: int
@forbid_in_graph
def mark_unbacked(t, index):
"""
Mark a tensor as having an unbacked dim. This changes the semantics of operations,
we will always report the size does not equal zero/one, we will turn asserts
on this index into runtime asserts, and if you try to get the real value we will
raise an exception. In other words, we will treat this dimension as if it was
data dependent (we do not know anything about its value.)
"""
# You could have copied the mark_dynamic behavior but I'm not convinced
# it's what you want
assert not is_traceable_wrapper_subclass(t), "not implemented yet"
if isinstance(index, int):
if not hasattr(t, "_dynamo_unbacked_indices"):
t._dynamo_unbacked_indices = set()
t._dynamo_unbacked_indices.add(index)
return
assert isinstance(index, (list, tuple))
for i in index:
mark_unbacked(t, i)
@forbid_in_graph
def mark_dynamic(t, index, *, min=None, max=None):
"""
Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
[Note - on the state of mark_dynamic]
The behavior of having a dynamic dimension on a tensor is governed by a few factors:
1) torch._dynamo.config dynamic_shapes True or False.
a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work.
a) dynamic_shapes=False - This config will raise an exception when used in conjunction with
mark_dynamic. We will eventually support this.
2) If the dimension is fully constrained - as in, it does not allow more than a single value
in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export),
we will raise an error
3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded
range of shapes, in eager we will pass it through, but export will raise an error.
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
before torch.compile.
"""
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(
mark_dynamic, t, index, min=min, max=max
)
if isinstance(index, int):
if not hasattr(t, "_dynamo_dynamic_indices"):
t._dynamo_dynamic_indices = set()
t._dynamo_dynamic_range = set()
# TODO(voz): Should we bounds check?
t._dynamo_dynamic_indices.add(index)
t._dynamo_dynamic_range.add(_DimRange(index, min, max))
return
assert isinstance(index, (list, tuple))
for i in index:
mark_dynamic(t, i, min=min, max=max)
@forbid_in_graph
def maybe_mark_dynamic(t, index):
"""
Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
dimension ends up getting specialized, don't error).
"""
if is_traceable_wrapper_subclass(t):
# default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index)
if isinstance(index, int):
if not hasattr(t, "_dynamo_weak_dynamic_indices"):
t._dynamo_weak_dynamic_indices = set()
# TODO(voz): Should we bounds check?
t._dynamo_weak_dynamic_indices.add(index)
return
assert isinstance(index, (list, tuple))
for i in index:
maybe_mark_dynamic(t, i)
def mark_static(t, index=None):
"""
Mark a tensor as having a static dim or mark a nn module class as static.
For tensors
===========
This will prevent us from attempting to compile it dynamically
when dynamic=True; this can improve trace-time performance.
This has lower precedence than mark_dynamic.
Unlike mark_dynamic, this can be done inside a graph, in which case it
induces specialization on the tensor.
For nn.Module classes
=====================
For static nn.Module classes, TorchDynamo assumes that the module instance
attributes will not be modified after compilation. This will ensure that
TorchDynamo keeps integer attributes CONSTANT and not symints.
From TorchDynamo implementation side, the instances of static-marked
nn.Module class will be converted to UnspecializedBuiltinNNModuleVariable,
which have the same properties.
Note that we still have to guard on the attributes, because different
instances of the nn.Module can have different values of the attributes. The
key point here is that the attributes are static.
"""
if is_compiling():
if index is None:
for s in t.size():
comptime.force_static(s)
else:
comptime.force_static(t.size(index))
return
if is_traceable_wrapper_subclass(t):
# default behavior: mirror mark_static() on all inner tensors with same dim as t
# TODO: Make this configurable via a supported public API
_apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
if not isinstance(t, torch.Tensor) and issubclass(t, torch.nn.Module):
t._dynamo_marked_static = True
return t
if not isinstance(t, torch.Tensor):
raise TypeError(
f"mark_static expects a tensor/nn.Module class but recieved {type(t)}"
)
if isinstance(index, int):
if not hasattr(t, "_dynamo_static_indices"):
t._dynamo_static_indices = set() # type: ignore[attr-defined]
# TODO(voz): Should we bounds check?
t._dynamo_static_indices.add(index) # type: ignore[attr-defined]
elif index is None:
for i in range(t.dim()):
mark_static(t, i)
else:
assert isinstance(index, (list, tuple))
for i in index:
mark_static(t, i)
@forbid_in_graph
def mark_static_address(t, guard=True):
"""
Marks an input tensor whose data_ptr will not change across multiple calls
to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
is not needed for this input. The data_ptr will be guarded if guard=True. Note:
Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
"""
if not isinstance(t, torch.Tensor):
raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}")
if guard:
t._dynamo_static_input_type = "guarded" # type: ignore[attr-defined]
else:
t._dynamo_static_input_type = "unguarded" # type: ignore[attr-defined]
# Note: this carefully avoids eagerly import einops.
# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
def _allow_in_graph_einops():
import einops
try:
# requires einops > 0.6.1, torch >= 2.0
from einops._torch_specific import ( # type: ignore[attr-defined] # noqa: F401
_ops_were_registered_in_torchdynamo,
)
# einops > 0.6.1 will call the op registration logic as it is imported.
except ImportError:
# einops <= 0.6.1
allow_in_graph(einops.rearrange)
allow_in_graph(einops.reduce)
if hasattr(einops, "repeat"):
allow_in_graph(einops.repeat) # available since einops 0.2.0
if hasattr(einops, "einsum"):
allow_in_graph(einops.einsum) # available since einops 0.5.0
if hasattr(einops, "pack"):
allow_in_graph(einops.pack) # available since einops 0.6.0
if hasattr(einops, "unpack"):
allow_in_graph(einops.unpack) # available since einops 0.6.0
trace_rules.add_module_init_func("einops", _allow_in_graph_einops)
|