1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
|
"""
This file is part of the private API. Please do not use directly these classes as they will be modified on
future versions without warning. The classes should be accessed only via the transforms argument of Weights.
"""
from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from . import functional as F, InterpolationMode
__all__ = [
"ObjectDetection",
"ImageClassification",
"VideoClassification",
"SemanticSegmentation",
"OpticalFlow",
]
class ObjectDetection(nn.Module):
def forward(self, img: Tensor) -> Tensor:
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
return F.convert_image_dtype(img, torch.float)
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[0.0, 1.0]``."
)
class ImageClassification(nn.Module):
def __init__(
self,
*,
crop_size: int,
resize_size: int = 256,
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.crop_size = [crop_size]
self.resize_size = [resize_size]
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
def forward(self, img: Tensor) -> Tensor:
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
img = F.center_crop(img, self.crop_size)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``."
)
class VideoClassification(nn.Module):
def __init__(
self,
*,
crop_size: Tuple[int, int],
resize_size: Tuple[int, int],
mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645),
std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.crop_size = list(crop_size)
self.resize_size = list(resize_size)
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
def forward(self, vid: Tensor) -> Tensor:
need_squeeze = False
if vid.ndim < 5:
vid = vid.unsqueeze(dim=0)
need_squeeze = True
N, T, C, H, W = vid.shape
vid = vid.view(-1, C, H, W)
vid = F.resize(vid, self.resize_size, interpolation=self.interpolation)
vid = F.center_crop(vid, self.crop_size)
vid = F.convert_image_dtype(vid, torch.float)
vid = F.normalize(vid, mean=self.mean, std=self.std)
H, W = self.crop_size
vid = vid.view(N, T, C, H, W)
vid = vid.permute(0, 2, 1, 3, 4) # (N, T, C, H, W) => (N, C, T, H, W)
if need_squeeze:
vid = vid.squeeze(dim=0)
return vid
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n crop_size={self.crop_size}"
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts batched ``(B, T, C, H, W)`` and single ``(T, C, H, W)`` video frame ``torch.Tensor`` objects. "
f"The frames are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``, "
f"followed by a central crop of ``crop_size={self.crop_size}``. Finally the values are first rescaled to "
f"``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and ``std={self.std}``. Finally the output "
"dimensions are permuted to ``(..., C, T, H, W)`` tensors."
)
class SemanticSegmentation(nn.Module):
def __init__(
self,
*,
resize_size: Optional[int],
mean: Tuple[float, ...] = (0.485, 0.456, 0.406),
std: Tuple[float, ...] = (0.229, 0.224, 0.225),
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
) -> None:
super().__init__()
self.resize_size = [resize_size] if resize_size is not None else None
self.mean = list(mean)
self.std = list(std)
self.interpolation = interpolation
def forward(self, img: Tensor) -> Tensor:
if isinstance(self.resize_size, list):
img = F.resize(img, self.resize_size, interpolation=self.interpolation)
if not isinstance(img, Tensor):
img = F.pil_to_tensor(img)
img = F.convert_image_dtype(img, torch.float)
img = F.normalize(img, mean=self.mean, std=self.std)
return img
def __repr__(self) -> str:
format_string = self.__class__.__name__ + "("
format_string += f"\n resize_size={self.resize_size}"
format_string += f"\n mean={self.mean}"
format_string += f"\n std={self.std}"
format_string += f"\n interpolation={self.interpolation}"
format_string += "\n)"
return format_string
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
f"The images are resized to ``resize_size={self.resize_size}`` using ``interpolation={self.interpolation}``. "
f"Finally the values are first rescaled to ``[0.0, 1.0]`` and then normalized using ``mean={self.mean}`` and "
f"``std={self.std}``."
)
class OpticalFlow(nn.Module):
def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]:
if not isinstance(img1, Tensor):
img1 = F.pil_to_tensor(img1)
if not isinstance(img2, Tensor):
img2 = F.pil_to_tensor(img2)
img1 = F.convert_image_dtype(img1, torch.float)
img2 = F.convert_image_dtype(img2, torch.float)
# map [0, 1] into [-1, 1]
img1 = F.normalize(img1, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img2 = F.normalize(img2, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
img1 = img1.contiguous()
img2 = img2.contiguous()
return img1, img2
def __repr__(self) -> str:
return self.__class__.__name__ + "()"
def describe(self) -> str:
return (
"Accepts ``PIL.Image``, batched ``(B, C, H, W)`` and single ``(C, H, W)`` image ``torch.Tensor`` objects. "
"The images are rescaled to ``[-1.0, 1.0]``."
)
|