1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
import torch
import copy
from torch.fx import GraphModule
from torch.fx.graph import Graph
from typing import Union, Dict, Any, Set
class FusedGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
self.preserved_attr_names = preserved_attr_names
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return FusedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
class ObservedGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
self.preserved_attr_names = set([
'_activation_post_process_map',
'_activation_post_process_indexes',
'_patterns',
'_qconfig_map',
'_prepare_custom_config',
'_equalization_qconfig_map',
'_node_name_to_scope',
'_qconfig_mapping',
'_is_qat',
'_observed_node_names']).union(preserved_attr_names)
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
# GraphModule does not copy attributes which are not in the __dict__
# of vanilla nn.Module. So, we override __deepcopy__ in order
# to copy the quantization specific attributes correctly.
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_module(module: Any) -> bool:
return isinstance(module, ObservedGraphModule)
class ObservedStandaloneGraphModule(ObservedGraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
preserved_attr_names = preserved_attr_names.union(set([
"_standalone_module_input_quantized_idxs",
"_standalone_module_output_quantized_idxs"]))
super().__init__(root, graph, preserved_attr_names)
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
def is_observed_standalone_module(module: Any) -> bool:
return isinstance(module, ObservedStandaloneGraphModule)
def _save_packed_weight(self, destination, prefix, keep_vars):
for attr_name in dir(self):
if "_packed_weight" in attr_name and \
isinstance(getattr(self, attr_name), torch._C.ScriptObject): # type: ignore[attr-defined]
packed_weight = getattr(self, attr_name)
destination[prefix + attr_name] = packed_weight
class QuantizedGraphModule(GraphModule):
""" This class is created to make sure PackedParams
(e.g. LinearPackedParams, Conv2dPackedParams) to appear in state_dict
so that we can serialize and deserialize quantized graph module with
torch.save(m.state_dict()) and m.load_state_dict(state_dict)
"""
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, preserved_attr_names: Set[str]):
self.preserved_attr_names = preserved_attr_names
preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
super().__init__(root, graph)
for attr in preserved_attrs:
setattr(self, attr, preserved_attrs[attr])
self._register_state_dict_hook(_save_packed_weight)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
attrs_to_pop = []
for attr_name in state_dict:
if attr_name.startswith("_packed_weight") and isinstance(state_dict[attr_name], torch._C.ScriptObject): # type: ignore[attr-defined] # noqa: B950
setattr(self, attr_name, state_dict[attr_name])
attrs_to_pop.append(attr_name)
# pop the packed param attributesn
for attr_name in attrs_to_pop:
state_dict.pop(attr_name)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def __deepcopy__(self, memo):
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return QuantizedGraphModule(fake_mod, copy.deepcopy(self.graph), copy.deepcopy(self.preserved_attr_names))
|