1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
|
# coding=utf-8
r"""Dynamically quantized convolution modules."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch._ops import ops
from torch.nn.common_types import _size_1_t
from torch.nn.modules.utils import _single, _pair, _triple
from torch.ao.nn.quantized.modules.conv import _reverse_repeat_padding
import torch.ao.nn.quantized as nnq
import warnings
__all__ = ['Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d']
class Conv1d(nnq.Conv1d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv1d` and :class:`~torch.nn.quantized.dynamic.Conv1d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv1d` for other attributes.
Examples::
>>> m = nn.quantized.dynamic.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 100)
>>> # xdoctest: +SKIP
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv1d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: _size_1_t = 0,
dilation: _size_1_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None,
reduce_range=True):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)
super(Conv1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConv1d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
if self.padding_mode != 'zeros':
# Padding in Conv1d is stored as (p, p), need to get (p,)
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding[:1])
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv1d_dynamic(input, self._packed_params, reduce_range)
class Conv2d(nnq.Conv2d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv2d` and :class:`~torch.nn.quantized.dynamic.Conv2d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv2d` for other attributes.
Examples::
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
>>> input = torch.randn(20, 16, 50, 100)
>>> # xdoctest: +SKIP
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv2d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
dilation = _pair(dilation)
super(Conv2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, dilation,
groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConv2d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv2d_dynamic(
input, self._packed_params, reduce_range)
class Conv3d(nnq.Conv3d):
r"""A dynamically quantized conv module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.Conv3d` and :class:`~torch.nn.quantized.dynamic.Conv3d` and
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.Conv3d` for other attributes.
Examples::
>>> # With square kernels and equal stride
>>> m = nn.quantized.dynamic.Conv3d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
>>> # non-square kernels and unequal stride and with padding and dilation
>>> m = nn.quantized.dynamic.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
>>> input = torch.randn(20, 16, 56, 56, 56)
>>> # xdoctest: +SKIP
>>> output = m(input)
"""
_FLOAT_MODULE = nn.Conv3d
_NNIQAT_CONV_BN_MODULE = None # type: ignore[assignment]
_NNI_CONV_RELU_MODULE = None # type: ignore[assignment]
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True,
padding_mode='zeros', device=None, dtype=None):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
assert padding_mode != 'reflect', "Conv3d does not support reflection padding"
factory_kwargs = {'device': device, 'dtype': dtype}
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
dilation = _triple(dilation)
super(Conv3d, self)._init(
in_channels, out_channels, kernel_size, stride, padding, dilation,
False, _triple(0), groups, bias, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConv3d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
if self.padding_mode != 'zeros':
_reversed_padding_repeated_twice = _reverse_repeat_padding(self.padding)
input = F.pad(input, _reversed_padding_repeated_twice,
mode=self.padding_mode)
return ops.quantized.conv3d_dynamic(
input, self._packed_params, reduce_range)
class ConvTranspose1d(nnq.ConvTranspose1d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose1d`.
For special notes, please, see :class:`~torch.nn.quantized.dynamic.Conv1d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose1d` for other attributes.
Examples::
>>> # With square kernels and equal stride
>>> # xdoctest: +SKIP
>>> m = nndq.ConvTranspose1d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nndq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nndq.Conv1d(16, 16, 3, stride=2, padding=1)
>>> upsample = nndq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose1d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
factory_kwargs = {'device': device, 'dtype': dtype}
super(ConvTranspose1d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, output_padding,
groups, bias, dilation, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConvTranpose1d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 3:
raise ValueError("Input shape must be `(N, C, L)`!")
return torch.ops.quantized.conv_transpose1d_dynamic(
input, self._packed_params, reduce_range)
class ConvTranspose2d(nnq.ConvTranspose2d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose2d`.
For special notes, please, see :class:`~torch.nn.quantized.dynamic.Conv2d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose2d` for other attributes.
Examples::
>>> # With square kernels and equal stride
>>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
>>> # non-square kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
>>> # xdoctest: +SKIP
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose2d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
factory_kwargs = {'device': device, 'dtype': dtype}
super(ConvTranspose2d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, output_padding,
groups, bias, dilation, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConvTranpose2d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 4:
raise ValueError("Input shape must be `(N, C, H, W)`!")
return ops.quantized.conv_transpose2d_dynamic(
input, self._packed_params, reduce_range)
class ConvTranspose3d(nnq.ConvTranspose3d):
r"""A dynamically quantized transposed convolution module with floating point tensors as inputs and outputs.
For details on input arguments, parameters, and implementation see
:class:`~torch.nn.ConvTranspose3d`.
For special notes, please, see :class:`~torch.nn.quantized.dynamic.Conv3d`
Attributes:
weight (Tensor): packed tensor derived from the learnable weight
parameter.
scale (Tensor): scalar for the output scale
zero_point (Tensor): scalar for the output zero point
See :class:`~torch.nn.ConvTranspose3d` for other attributes.
Examples::
>>> # With cubic kernels and equal stride
>>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
>>> # non-cubic kernels and unequal stride and with padding
>>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
>>> # xdoctest: +SKIP
>>> output = m(input)
>>> # exact output size can be also specified as an argument
>>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
>>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
>>> h = downsample(input)
>>> h.size()
torch.Size([1, 16, 6, 6, 6])
>>> output = upsample(h, output_size=input.size())
>>> output.size()
torch.Size([1, 16, 12, 12, 12])
"""
_FLOAT_MODULE = nn.ConvTranspose3d
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True,
dilation=1, padding_mode='zeros', device=None, dtype=None):
warnings.warn(
"The current implementation of the {} module has poor numerical accuracy and its use is not recommended".format(
self._get_name()
)
)
factory_kwargs = {'device': device, 'dtype': dtype}
super(ConvTranspose3d, self).__init__(
in_channels, out_channels, kernel_size, stride, padding, output_padding,
groups, bias, dilation, padding_mode, **factory_kwargs)
def _get_name(self):
return 'DynamicQuantizedConvTranpose3d'
def forward(self, input: Tensor, reduce_range: bool = True) -> Tensor:
# Temporarily using len(shape) instead of ndim due to JIT issue
# https://github.com/pytorch/pytorch/issues/23890
if len(input.shape) != 5:
raise ValueError("Input shape must be `(N, C, T, H, W)`!")
return ops.quantized.conv_transpose3d_dynamic(
input, self._packed_params, reduce_range)
|