1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
|
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms, tv_tensors
from torchvision.transforms.v2 import functional as F
from ._transform import _RandomApplyTransform, Transform
from ._utils import _check_sequence_input, _parse_labels_getter, has_any, is_pure_tensor, query_chw, query_size
class RandomErasing(_RandomApplyTransform):
"""Randomly select a rectangle region in the input image or video and erase its pixels.
This transform does not support PIL Image.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p (float, optional): probability that the random erasing operation will be performed.
scale (tuple of float, optional): range of proportion of erased area against input image.
ratio (tuple of float, optional): range of aspect ratio of erased area.
value (number or tuple of numbers): erasing value. Default is 0. If a single int, it is used to
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace (bool, optional): boolean to make this transform inplace. Default set to False.
Returns:
Erased input.
Example:
>>> from torchvision.transforms import v2 as transforms
>>>
>>> transform = transforms.Compose([
>>> transforms.RandomHorizontalFlip(),
>>> transforms.PILToTensor(),
>>> transforms.ConvertImageDtype(torch.float),
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
>>> transforms.RandomErasing(),
>>> ])
"""
_v1_transform_cls = _transforms.RandomErasing
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
return dict(
super()._extract_params_for_v1_transform(),
value="random" if self.value is None else self.value,
)
def __init__(
self,
p: float = 0.5,
scale: Sequence[float] = (0.02, 0.33),
ratio: Sequence[float] = (0.3, 3.3),
value: float = 0.0,
inplace: bool = False,
):
super().__init__(p=p)
if not isinstance(value, (numbers.Number, str, tuple, list)):
raise TypeError("Argument value should be either a number or str or a sequence")
if isinstance(value, str) and value != "random":
raise ValueError("If value is str, it should be 'random'")
if not isinstance(scale, Sequence):
raise TypeError("Scale should be a sequence")
if not isinstance(ratio, Sequence):
raise TypeError("Ratio should be a sequence")
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("Scale and ratio should be of kind (min, max)")
if scale[0] < 0 or scale[1] > 1:
raise ValueError("Scale should be between 0 and 1")
self.scale = scale
self.ratio = ratio
if isinstance(value, (int, float)):
self.value = [float(value)]
elif isinstance(value, str):
self.value = None
elif isinstance(value, (list, tuple)):
self.value = [float(v) for v in value]
else:
self.value = value
self.inplace = inplace
self._log_ratio = torch.log(torch.tensor(self.ratio))
def _call_kernel(self, functional: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any:
if isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)):
warnings.warn(
f"{type(self).__name__}() is currently passing through inputs of type "
f"tv_tensors.{type(inpt).__name__}. This will likely change in the future."
)
return super()._call_kernel(functional, inpt, *args, **kwargs)
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
img_c, img_h, img_w = query_chw(flat_inputs)
if self.value is not None and not (len(self.value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
)
area = img_h * img_w
log_ratio = self._log_ratio
for _ in range(10):
erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
aspect_ratio = torch.exp(
torch.empty(1).uniform_(
log_ratio[0], # type: ignore[arg-type]
log_ratio[1], # type: ignore[arg-type]
)
).item()
h = int(round(math.sqrt(erase_area * aspect_ratio)))
w = int(round(math.sqrt(erase_area / aspect_ratio)))
if not (h < img_h and w < img_w):
continue
if self.value is None:
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
else:
v = torch.tensor(self.value)[:, None, None]
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
break
else:
i, j, h, w, v = 0, 0, img_h, img_w, None
return dict(i=i, j=j, h=h, w=w, v=v)
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["v"] is not None:
inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace)
return inpt
class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
self.num_classes = num_classes
self._labels_getter = _parse_labels_getter(labels_getter)
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
flat_inputs, spec = tree_flatten(inputs)
needs_transform_list = self._needs_transform_list(flat_inputs)
if has_any(flat_inputs, PIL.Image.Image, tv_tensors.BoundingBoxes, tv_tensors.Mask):
raise ValueError(f"{type(self).__name__}() does not support PIL images, bounding boxes and masks.")
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
if labels.ndim not in (1, 2):
raise ValueError(
f"labels should be index based with shape (batch_size,) "
f"or probability based with shape (batch_size, num_classes), "
f"but got a tensor of shape {labels.shape} instead."
)
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
raise ValueError(
f"When passing 2D labels, "
f"the number of elements in last dimension must match num_classes: "
f"{labels.shape[-1]} != {self.num_classes}. "
f"You can Leave num_classes to None."
)
if labels.ndim == 1 and self.num_classes is None:
raise ValueError("num_classes must be passed if the labels are index-based (1D)")
params = {
"labels": labels,
"batch_size": labels.shape[0],
**self.make_params(
[inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
),
}
# By default, the labels will be False inside needs_transform_list, since they are a torch.Tensor coming
# after an image or video. However, we need to handle them in _transform, so we make sure to set them to True
needs_transform_list[next(idx for idx, inpt in enumerate(flat_inputs) if inpt is labels)] = True
flat_outputs = [
self.transform(inpt, params) if needs_transform else inpt
for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
]
return tree_unflatten(flat_outputs, spec)
def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
expected_num_dims = 5 if isinstance(inpt, tv_tensors.Video) else 4
if inpt.ndim != expected_num_dims:
raise ValueError(
f"Expected a batched input with {expected_num_dims} dims, but got {inpt.ndim} dimensions instead."
)
if inpt.shape[0] != batch_size:
raise ValueError(
f"The batch size of the image or video does not match the batch size of the labels: "
f"{inpt.shape[0]} != {batch_size}."
)
def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
if label.ndim == 1:
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
class MixUp(_BaseMixUpCutMix):
"""Apply MixUp to the provided batch of images and labels.
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=lam)
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
return inpt
class CutMix(_BaseMixUpCutMix):
"""Apply CutMix to the provided batch of images and labels.
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
<https://arxiv.org/abs/1905.04899>`_.
.. note::
This transform is meant to be used on **batches** of samples, not
individual images. See
:ref:`sphx_glr_auto_examples_transforms_plot_cutmix_mixup.py` for detailed usage
examples.
The sample pairing is deterministic and done by matching consecutive
samples in the batch, so the batch needs to be shuffled (this is an
implementation detail, not a guaranteed convention.)
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
into a tensor of shape ``(batch_size, num_classes)``.
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
It can also be a callable that takes the same input as the transform, and returns the labels.
"""
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]
H, W = query_size(flat_inputs)
r_x = torch.randint(W, size=(1,))
r_y = torch.randint(H, size=(1,))
r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)
x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
return dict(box=box, lam_adjusted=lam_adjusted)
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if inpt is params["labels"]:
return self._mixup_label(inpt, lam=params["lam_adjusted"])
elif isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)) or is_pure_tensor(inpt):
self._check_image_or_video(inpt, batch_size=params["batch_size"])
x1, y1, x2, y2 = params["box"]
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
if isinstance(inpt, (tv_tensors.Image, tv_tensors.Video)):
output = tv_tensors.wrap(output, like=inpt)
return output
else:
return inpt
class JPEG(Transform):
"""Apply JPEG compression and decompression to the given images.
If the input is a :class:`torch.Tensor`, it is expected
to be of dtype uint8, on CPU, and have [..., 3 or 1, H, W] shape,
where ... means an arbitrary number of leading dimensions.
Args:
quality (sequence or number): JPEG quality, from 1 to 100. Lower means more compression.
If quality is a sequence like (min, max), it specifies the range of JPEG quality to
randomly select from (inclusive of both ends).
Returns:
image with JPEG compression.
"""
def __init__(self, quality: Union[int, Sequence[int]]):
super().__init__()
if isinstance(quality, int):
quality = [quality, quality]
else:
_check_sequence_input(quality, "quality", req_sizes=(2,))
if not (1 <= quality[0] <= quality[1] <= 100 and isinstance(quality[0], int) and isinstance(quality[1], int)):
raise ValueError(f"quality must be an integer from 1 to 100, got {quality =}")
self.quality = quality
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
quality = torch.randint(self.quality[0], self.quality[1] + 1, ()).item()
return dict(quality=quality)
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return self._call_kernel(F.jpeg, inpt, quality=params["quality"])
|