File: parametrization.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (77 lines) | stat: -rw-r--r-- 2,641 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from torch import nn
from typing import Any, List

__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook']

class PruningParametrization(nn.Module):
    def __init__(self, original_outputs):
        super().__init__()
        self.original_outputs = set(range(original_outputs.item()))
        self.pruned_outputs = set()  # Will contain indicies of outputs to prune

    def forward(self, x):
        valid_outputs = self.original_outputs - self.pruned_outputs
        return x[list(valid_outputs)]


class ZeroesParametrization(nn.Module):
    r"""Zero out pruned channels instead of removing.
    E.g. used for Batch Norm pruning, which should match previous Conv2d layer."""
    def __init__(self, original_outputs):
        super().__init__()
        self.original_outputs = set(range(original_outputs.item()))
        self.pruned_outputs = set()  # Will contain indicies of outputs to prune

    def forward(self, x):
        x.data[list(self.pruned_outputs)] = 0
        return x


class ActivationReconstruction:
    def __init__(self, parametrization):
        self.param = parametrization

    def __call__(self, module, input, output):
        max_outputs = self.param.original_outputs
        pruned_outputs = self.param.pruned_outputs
        valid_columns = list(max_outputs - pruned_outputs)

        # get size of reconstructed output
        sizes = list(output.shape)
        sizes[1] = len(max_outputs)

        # get valid indices of reconstructed output
        indices: List[Any] = []
        for size in output.shape:
            indices.append(slice(0, size, 1))
        indices[1] = valid_columns

        reconstructed_tensor = torch.zeros(sizes,
                                           dtype=output.dtype,
                                           device=output.device,
                                           layout=output.layout)
        reconstructed_tensor[indices] = output
        return reconstructed_tensor


class BiasHook:
    def __init__(self, parametrization, prune_bias):
        self.param = parametrization
        self.prune_bias = prune_bias

    def __call__(self, module, input, output):
        pruned_outputs = self.param.pruned_outputs

        if getattr(module, '_bias', None) is not None:
            bias = module._bias.data
            if self.prune_bias:
                bias[list(pruned_outputs)] = 0

            # reshape bias to broadcast over output dimensions
            idx = [1] * len(output.shape)
            idx[1] = -1
            bias = bias.reshape(idx)

            output += bias
        return output