1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
import functools
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
import torch
from torchvision import tv_tensors
_FillType = Union[int, float, Sequence[int], Sequence[float], None]
_FillTypeJIT = Optional[List[float]]
def is_pure_tensor(inpt: Any) -> bool:
return isinstance(inpt, torch.Tensor) and not isinstance(inpt, tv_tensors.TVTensor)
# {functional: {input_type: type_specific_kernel}}
_KERNEL_REGISTRY: Dict[Callable, Dict[Type, Callable]] = {}
def _kernel_tv_tensor_wrapper(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
# If you're wondering whether we could / should get rid of this wrapper,
# the answer is no: we want to pass pure Tensors to avoid the overhead
# of the __torch_function__ machinery. Note that this is always valid,
# regardless of whether we override __torch_function__ in our base class
# or not.
# Also, even if we didn't call `as_subclass` here, we would still need
# this wrapper to call wrap(), because the TVTensor type would be
# lost after the first operation due to our own __torch_function__
# logic.
output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs)
return tv_tensors.wrap(output, like=inpt)
return wrapper
def _register_kernel_internal(functional, input_type, *, tv_tensor_wrapper=True):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise ValueError(f"Functional {functional} already has a kernel registered for type {input_type}.")
def decorator(kernel):
registry[input_type] = (
_kernel_tv_tensor_wrapper(kernel)
if issubclass(input_type, tv_tensors.TVTensor) and tv_tensor_wrapper
else kernel
)
return kernel
return decorator
def _name_to_functional(name):
import torchvision.transforms.v2.functional # noqa
try:
return getattr(torchvision.transforms.v2.functional, name)
except AttributeError:
raise ValueError(
f"Could not find functional with name '{name}' in torchvision.transforms.v2.functional."
) from None
_BUILTIN_DATAPOINT_TYPES = {
obj for obj in tv_tensors.__dict__.values() if isinstance(obj, type) and issubclass(obj, tv_tensors.TVTensor)
}
def register_kernel(functional, tv_tensor_cls):
"""Decorate a kernel to register it for a functional and a (custom) tv_tensor type.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_tv_tensors.py` for usage
details.
"""
if isinstance(functional, str):
functional = _name_to_functional(name=functional)
elif not (
callable(functional)
and getattr(functional, "__module__", "").startswith("torchvision.transforms.v2.functional")
):
raise ValueError(
f"Kernels can only be registered on functionals from the torchvision.transforms.v2.functional namespace, "
f"but got {functional}."
)
if not (isinstance(tv_tensor_cls, type) and issubclass(tv_tensor_cls, tv_tensors.TVTensor)):
raise ValueError(
f"Kernels can only be registered for subclasses of torchvision.tv_tensors.TVTensor, "
f"but got {tv_tensor_cls}."
)
if tv_tensor_cls in _BUILTIN_DATAPOINT_TYPES:
raise ValueError(f"Kernels cannot be registered for the builtin tv_tensor classes, but got {tv_tensor_cls}")
return _register_kernel_internal(functional, tv_tensor_cls, tv_tensor_wrapper=False)
def _get_kernel(functional, input_type, *, allow_passthrough=False):
registry = _KERNEL_REGISTRY.get(functional)
if not registry:
raise ValueError(f"No kernel registered for functional {functional.__name__}.")
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
elif cls is tv_tensors.TVTensor:
# We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the
# MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't
# allow kernels to be registered for tv_tensors.TVTensor anyway.
break
if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
raise TypeError(
f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, "
f"but got {input_type} instead."
)
# This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop
# We could get rid of this by letting _register_kernel_internal take arbitrary functionals rather than wrap_kernel: bool
def _register_five_ten_crop_kernel_internal(functional, input_type):
registry = _KERNEL_REGISTRY.setdefault(functional, {})
if input_type in registry:
raise TypeError(f"Functional '{functional}' already has a kernel registered for type '{input_type}'.")
def wrap(kernel):
@functools.wraps(kernel)
def wrapper(inpt, *args, **kwargs):
output = kernel(inpt, *args, **kwargs)
container_type = type(output)
return container_type(tv_tensors.wrap(o, like=inpt) for o in output)
return wrapper
def decorator(kernel):
registry[input_type] = wrap(kernel) if issubclass(input_type, tv_tensors.TVTensor) else kernel
return kernel
return decorator
|