1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
|
from functools import reduce
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from .base_sparsifier import BaseSparsifier
__all__ = ["WeightNormSparsifier"]
def _flat_idx_to_2d(idx, shape):
rows = idx // shape[1]
cols = idx % shape[1]
return rows, cols
class WeightNormSparsifier(BaseSparsifier):
r"""Weight-Norm Sparsifier
This sparsifier computes the norm of every sparse block and "zeroes-out" the
ones with the lowest norm. The level of sparsity defines how many of the
blocks is removed.
This sparsifier is controlled by three variables:
1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out
2. `sparse_block_shape` defines the shape of the sparse blocks. Note that
the sparse blocks originate at the zero-index of the tensor.
3. `zeros_per_block` is the number of zeros that we are expecting in each
sparse block. By default we assume that all elements within a block are
zeroed-out. However, setting this variable sets the target number of
zeros per block. The zeros within each block are chosen as the *smallest
absolute values*.
Args:
sparsity_level: The target level of sparsity
sparse_block_shape: The shape of a sparse block (see note below)
zeros_per_block: Number of zeros in a sparse block
norm: Norm to use. Could be either `int` or a callable.
If `int`, only L1 and L2 are implemented.
Note::
The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS),
irrespective of what the rows / cols mean in the data tensor. That means,
if you were to sparsify a weight tensor in the nn.Linear, which has a
weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output
channels, while the `block_COLS` would refer to the input channels.
Note::
All arguments to the WeightNormSparsifier constructor are "default"
arguments and could be overriden by the configuration provided in the
`prepare` step.
"""
def __init__(self,
sparsity_level: float = 0.5,
sparse_block_shape: Tuple[int, int] = (1, 4),
zeros_per_block: Optional[int] = None,
norm: Optional[Union[Callable, int]] = None):
if zeros_per_block is None:
zeros_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
defaults = {
"sparsity_level": sparsity_level,
"sparse_block_shape": sparse_block_shape,
"zeros_per_block": zeros_per_block,
}
if norm is None:
norm = 2
if callable(norm):
self.norm_fn = norm
elif norm == 1:
self.norm_fn = lambda T: T.abs()
elif norm == 2:
self.norm_fn = lambda T: T * T
else:
raise NotImplementedError(f"L-{norm} is not yet implemented.")
super().__init__(defaults=defaults)
def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape,
mask=None, input_shape=None, device=None):
r"""Creates patches of size `block_shape` after scattering the indices."""
if mask is None:
assert input_shape is not None
mask = torch.ones(input_shape, device=device)
mask.scatter_(dim=dim, index=indices, value=0)
mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape)
return mask
def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None):
r"""Creates a tensor-level mask.
Tensor-level mask is described as a mask, where the granularity of sparsification of the
smallest patch is the sparse_block_shape. That means, that for a given mask and a
sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape.
In this context, `sparsity_level` describes the fraction of sparse patches.
"""
h, w = data.shape[-2:]
block_h, block_w = sparse_block_shape
dh = (block_h - h % block_h) % block_h
dw = (block_w - w % block_w) % block_w
if mask is None:
mask = torch.ones(h, w, device=data.device)
if sparsity_level >= 1.0:
mask.data = torch.zeros_like(mask)
return mask
elif sparsity_level <= 0.0:
mask.data = torch.ones_like(mask)
return mask
values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
if values_per_block > 1:
# Reduce the data
data = F.avg_pool2d(
data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True
)
data = data.flatten()
num_blocks = len(data)
data = data.repeat(1, values_per_block, 1)
threshold_idx = int(round(sparsity_level * num_blocks))
threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check
_, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False)
# Temp reshape for mask
mask_reshape = mask.reshape(data.shape) # data might be reshaped
self._scatter_fold_block_mask(
dim=2, output_shape=(h + dh, w + dw),
indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape
)
mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
return mask
def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None):
r"""Creates a block-level mask.
Block-level mask is described as a mask, where the granularity of sparsification of the
largest patch is the sparse_block_shape. That means that for a given mask and a
sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape.
In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch.
"""
if mask is None:
mask = torch.ones(data.shape, device=data.device)
h, w = data.shape[-2:]
block_h, block_w = sparse_block_shape
dh = (block_h - h % block_h) % block_h
dw = (block_w - w % block_w) % block_w
values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
if values_per_block == zeros_per_block:
# Everything should be sparsified
mask.data = torch.zeros_like(mask)
return mask
# create a new padded tensor like data (to match the block_shape)
padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device)
padded_data.fill_(torch.nan)
padded_data[:h, :w] = data
unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape)
# Temp reshape for mask
mask_reshape = mask.reshape(unfolded_data.shape)
_, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False)
self._scatter_fold_block_mask(
dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape
)
mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous()
return mask
def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape,
zeros_per_block, **kwargs):
values_per_block = reduce((lambda x, y: x * y), sparse_block_shape)
if zeros_per_block > values_per_block:
raise ValueError(
"Number of zeros per block cannot be more than " "the total number of elements in that block."
)
if zeros_per_block < 0:
raise ValueError("Number of zeros per block should be positive.")
mask = getattr(module.parametrizations, tensor_name)[0].mask
if sparsity_level <= 0 or zeros_per_block == 0:
mask.data = torch.ones_like(mask)
elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block):
mask.data = torch.zeros_like(mask)
else:
ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask(
data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape
)
if values_per_block != zeros_per_block:
block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape,
zeros_per_block=zeros_per_block)
tensor_mask = torch.logical_or(tensor_mask, block_mask)
mask.data = tensor_mask
|