1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
import math
import torch
from ..optimizer import Optimizer
class AdamW(Optimizer):
r"""Implements AdamW algorithm.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay coefficient (default: 1e-2)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=1e-2, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
amsgrad = group['amsgrad']
grads = []
states = []
exp_avg = []
exp_avg_sq = []
max_exp_avg_sq = []
params_with_grad = []
for p in group['params']:
if p.grad is not None:
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
# Perform stepweight decay
p.mul_(1 - group['lr'] * group['weight_decay'])
params_with_grad.append(p)
grads.append(p.grad)
for p in params_with_grad:
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg.append(state['exp_avg'])
exp_avg_sq.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sq.append(state['max_exp_avg_sq'])
state['step'] += 1
states.append(state)
beta1, beta2 = group['betas']
bias_correction1 = [1 - beta1 ** state['step'] for state in states]
bias_correction2 = [1 - beta2 ** state['step'] for state in states]
#
# Decay the first and second moment running average coefficient
#
torch._foreach_mul_(exp_avg, beta1)
torch._foreach_add_(exp_avg, grads, alpha=1 - beta1)
torch._foreach_mul_(exp_avg_sq, beta2)
torch._foreach_addcmul_(exp_avg_sq, grads, grads, 1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
[torch.max(a, b, out=a) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
torch._foreach_div_scalar_list_(max_exp_avg_sq_sqrt, bias_correction_sqrt)
denom = torch._foreach_add(max_exp_avg_sq_sqrt, group['eps'])
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
torch._foreach_div_scalar_list_(exp_avg_sq_sqrt, bias_correction_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, group['eps'])
step_size = [group['lr'] / bc for bc in bias_correction1]
for i in range(len(step_size)):
params_with_grad[i].addcdiv_(exp_avg[i], denom[i], value=-step_size[i])
return loss
|