File: pact.py

package info (click to toggle)
pytorch-ignite 0.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 11,712 kB
  • sloc: python: 46,874; sh: 376; makefile: 27
file content (34 lines) | stat: -rw-r--r-- 815 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Implementation taken from https://discuss.pytorch.org/t/evaluator-returns-nan/107972/3
# Ref: https://arxiv.org/abs/1805.06085

import torch
import torch.nn as nn


class PACTClip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.save_for_backward(x, alpha)
        return torch.clamp(x, 0, alpha.data)

    @staticmethod
    def backward(ctx, dy):
        x, alpha = ctx.saved_tensors

        dx = dy.clone()
        dx[x < 0] = 0
        dx[x > alpha] = 0

        dalpha = dy.clone()
        dalpha[x <= alpha] = 0

        return dx, torch.sum(dalpha)


class PACTReLU(nn.Module):
    def __init__(self, alpha=6.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))

    def forward(self, x):
        return PACTClip.apply(x, self.alpha)