1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
import math
from typing import Optional, Tuple
import torch
from torch import nn, Tensor
from torch.nn import init
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter
from torchvision.extension import _assert_has_ops
from ..utils import _log_api_usage_once
def deform_conv2d(
input: Tensor,
offset: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
mask: Optional[Tensor] = None,
) -> Tensor:
r"""
Performs Deformable Convolution v2, described in
`Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168>`__ if :attr:`mask` is not ``None`` and
Performs Deformable Convolution, described in
`Deformable Convolutional Networks
<https://arxiv.org/abs/1703.06211>`__ if :attr:`mask` is ``None``.
Args:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
offsets to be applied for each position in the convolution kernel.
weight (Tensor[out_channels, in_channels // groups, kernel_height, kernel_width]): convolution weights,
split into groups of size (in_channels // groups)
bias (Tensor[out_channels]): optional bias of shape (out_channels,). Default: None
stride (int or Tuple[int, int]): distance between convolution centers. Default: 1
padding (int or Tuple[int, int]): height/width of padding of zeroes around
each image. Default: 0
dilation (int or Tuple[int, int]): the spacing between kernel elements. Default: 1
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
masks to be applied for each position in the convolution kernel. Default: None
Returns:
Tensor[batch_sz, out_channels, out_h, out_w]: result of convolution
Examples::
>>> input = torch.rand(4, 3, 10, 10)
>>> kh, kw = 3, 3
>>> weight = torch.rand(5, 3, kh, kw)
>>> # offset and mask should have the same spatial size as the output
>>> # of the convolution. In this case, for an input of 10, stride of 1
>>> # and kernel size of 3, without padding, the output size is 8
>>> offset = torch.rand(4, 2 * kh * kw, 8, 8)
>>> mask = torch.rand(4, kh * kw, 8, 8)
>>> out = deform_conv2d(input, offset, weight, mask=mask)
>>> print(out.shape)
>>> # returns
>>> torch.Size([4, 5, 8, 8])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(deform_conv2d)
_assert_has_ops()
out_channels = weight.shape[0]
use_mask = mask is not None
if mask is None:
mask = torch.zeros((input.shape[0], 1), device=input.device, dtype=input.dtype)
if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
stride_h, stride_w = _pair(stride)
pad_h, pad_w = _pair(padding)
dil_h, dil_w = _pair(dilation)
weights_h, weights_w = weight.shape[-2:]
_, n_in_channels, _, _ = input.shape
n_offset_grps = offset.shape[1] // (2 * weights_h * weights_w)
n_weight_grps = n_in_channels // weight.shape[1]
if n_offset_grps == 0:
raise RuntimeError(
"the shape of the offset tensor at dimension 1 is not valid. It should "
"be a multiple of 2 * weight.size[2] * weight.size[3].\n"
f"Got offset.shape[1]={offset.shape[1]}, while 2 * weight.size[2] * weight.size[3]={2 * weights_h * weights_w}"
)
return torch.ops.torchvision.deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
)
class DeformConv2d(nn.Module):
"""
See :func:`deform_conv2d`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
):
super().__init__()
_log_api_usage_once(self)
if in_channels % groups != 0:
raise ValueError("in_channels must be divisible by groups")
if out_channels % groups != 0:
raise ValueError("out_channels must be divisible by groups")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.weight = Parameter(
torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1])
)
if bias:
self.bias = Parameter(torch.empty(out_channels))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) -> Tensor:
"""
Args:
input (Tensor[batch_size, in_channels, in_height, in_width]): input tensor
offset (Tensor[batch_size, 2 * offset_groups * kernel_height * kernel_width, out_height, out_width]):
offsets to be applied for each position in the convolution kernel.
mask (Tensor[batch_size, offset_groups * kernel_height * kernel_width, out_height, out_width]):
masks to be applied for each position in the convolution kernel.
"""
return deform_conv2d(
input,
offset,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
mask=mask,
)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"{self.in_channels}"
f", {self.out_channels}"
f", kernel_size={self.kernel_size}"
f", stride={self.stride}"
)
s += f", padding={self.padding}" if self.padding != (0, 0) else ""
s += f", dilation={self.dilation}" if self.dilation != (1, 1) else ""
s += f", groups={self.groups}" if self.groups != 1 else ""
s += ", bias=False" if self.bias is None else ""
s += ")"
return s
|