1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
|
# mypy: allow-untyped-defs
import copy
from typing import Any, Dict
import torch
__all__ = [
"set_module_weight",
"set_module_bias",
"has_bias",
"get_module_weight",
"get_module_bias",
"max_over_ndim",
"min_over_ndim",
"channel_range",
"get_name_by_module",
"cross_layer_equalization",
"process_paired_modules_list_to_name",
"expand_groups_in_paired_modules_list",
"equalize",
"converged",
]
_supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d}
_supported_intrinsic_types = {
torch.ao.nn.intrinsic.ConvReLU2d,
torch.ao.nn.intrinsic.LinearReLU,
torch.ao.nn.intrinsic.ConvReLU1d,
}
_all_supported_types = _supported_types.union(_supported_intrinsic_types)
def set_module_weight(module, weight) -> None:
if type(module) in _supported_types:
module.weight = torch.nn.Parameter(weight)
else:
module[0].weight = torch.nn.Parameter(weight)
def set_module_bias(module, bias) -> None:
if type(module) in _supported_types:
module.bias = torch.nn.Parameter(bias)
else:
module[0].bias = torch.nn.Parameter(bias)
def has_bias(module) -> bool:
if type(module) in _supported_types:
return module.bias is not None
else:
return module[0].bias is not None
def get_module_weight(module):
if type(module) in _supported_types:
return module.weight
else:
return module[0].weight
def get_module_bias(module):
if type(module) in _supported_types:
return module.bias
else:
return module[0].bias
def max_over_ndim(input, axis_list, keepdim=False):
"""Apply 'torch.max' over the given axes."""
axis_list.sort(reverse=True)
for axis in axis_list:
input, _ = input.max(axis, keepdim)
return input
def min_over_ndim(input, axis_list, keepdim=False):
"""Apply 'torch.min' over the given axes."""
axis_list.sort(reverse=True)
for axis in axis_list:
input, _ = input.min(axis, keepdim)
return input
def channel_range(input, axis=0):
"""Find the range of weights associated with a specific channel."""
size_of_tensor_dim = input.ndim
axis_list = list(range(size_of_tensor_dim))
axis_list.remove(axis)
mins = min_over_ndim(input, axis_list)
maxs = max_over_ndim(input, axis_list)
assert mins.size(0) == input.size(
axis
), "Dimensions of resultant channel range does not match size of requested axis"
return maxs - mins
def get_name_by_module(model, module):
"""Get the name of a module within a model.
Args:
model: a model (nn.module) that equalization is to be applied on
module: a module within the model
Returns:
name: the name of the module within the model
"""
for name, m in model.named_modules():
if m is module:
return name
raise ValueError("module is not in the model")
def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
"""Scale the range of Tensor1.output to equal Tensor2.input.
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
ranges of the second tensors' input channel
"""
if (
type(module1) not in _all_supported_types
or type(module2) not in _all_supported_types
):
raise ValueError(
"module type not supported:", type(module1), " ", type(module2)
)
bias = get_module_bias(module1) if has_bias(module1) else None
weight1 = get_module_weight(module1)
weight2 = get_module_weight(module2)
if weight1.size(output_axis) != weight2.size(input_axis):
raise TypeError(
"Number of output channels of first arg do not match \
number input channels of second arg"
)
weight1_range = channel_range(weight1, output_axis)
weight2_range = channel_range(weight2, input_axis)
# producing scaling factors to applied
weight2_range += 1e-9
scaling_factors = torch.sqrt(weight1_range / weight2_range)
inverse_scaling_factors = torch.reciprocal(scaling_factors)
if bias is not None:
bias = bias * inverse_scaling_factors
# formatting the scaling (1D) tensors to be applied on the given argument tensors
# pads axis to (1D) tensors to then be broadcasted
size1 = [1] * weight1.ndim
size1[output_axis] = weight1.size(output_axis)
size2 = [1] * weight2.ndim
size2[input_axis] = weight2.size(input_axis)
scaling_factors = torch.reshape(scaling_factors, size2)
inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
weight1 = weight1 * inverse_scaling_factors
weight2 = weight2 * scaling_factors
set_module_weight(module1, weight1)
if bias is not None:
set_module_bias(module1, bias)
set_module_weight(module2, weight2)
def process_paired_modules_list_to_name(model, paired_modules_list):
"""Processes a list of paired modules to a list of names of paired modules."""
for group in paired_modules_list:
for i, item in enumerate(group):
if isinstance(item, torch.nn.Module):
group[i] = get_name_by_module(model, item)
elif not isinstance(item, str):
raise TypeError("item must be a nn.Module or a string")
return paired_modules_list
def expand_groups_in_paired_modules_list(paired_modules_list):
"""Expands module pair groups larger than two into groups of two modules."""
new_list = []
for group in paired_modules_list:
if len(group) == 1:
raise ValueError("Group must have at least two modules")
elif len(group) == 2:
new_list.append(group)
elif len(group) > 2:
new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1))
return new_list
def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
"""Equalize modules until convergence is achieved.
Given a list of adjacent modules within a model, equalization will
be applied between each pair, this will repeated until convergence is achieved
Keeps a copy of the changing modules from the previous iteration, if the copies
are not that different than the current modules (determined by converged_test),
then the modules have converged enough that further equalizing is not necessary
Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
Args:
model: a model (nn.Module) that equalization is to be applied on
paired_modules_list (List(List[nn.module || str])): a list of lists
where each sublist is a pair of two submodules found in the model,
for each pair the two modules have to be adjacent in the model,
with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between
to get expected results.
The list can contain either modules, or names of modules in the model.
If you pass multiple modules in the same list, they will all be equalized together.
threshold (float): a number used by the converged function to determine what degree
of similarity between models is necessary for them to be called equivalent
inplace (bool): determines if function is inplace or not
"""
paired_modules_list = process_paired_modules_list_to_name(
model, paired_modules_list
)
if not inplace:
model = copy.deepcopy(model)
paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list)
name_to_module: Dict[str, torch.nn.Module] = {}
previous_name_to_module: Dict[str, Any] = {}
name_set = {name for pair in paired_modules_list for name in pair}
for name, module in model.named_modules():
if name in name_set:
name_to_module[name] = module
previous_name_to_module[name] = None
while not converged(name_to_module, previous_name_to_module, threshold):
for pair in paired_modules_list:
previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
return model
def converged(curr_modules, prev_modules, threshold=1e-4):
"""Test whether modules are converged to a specified threshold.
Tests for the summed norm of the differences between each set of modules
being less than the given threshold
Takes two dictionaries mapping names to modules, the set of names for each dictionary
should be the same, looping over the set of names, for each name take the difference
between the associated modules in each dictionary
"""
if curr_modules.keys() != prev_modules.keys():
raise ValueError(
"The keys to the given mappings must have the same set of names of modules"
)
summed_norms = torch.tensor(0.0)
if None in prev_modules.values():
return False
for name in curr_modules.keys():
curr_weight = get_module_weight(curr_modules[name])
prev_weight = get_module_weight(prev_modules[name])
difference = curr_weight.sub(prev_weight)
summed_norms += torch.norm(difference)
return bool(summed_norms < threshold)
|