1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
import math
from torch import nn
from torch.nn import init
def _initialize_orthogonal(conv):
prelu_gain = math.sqrt(2)
init.orthogonal(conv.weight, gain=prelu_gain)
if conv.bias is not None:
conv.bias.data.zero_()
class ResidualBlock(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.conv1 = nn.Conv2d(
n_filters, n_filters, kernel_size=3, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(n_filters)
self.prelu = nn.PReLU(n_filters)
self.conv2 = nn.Conv2d(
n_filters, n_filters, kernel_size=3, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(n_filters)
# Orthogonal initialisation
_initialize_orthogonal(self.conv1)
_initialize_orthogonal(self.conv2)
def forward(self, x):
residual = self.prelu(self.bn1(self.conv1(x)))
residual = self.bn2(self.conv2(residual))
return x + residual
class UpscaleBlock(nn.Module):
def __init__(self, n_filters):
super().__init__()
self.upscaling_conv = nn.Conv2d(
n_filters, 4 * n_filters, kernel_size=3, padding=1
)
self.upscaling_shuffler = nn.PixelShuffle(2)
self.upscaling = nn.PReLU(n_filters)
_initialize_orthogonal(self.upscaling_conv)
def forward(self, x):
return self.upscaling(self.upscaling_shuffler(self.upscaling_conv(x)))
class SRResNet(nn.Module):
def __init__(self, rescale_factor, n_filters, n_blocks):
super().__init__()
self.rescale_levels = int(math.log(rescale_factor, 2))
self.n_filters = n_filters
self.n_blocks = n_blocks
self.conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4)
self.prelu1 = nn.PReLU(n_filters)
for residual_block_num in range(1, n_blocks + 1):
residual_block = ResidualBlock(self.n_filters)
self.add_module(
"residual_block" + str(residual_block_num),
nn.Sequential(residual_block),
)
self.skip_conv = nn.Conv2d(
n_filters, n_filters, kernel_size=3, padding=1, bias=False
)
self.skip_bn = nn.BatchNorm2d(n_filters)
for upscale_block_num in range(1, self.rescale_levels + 1):
upscale_block = UpscaleBlock(self.n_filters)
self.add_module(
"upscale_block" + str(upscale_block_num), nn.Sequential(upscale_block)
)
self.output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4)
# Orthogonal initialisation
_initialize_orthogonal(self.conv1)
_initialize_orthogonal(self.skip_conv)
_initialize_orthogonal(self.output_conv)
def forward(self, x):
x_init = self.prelu1(self.conv1(x))
x = self.residual_block1(x_init)
for residual_block_num in range(2, self.n_blocks + 1):
x = getattr(self, "residual_block" + str(residual_block_num))(x)
x = self.skip_bn(self.skip_conv(x)) + x_init
for upscale_block_num in range(1, self.rescale_levels + 1):
x = getattr(self, "upscale_block" + str(upscale_block_num))(x)
return self.output_conv(x)
|