1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
from __future__ import annotations
import enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
import PIL.Image
import torch
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.utils import _log_api_usage_once
from .functional._utils import _get_kernel
class Transform(nn.Module):
"""Base class to implement your own v2 transforms.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py` for
more details.
"""
# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
def __init__(self) -> None:
super().__init__()
_log_api_usage_once(self)
def check_inputs(self, flat_inputs: List[Any]) -> None:
pass
# When v2 was introduced, this method was private and called
# `_get_params()`. Now it's publicly exposed as `make_params()`. It cannot
# be exposed as `get_params()` because there is already a `get_params()`
# methods for v2 transforms: it's the v1's `get_params()` that we have to
# keep in order to guarantee 100% BC with v1. (It's defined in
# __init_subclass__ below).
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
"""Method to override for custom transforms.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
return dict()
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
kernel = _get_kernel(functional, type(inpt), allow_passthrough=True)
return kernel(inpt, *args, **kwargs)
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
"""Method to override for custom transforms.
See :ref:`sphx_glr_auto_examples_transforms_plot_custom_transforms.py`"""
raise NotImplementedError
def forward(self, *inputs: Any) -> Any:
"""Do not override this! Use ``transform()`` instead."""
flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0])
self.check_inputs(flat_inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
params = self.make_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [
self.transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]:
# Below is a heuristic on how to deal with pure tensor inputs:
# 1. Pure tensors, i.e. tensors that are not a tv_tensor, are passed through if there is an explicit image
# (`tv_tensors.Image` or `PIL.Image.Image`) or video (`tv_tensors.Video`) in the sample.
# 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is
# transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`
# of `tree_flatten`, which recurses depth-first through the input.
#
# This heuristic stems from two requirements:
# 1. We need to keep BC for single input pure tensors and treat them as images.
# 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface`
# return supplemental numerical data as tensors that cannot be transformed as images.
#
# The heuristic should work well for most people in practice. The only case where it doesn't is if someone
# tries to transform multiple pure tensors at the same time, expecting them all to be treated as images.
# However, this case wasn't supported by transforms v1 either, so there is no BC concern.
needs_transform_list = []
transform_pure_tensor = not has_any(flat_inputs, tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)
for inpt in flat_inputs:
needs_transform = True
if not check_type(inpt, self._transformed_types):
needs_transform = False
elif is_pure_tensor(inpt):
if transform_pure_tensor:
transform_pure_tensor = False
else:
needs_transform = False
needs_transform_list.append(needs_transform)
return needs_transform_list
def extra_repr(self) -> str:
extra = []
for name, value in self.__dict__.items():
if name.startswith("_") or name == "training":
continue
if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)):
continue
extra.append(f"{name}={value}")
return ", ".join(extra)
# This attribute should be set on all transforms that have a v1 equivalent. Doing so enables two things:
# 1. In case the v1 transform has a static `get_params` method, it will also be available under the same name on
# the v2 transform. See `__init_subclass__` for details.
# 2. The v2 transform will be JIT scriptable. See `_extract_params_for_v1_transform` and `__prepare_scriptable__`
# for details.
_v1_transform_cls: Optional[Type[nn.Module]] = None
def __init_subclass__(cls) -> None:
# Since `get_params` is a `@staticmethod`, we have to bind it to the class itself rather than to an instance.
# This method is called after subclassing has happened, i.e. `cls` is the subclass, e.g. `Resize`.
if cls._v1_transform_cls is not None and hasattr(cls._v1_transform_cls, "get_params"):
cls.get_params = staticmethod(cls._v1_transform_cls.get_params) # type: ignore[attr-defined]
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
# This method is called by `__prepare_scriptable__` to instantiate the equivalent v1 transform from the current
# v2 transform instance. It extracts all available public attributes that are specific to that transform and
# not `nn.Module` in general.
# Overwrite this method on the v2 transform class if the above is not sufficient. For example, this might happen
# if the v2 transform introduced new parameters that are not support by the v1 transform.
common_attrs = nn.Module().__dict__.keys()
return {
attr: value
for attr, value in self.__dict__.items()
if not attr.startswith("_") and attr not in common_attrs
}
def __prepare_scriptable__(self) -> nn.Module:
# This method is called early on when `torch.jit.script`'ing an `nn.Module` instance. If it succeeds, the return
# value is used for scripting over the original object that should have been scripted. Since the v1 transforms
# are JIT scriptable, and we made sure that for single image inputs v1 and v2 are equivalent, we just return the
# equivalent v1 transform here. This of course only makes transforms v2 JIT scriptable as long as transforms v1
# is around.
if self._v1_transform_cls is None:
raise RuntimeError(
f"Transform {type(self).__name__} cannot be JIT scripted. "
"torchscript is only supported for backward compatibility with transforms "
"which are already in torchvision.transforms. "
"For torchscript support (on tensors only), you can use the functional API instead."
)
return self._v1_transform_cls(**self._extract_params_for_v1_transform())
class _RandomApplyTransform(Transform):
def __init__(self, p: float = 0.5) -> None:
if not (0.0 <= p <= 1.0):
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
super().__init__()
self.p = p
def forward(self, *inputs: Any) -> Any:
# We need to almost duplicate `Transform.forward()` here since we always want to check the inputs, but return
# early afterwards in case the random check triggers. The same result could be achieved by calling
# `super().forward()` after the random check, but that would call `self.check_inputs` twice.
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
self.check_inputs(flat_inputs)
if torch.rand(1) >= self.p:
return inputs
needs_transform_list = self._needs_transform_list(flat_inputs)
params = self.make_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
)
flat_outputs = [
self.transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
|