1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
import torch
from torch import nn
from typing import Any, List
__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook']
class PruningParametrization(nn.Module):
def __init__(self, original_outputs):
super().__init__()
self.original_outputs = set(range(original_outputs.item()))
self.pruned_outputs = set() # Will contain indicies of outputs to prune
def forward(self, x):
valid_outputs = self.original_outputs - self.pruned_outputs
return x[list(valid_outputs)]
class ZeroesParametrization(nn.Module):
r"""Zero out pruned channels instead of removing.
E.g. used for Batch Norm pruning, which should match previous Conv2d layer."""
def __init__(self, original_outputs):
super().__init__()
self.original_outputs = set(range(original_outputs.item()))
self.pruned_outputs = set() # Will contain indicies of outputs to prune
def forward(self, x):
x.data[list(self.pruned_outputs)] = 0
return x
class ActivationReconstruction:
def __init__(self, parametrization):
self.param = parametrization
def __call__(self, module, input, output):
max_outputs = self.param.original_outputs
pruned_outputs = self.param.pruned_outputs
valid_columns = list(max_outputs - pruned_outputs)
# get size of reconstructed output
sizes = list(output.shape)
sizes[1] = len(max_outputs)
# get valid indices of reconstructed output
indices: List[Any] = []
for size in output.shape:
indices.append(slice(0, size, 1))
indices[1] = valid_columns
reconstructed_tensor = torch.zeros(sizes,
dtype=output.dtype,
device=output.device,
layout=output.layout)
reconstructed_tensor[indices] = output
return reconstructed_tensor
class BiasHook:
def __init__(self, parametrization, prune_bias):
self.param = parametrization
self.prune_bias = prune_bias
def __call__(self, module, input, output):
pruned_outputs = self.param.pruned_outputs
if getattr(module, '_bias', None) is not None:
bias = module._bias.data
if self.prune_bias:
bias[list(pruned_outputs)] = 0
# reshape bias to broadcast over output dimensions
idx = [1] * len(output.shape)
idx[1] = -1
bias = bias.reshape(idx)
output += bias
return output
|