1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
|
# mypy: allow-untyped-defs
import inspect
from collections import defaultdict
from functools import lru_cache, partial, wraps
from itertools import chain
from typing import (
Callable,
Dict,
FrozenSet,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from torch.export.decomp_utils import CustomDecompTable
import torch
import torch.library
from torch._ops import HigherOrderOperator, OperatorBase, OpOverload, OpOverloadPacket
from torch._prims_common import CustomOutParamAnnotation
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.utils import _pytree as pytree
__all__ = [
"decomposition_table",
"pre_autograd_decomposition_table",
"meta_table",
"register_decomposition",
"get_decompositions",
"core_aten_decompositions",
"_should_decompose_because_unsafe_op",
]
_T = TypeVar("_T")
_P = ParamSpec("_P")
# TODO: relax key type here; torch registrations should be possible to; but
# right now this type is accurate
global_decomposition_table: Dict[
str, Dict[torch._ops.OperatorBase, Callable]
] = defaultdict(dict)
decomposition_table = global_decomposition_table["post_autograd"]
pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
meta_table = global_decomposition_table["meta"]
def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool:
"""
Returns True if the op must always decompose in export/compile tracing system
In export, we always decompose certain CIA ops that are tagged with
maybe_aliasing_or_mutating because we statically need to know if the op is
mutating or not. But these CIA ops could have different behaviour in runtime.
native_batch_norm is a prim op which has a wrong schema and it needs to be replaced
with correct schema. But until then, we will force decompose it via this tag.
"""
if not isinstance(op, torch._ops.OpOverload):
return False
if torch.Tag.maybe_aliasing_or_mutating in op.tags:
return True
return op == torch.ops.aten.native_batch_norm.default
def _add_op_to_registry(registry, op, fn):
"""
This is an internal API for adding an op to the decomposition table.
If op is OpOverload, it will be added to the registry directly.
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
"""
overloads: List[Union[torch._ops.OperatorBase]] = []
if isinstance(op, HigherOrderOperator):
# There's no concept of overloads for HigherOrderOperator
registry[op] = fn
return
elif isinstance(op, OpOverload):
overloads.append(op)
else:
assert isinstance(op, OpOverloadPacket)
for ol in op.overloads():
overloads.append(getattr(op, ol))
for op_overload in overloads:
if op_overload in registry:
raise RuntimeError(f"duplicate registrations for {op_overload}")
# TorchScript dumps a bunch of extra nonsense overloads
# which don't have corresponding dispatcher entries, we need
# to filter those out, e.g aten.add.float_int
if torch._C._dispatch_has_kernel(op_overload.name()):
registry[op_overload] = fn
def _convert_out_params(f):
out_annotation = f.__annotations__.get("out")
# If there are no out params, do not wrap the function.
if not out_annotation:
return f
# Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
if getattr(out_annotation, "__origin__", None) is tuple:
sig = inspect.signature(f)
out_names = sig.return_annotation._fields
# If out is a tuple, we need to register a function that unpacks all the out
# elements as this is what native_functions.yaml expects
@wraps(f)
def _fn(*args, **kwargs):
out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
# Either all of the out kwargs are set or none of them
is_none = out_kwargs[0] is None
assert all((o is None) == is_none for o in out_kwargs)
return f(*args, **kwargs, out=None if is_none else out_kwargs)
out_params = [
inspect.Parameter(
o,
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=t,
)
for o, t in zip(out_names, out_annotation.__args__)
]
# Drop the out parameter and concatenate the new kwargs in the signature
params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
)
# Drop the out parameter and concatenate the new kwargs in the annotations
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
for o in out_params:
_fn.__annotations__[o.name] = o.annotation
# Propagate that this function is wrapped by `out_wrapper`
_fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined]
return _fn
# Alternatively, there may be a single tensor out parameter with a name
# other than "out". This will need special treatment and is indicated by an
# annotation, which we will remove here so it is not exposed after wrapping.
custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None)
if custom_out_param_name:
@wraps(f)
def _fn(*args, **kwargs):
out_kwarg = kwargs.pop(custom_out_param_name, None)
return f(*args, **kwargs, out=out_kwarg)
out_param = inspect.Parameter(
custom_out_param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=out_annotation,
)
# Drop the out parameter and concatenate the new kwarg in the signature
sig = inspect.signature(f)
params = chain(
(v for k, v in sig.parameters.items() if k != "out"), (out_param,)
)
_fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
)
# Drop the out parameter and concatenate the new kwargs in the annotations
_fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"}
_fn.__annotations__[out_param.name] = out_param.annotation
return _fn
return f
def register_decomposition(
aten_op, registry=None, *, type="post_autograd", unsafe=False
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
"""
A decorator to register a function as a decomposition to the Python
decomposition table. Use it like this::
@register_decomposition(torch.ops.aten.clamp_min)
def clamp_min(x):
return torch.clamp(self, min=min)
If you are writing a new decomposition, consider contributing it
directly to PyTorch in torch._decomp.decompositions.
This API is experimental; we are almost certainly going to extend
the API when we make decompositions eligible for use in transforms (e.g.,
autograd) and not just backend tracing, where we then need to know if a
decomposition can be used to simulate a transform.
By default, we also will register it to the Meta key of dispatcher,
and replace the c++ Meta implementation if there is already one.
unsafe kwarg is for reuse of this function for registering non-function
things
"""
assert type in {"post_autograd", "pre_autograd", "meta"}
def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]:
orig_fn = fn
if not unsafe:
fn = _convert_out_params(fn)
nonlocal registry
if registry is None:
registry = global_decomposition_table[type]
def register(op):
_add_op_to_registry(registry, op, fn)
# To handle allowing multiple aten_ops at once
pytree.tree_map_(register, aten_op)
return orig_fn
return decomposition_decorator
def get_decompositions(
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
type: str = "post_autograd",
) -> Dict[torch._ops.OperatorBase, Callable]:
"""
Retrieve a dictionary of decompositions corresponding to the list of
operator overloads and overload packets passed as input. Overload
packets will include all decomposed overloads in the packet. If there is
no decomposition for a requested operator, it is silently ignored.
This API is experimental; we are almost certainly going to give an alternate,
more recommended formulation, where a user provides the set of operators
they know how to implement, and we provide decompositions for everything
not in this set.
"""
assert type in {"post_autograd", "pre_autograd", "meta"}
registry = global_decomposition_table[type]
packets_to_overloads = defaultdict(list)
for opo in registry:
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
for op in aten_ops:
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
for op_overload in packets_to_overloads[op]:
decompositions[op_overload] = registry[op_overload]
elif isinstance(op, (torch._ops.OperatorBase)) and op in registry:
decompositions[op] = registry[op]
return decompositions
def remove_decompositions(
decompositions: Dict[torch._ops.OperatorBase, Callable],
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
) -> None:
"""
Given a dictionary of decompositions obtained from get_decompositions(), removes
operators associated with a list of operator overloads and overload packets passed
as input. If the decomposition dictionary does not contain a decomposition that is
specified to be removed, it is silently ignored.
"""
for op in aten_ops:
if isinstance(op, OpOverloadPacket):
for overload_name in op.overloads():
opo = getattr(op, overload_name)
decompositions.pop(opo, None)
elif isinstance(op, OpOverload):
decompositions.pop(op, None)
# populate the table
import torch._decomp.decompositions
import torch._refs
def core_aten_decompositions() -> "CustomDecompTable":
from torch.export.exported_program import default_decompositions
return default_decompositions()
# See NOTE [Core ATen Ops]
#
# list was copied from torch/_inductor/decomposition.py
# excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops
def _core_aten_decompositions_post_autograd() -> (
Dict[torch._ops.OperatorBase, Callable]
):
aten = torch.ops.aten
return get_decompositions(
[
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.affine_grid_generator,
aten.alias_copy,
aten.all,
aten.aminmax,
aten.arange.default,
aten.arange.start,
aten.avg_pool2d_backward,
aten.baddbmm,
aten.binary_cross_entropy,
aten.binary_cross_entropy_backward,
aten.binary_cross_entropy_with_logits,
aten.block_diag,
aten.bernoulli.p,
aten.bernoulli.default,
aten.celu,
aten.celu_,
aten.channel_shuffle,
aten.clamp_max,
aten.clamp_min,
aten.col2im,
aten.count_nonzero,
aten.linalg_cross,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.miopen_batch_norm_backward,
aten.deg2rad,
aten.deg2rad_,
aten.detach,
aten.diag_embed,
aten.diagonal_backward,
aten.diagonal_copy,
aten.dot,
aten.vdot,
aten.elu,
aten.elu_,
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten.empty_like,
aten._euclidean_dist.default,
aten.expand_as,
aten.expand_copy,
aten.eye,
aten.fill,
aten.fill_,
aten.floor_divide,
aten.frac,
aten.frac_,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu_,
aten.gelu_backward,
aten.glu,
aten.glu_backward,
aten.hardshrink,
aten.hardsigmoid,
aten.hardsigmoid_,
aten.hardsigmoid_backward,
aten.hardswish,
aten.hardswish_,
aten.hardswish_backward,
aten.hardtanh_,
aten.hardtanh_backward,
aten.heaviside,
aten.heaviside_,
aten.huber_loss,
aten.huber_loss_backward,
aten.im2col,
aten.index_add.out,
aten.index_add.default,
aten.index_add_,
aten.index_copy.out,
aten.index_copy.default,
aten.index_copy_,
aten.index_fill.int_Scalar,
aten.index_fill.int_Tensor,
aten.index_fill.int_Scalar_out,
aten.index_fill.int_Tensor_out,
aten.index_fill_,
aten.isin,
aten.isneginf,
aten.isposinf,
aten.l1_loss,
aten._lazy_clone,
aten._test_parallel_materialize,
aten.leaky_relu_,
aten.leaky_relu_backward,
aten.lerp,
aten.lerp_,
aten.linspace,
aten.logaddexp,
aten.logaddexp2,
aten.logit,
aten.logit_,
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
aten.masked_fill,
aten.masked_fill_,
aten.max_unpool2d,
aten.max_unpool3d,
aten.mish,
aten.mish_,
aten.mse_loss,
aten.mse_loss_backward,
aten.multi_margin_loss,
aten.multilabel_margin_loss_forward,
aten.mv,
aten.mvlgamma,
aten.mvlgamma_,
aten.nansum,
aten.nan_to_num,
aten.nan_to_num_,
aten.narrow,
aten.native_batch_norm_backward,
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,
aten.new_zeros,
aten.nll_loss2d_forward,
aten.nll_loss2d_backward,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.norm.ScalarOpt_dtype,
aten.norm.Scalar,
aten.norm.ScalarOpt_dim_dtype,
aten.norm.ScalarOpt_dim,
aten.norm.dtype_out,
aten.norm.out,
aten.norm.names_dtype_out,
aten.norm.names_out,
aten.norm.ScalarOpt_dtype_out,
aten.norm.Scalar_out,
aten.ones,
aten.ones_like,
aten.pixel_shuffle,
aten.pixel_unshuffle,
aten._prelu_kernel,
aten._prelu_kernel_backward,
aten._reshape_alias,
aten.rad2deg,
aten.rad2deg_,
aten.reflection_pad1d,
aten.reflection_pad1d_backward,
aten.reflection_pad2d,
aten.reflection_pad2d_backward,
aten.reflection_pad3d,
aten.reflection_pad3d_backward,
aten.replication_pad1d,
aten.replication_pad2d,
aten.replication_pad3d,
aten.renorm,
aten.renorm_,
aten.replication_pad2d,
aten.resize_as,
aten.roll,
aten.rot90,
aten.rrelu_with_noise,
aten.rrelu_with_noise_,
aten.rsub,
aten._safe_softmax,
aten._scaled_dot_product_flash_attention_for_cpu.default,
aten.select_backward,
aten.select_scatter,
aten.sgn,
aten.sgn_,
aten.sigmoid_backward,
aten.silu,
aten.silu_,
aten.silu_backward.grad_input,
aten.sinc,
aten.sinc_,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
aten.softshrink,
aten.special_entr,
aten.special_log_ndtr,
aten.special_xlog1py,
aten.split.Tensor,
aten.split_with_sizes_copy,
aten.squeeze_copy,
aten.squeeze.default,
aten.squeeze.dim,
aten.std.correction,
aten.std.out,
aten.std.correction_out,
aten.std.names_out,
aten.std.correction_names_out,
aten.std_mean.correction,
aten.std_mean.correction_out,
aten.stack,
aten.sum.default,
aten.sum.out,
aten.t,
aten.t_copy,
aten.take,
aten.tanh_backward,
aten.threshold,
aten.threshold_,
aten.threshold_backward,
aten.trace,
aten.transpose.int,
aten.transpose_copy,
aten.tril,
aten.tril_,
aten.triu,
aten.triu_,
aten.unbind,
aten.unfold_backward,
aten.unfold_copy,
aten._unsafe_index,
aten._unsafe_index_put,
aten._unsafe_masked_index,
aten._unsafe_masked_index_put_accumulate,
aten.unsafe_split.Tensor,
aten.unsafe_split_with_sizes,
aten.unsqueeze_copy,
aten._unsafe_view,
aten.upsample_linear1d,
aten.upsample_bilinear2d.out,
aten.upsample_trilinear3d.out,
aten.upsample_nearest2d_backward,
aten.view_as_complex,
aten.xlogy,
aten.xlogy_,
aten.zero,
aten.zero_,
aten.zeros,
aten.zeros_like,
aten._chunk_cat,
aten._weight_norm_interface,
]
)
|