1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
|
from typing import Any, Dict, Optional
from torch import nn
__all__ = [
"module_to_fqn",
"fqn_to_module",
"get_arg_info_from_tensor_fqn",
"FakeSparsity",
]
def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> Optional[str]:
"""
Returns the fqn for a module or None if module not a descendent of model.
"""
if module is model:
return ""
for name, child in model.named_children():
fqn = module_to_fqn(child, module, ".")
if isinstance(fqn, str):
return prefix + name + fqn
return None
def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
"""
Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
"""
if path != "":
for name in path.split("."):
model = getattr(model, name, None)
return model
def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
"""
Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
"""
# string manip to split tensor_fqn into module_fqn and tensor_name
# if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
# if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
tensor_name = tensor_fqn.split(".")[-1]
module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
module = fqn_to_module(model, module_fqn)
return {
"module_fqn": module_fqn,
"module": module,
"tensor_name": tensor_name,
"tensor_fqn": tensor_fqn,
}
# Parametrizations
class FakeSparsity(nn.Module):
r"""Parametrization for the weights. Should be attached to the 'weight' or
any other parmeter that requires a mask applied to it.
Note::
Once the mask is passed, the variable should not change the id. The
contents of the mask can change, but the mask reference itself should
not.
"""
def __init__(self, mask):
super().__init__()
self.register_buffer("mask", mask)
def forward(self, x):
assert self.mask.shape == x.shape
return self.mask * x
def state_dict(self, *args, **kwargs):
# We don't want to let the parametrizations to save the mask.
# That way we make sure that the linear module doesn't store the masks
# alongside their parametrizations.
return {}
|