1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
|
import torch
import torch.distributed as dist
from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
input = input.contiguous()
count = torch.empty(1,
dtype=running_mean.dtype,
device=input.device).fill_(input.numel() // input.size(1))
# calculate mean/invstd for input.
mean, invstd = torch.batch_norm_stats(input, eps)
num_channels = input.shape[1]
# C, C, 1 -> (2C + 1)
combined = torch.cat([mean, invstd, count], dim=0)
# world_size * (2C + 1)
combined_list = [
torch.empty_like(combined) for k in range(world_size)
]
# Use allgather instead of allreduce since I don't trust in-place operations ..
dist.all_gather(combined_list, combined, process_group, async_op=False)
combined = torch.stack(combined_list, dim=0)
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)
size = count_all.view(-1).long().sum()
if size == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
# calculate global mean & invstd
mean, invstd = torch.batch_norm_gather_stats_with_counts(
input,
mean_all,
invstd_all,
running_mean,
running_var,
momentum,
eps,
count_all.view(-1)
)
self.save_for_backward(input, weight, mean, invstd, count_all)
self.process_group = process_group
# apply element-wise normalization
out = torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
return out
@staticmethod
def backward(self, grad_output):
grad_output = grad_output.contiguous()
saved_input, weight, mean, invstd, count_tensor = self.saved_tensors
grad_input = grad_weight = grad_bias = None
process_group = self.process_group
# calculate local stats as well as grad_weight / grad_bias
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
weight,
self.needs_input_grad[0],
self.needs_input_grad[1],
self.needs_input_grad[2]
)
if self.needs_input_grad[0]:
# synchronizing stats used to calculate input gradient.
# TODO: move div_ into batch_norm_backward_elemt kernel
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(
combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
divisor = count_tensor.sum()
mean_dy = sum_dy / divisor
mean_dy_xmu = sum_dy_xmu / divisor
# backward pass for gradient calculation
grad_input = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
mean_dy,
mean_dy_xmu
)
# synchronizing of grad_weight / grad_bias is not needed as distributed
# training would handle all reduce.
if weight is None or not self.needs_input_grad[1]:
grad_weight = None
if weight is None or not self.needs_input_grad[2]:
grad_bias = None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
class CrossMapLRN2d(Function):
@staticmethod
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
ctx.beta = beta
ctx.k = k
ctx.scale = None
assert input.dim() == 4
ctx.scale = ctx.scale or input.new()
output = input.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
output.resize_as_(input)
ctx.scale.resize_as_(input)
# use output storage as temporary buffer
input_square = output
torch.pow(input, 2, out=input_square)
pre_pad = int((ctx.size - 1) / 2 + 1)
pre_pad_crop = channels if pre_pad > channels else pre_pad
scale_first = ctx.scale.select(1, 0)
scale_first.zero_()
# compute first feature map normalization
for c in range(pre_pad_crop):
scale_first.add_(input_square.select(1, c))
# reuse computations for next feature maps normalization
# by adding the next feature map and removing the previous
for c in range(1, channels):
scale_previous = ctx.scale.select(1, c - 1)
scale_current = ctx.scale.select(1, c)
scale_current.copy_(scale_previous)
if c < channels - pre_pad + 1:
square_next = input_square.select(1, c + pre_pad - 1)
scale_current.add_(square_next, alpha=1)
if c > pre_pad:
square_previous = input_square.select(1, c - pre_pad)
scale_current.add_(square_previous, alpha=-1)
ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k)
torch.pow(ctx.scale, -ctx.beta, out=output)
output.mul_(input)
ctx.save_for_backward(input, output)
return output
@staticmethod
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
batch_size = input.size(0)
channels = input.size(1)
input_height = input.size(2)
input_width = input.size(3)
paddded_ratio = input.new(channels + ctx.size - 1, input_height,
input_width)
accum_ratio = input.new(input_height, input_width)
cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size
inversePrePad = int(ctx.size - (ctx.size - 1) / 2)
grad_input.resize_as_(input)
torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output)
paddded_ratio.zero_()
padded_ratio_center = paddded_ratio.narrow(0, inversePrePad,
channels)
for n in range(batch_size):
torch.mul(grad_output[n], output[n], out=padded_ratio_center)
padded_ratio_center.div_(ctx.scale[n])
torch.sum(
paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio)
for c in range(channels):
accum_ratio.add_(paddded_ratio[c + ctx.size - 1])
grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value)
accum_ratio.add_(paddded_ratio[c], alpha=-1)
return grad_input, None, None, None, None
|