1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
|
from typing import Tuple
import numpy as np
import torch
from torch.nn import functional as F
ADAROUND_ZETA: float = 1.1
ADAROUND_GAMMA: float = -0.1
class AdaptiveRoundingLoss(torch.nn.Module):
"""
Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
rounding regularization is eq [24]
reconstruction loss is eq [25] except regularization term
"""
def __init__(
self,
max_iter: int,
warm_start: float = 0.2,
beta_range: Tuple[int, int] = (20, 2),
reg_param: float = 0.001,
) -> None:
super().__init__()
self.max_iter = max_iter
self.warm_start = warm_start
self.beta_range = beta_range
self.reg_param = reg_param
def rounding_regularization(
self,
V: torch.Tensor,
curr_iter: int,
) -> torch.Tensor:
"""
Major logics copied from official Adaround Implementation.
Apply rounding regularization to the input tensor V.
"""
assert (
curr_iter < self.max_iter
), "Current iteration strictly les sthan max iteration"
if curr_iter < self.warm_start * self.max_iter:
return torch.tensor(0.0)
else:
start_beta, end_beta = self.beta_range
warm_start_end_iter = self.warm_start * self.max_iter
# compute relative iteration of current iteration
rel_iter = (curr_iter - warm_start_end_iter) / (
self.max_iter - warm_start_end_iter
)
beta = end_beta + 0.5 * (start_beta - end_beta) * (
1 + np.cos(rel_iter * np.pi)
)
# A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
h_alpha = torch.clamp(
torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
min=0,
max=1,
)
# Apply rounding regularization
# This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
regularization_term = torch.add(1, -inner_term).sum()
return regularization_term * self.reg_param
def reconstruction_loss(
self,
soft_quantized_output: torch.Tensor,
original_output: torch.Tensor,
) -> torch.Tensor:
"""
Compute the reconstruction loss between the soft quantized output and the original output.
"""
return F.mse_loss(
soft_quantized_output, original_output, reduction="none"
).mean()
def forward(
self,
soft_quantized_output: torch.Tensor,
original_output: torch.Tensor,
V: torch.Tensor,
curr_iter: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the asymmetric reconstruction formulation as eq [25]
"""
regularization_term = self.rounding_regularization(V, curr_iter)
reconstruction_term = self.reconstruction_loss(
soft_quantized_output, original_output
)
return regularization_term, reconstruction_term
|