File: adaround_loss.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (98 lines) | stat: -rw-r--r-- 3,283 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Tuple

import numpy as np

import torch
from torch.nn import functional as F


ADAROUND_ZETA: float = 1.1
ADAROUND_GAMMA: float = -0.1


class AdaptiveRoundingLoss(torch.nn.Module):
    """
    Adaptive Rounding Loss functions described in https://arxiv.org/pdf/2004.10568.pdf
    rounding regularization is eq [24]
    reconstruction loss is eq [25] except regularization term
    """

    def __init__(
        self,
        max_iter: int,
        warm_start: float = 0.2,
        beta_range: Tuple[int, int] = (20, 2),
        reg_param: float = 0.001,
    ) -> None:
        super().__init__()
        self.max_iter = max_iter
        self.warm_start = warm_start
        self.beta_range = beta_range
        self.reg_param = reg_param

    def rounding_regularization(
        self,
        V: torch.Tensor,
        curr_iter: int,
    ) -> torch.Tensor:
        """
        Major logics copied from official Adaround Implementation.
        Apply rounding regularization to the input tensor V.
        """
        assert (
            curr_iter < self.max_iter
        ), "Current iteration strictly les sthan max iteration"
        if curr_iter < self.warm_start * self.max_iter:
            return torch.tensor(0.0)
        else:
            start_beta, end_beta = self.beta_range
            warm_start_end_iter = self.warm_start * self.max_iter

            # compute relative iteration of current iteration
            rel_iter = (curr_iter - warm_start_end_iter) / (
                self.max_iter - warm_start_end_iter
            )
            beta = end_beta + 0.5 * (start_beta - end_beta) * (
                1 + np.cos(rel_iter * np.pi)
            )

            # A rectified sigmoid for soft-quantization as formualted [23] in https://arxiv.org/pdf/2004.10568.pdf
            h_alpha = torch.clamp(
                torch.sigmoid(V) * (ADAROUND_ZETA - ADAROUND_GAMMA) + ADAROUND_GAMMA,
                min=0,
                max=1,
            )

            # Apply rounding regularization
            # This regularization term helps out term to converge into binary solution either 0 or 1 at the end of optimization.
            inner_term = torch.add(2 * h_alpha, -1).abs().pow(beta)
            regularization_term = torch.add(1, -inner_term).sum()
            return regularization_term * self.reg_param

    def reconstruction_loss(
        self,
        soft_quantized_output: torch.Tensor,
        original_output: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute the reconstruction loss between the soft quantized output and the original output.
        """
        return F.mse_loss(
            soft_quantized_output, original_output, reduction="none"
        ).mean()

    def forward(
        self,
        soft_quantized_output: torch.Tensor,
        original_output: torch.Tensor,
        V: torch.Tensor,
        curr_iter: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the asymmetric reconstruction formulation as eq [25]
        """
        regularization_term = self.rounding_regularization(V, curr_iter)
        reconstruction_term = self.reconstruction_loss(
            soft_quantized_output, original_output
        )
        return regularization_term, reconstruction_term