1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
|
import copy
import logging
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Sequence, Tuple
import torch
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
from torch.export import ExportedProgram
from torch.fx import GraphModule, Node
from torch.nn import functional as F
NUMERIC_DEBUG_HANDLE_KEY = "numeric_debug_handle"
CUSTOM_KEY = "custom"
log = logging.getLogger(__name__)
def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
"""
Attach numeric_debug_handle_id for all nodes in the graph module of the given
ExportedProgram, like conv2d, squeeze, conv1d, etc, except for placeholder.
Notice that nodes like getattr are out of scope since they are not in the graph.
The graph nodes of input exported program are modified inplace.
Here's an example of using debug handle quantize flow::
ep = export_for_training(eager_model, example_inputs)
generate_numeric_debug_handle(ep)
m = ep.module()
quantizer = XNNPACKQuantizer()
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
"""
# Sanity check the input data type
if not isinstance(ep, ExportedProgram):
raise ValueError(
f"Expected ep to be ExportedProgram, got {type(ExportedProgram)}"
)
unique_id = 0
def _bfs_trace_graph_with_node_process(node_op: Callable) -> None:
nonlocal ep
queue = [ep.graph_module]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
node_op(node)
control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)
def _find_max_id(node: torch.fx.Node) -> None:
nonlocal unique_id
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
)
def _assign_debug_handle(node: torch.fx.Node) -> None:
nonlocal unique_id
if CUSTOM_KEY not in node.meta:
node.meta[CUSTOM_KEY] = {}
if NUMERIC_DEBUG_HANDLE_KEY not in node.meta[CUSTOM_KEY]:
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
unique_id += 1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
_bfs_trace_graph_with_node_process(_find_max_id)
unique_id += 1
# Assign debug handles to all nodes in the graph that don't have one based on the
# max ID found in the previous step.
_bfs_trace_graph_with_node_process(_assign_debug_handle)
class OutputLogger(torch.nn.Module):
"""
Base class for capturing output values for nodes in a GraphModule, it only captures
Tensor output currently, but we can extend it to work for other types of inputs later if needed
"""
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(
self,
debug_handle: int,
node_name: Optional[str] = None,
nn_module_stack: Optional[object] = None,
) -> None:
super().__init__()
self.node_name = node_name
self.nn_module_stack = nn_module_stack
self.debug_handle = debug_handle
self.stats: List[torch.Tensor] = []
def forward(self, x: object) -> object:
if isinstance(x, torch.Tensor):
self.stats.append(x.detach())
return x
def __extra_repr__(self) -> str:
return (
f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
"nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
)
def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
"""For a given node, adds an OutputLogger that observes the output of that node,
and all its users use the OutputLogger output instead.
The OutputLogger will contain the debug_handle which can be used to compare
graphs after transforms"""
# to avoid circular dep
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
# add a logger after the node
with model.graph.inserting_after(node):
get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
logger_name = get_new_attr_name(model)
setattr(
model,
logger_name,
OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
)
logger_node = model.graph.call_module(logger_name, (node,), {})
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is logger_node:
continue
user_node.replace_input_with(node, logger_node)
return logger_node
def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
"""Add output loggers to node that has numeric_debug_handle
Args:
model (GraphModule): original model
Returns:
a model with output loggers for all nodes that has numeric_debug_handle_id
"""
# don't change the original model
model = copy.deepcopy(model)
for n in model.graph.nodes:
if (
CUSTOM_KEY not in n.meta
or NUMERIC_DEBUG_HANDLE_KEY not in n.meta[CUSTOM_KEY]
):
continue
numeric_debug_handle = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
_insert_logger(model, n, numeric_debug_handle)
model.recompile()
return model
@dataclass(frozen=True)
class QuantizationComparisonResult:
actual: torch.Tensor
ref: torch.Tensor
@property
def mse_loss(self) -> torch.Tensor:
return F.mse_loss(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
@property
def sqnr(self) -> torch.Tensor:
return compute_sqnr(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> torch.Tensor:
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
)
return loss_function(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
return (
f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
)
def __post_init__(self) -> None:
if not isinstance(self.actual, torch.Tensor):
raise ValueError(
f"`self.actual` value must be a Tensor, got: {self.actual}"
)
if not isinstance(self.ref, torch.Tensor):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
if self.actual.shape != self.ref.shape:
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
)
@dataclass(frozen=True)
class NodeAccuracySummary:
handle: int
actual_node_name: str
actual_module_stack: str
ref_node_name: str
ref_module_stack: str
results: Sequence[QuantizationComparisonResult]
def _module_stack_to_str(module_stack: object) -> str:
"""Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
to "mod.foo.0.linear"
"""
if not isinstance(module_stack, dict):
return str(module_stack)
module_values_list = list(module_stack.values())
if len(module_values_list) > 0:
owning_module = module_values_list[-1][0]
return str(owning_module)
else:
return str(module_stack)
def extract_results_from_loggers(
model: GraphModule,
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
"""For a given model, extract the tensors stats and related information for each debug handle.
Returns:
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
in loggers"""
# Results maps debug handle to a tensor list for each model being compared.
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
for _name, module in model.named_children():
if isinstance(module, OutputLogger) and len(module.stats) > 0:
handles[module.debug_handle] = (
module.node_name,
module.nn_module_stack,
module.stats,
)
return handles
def compare_results(
ref_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]],
) -> Dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
comparison information like SQNR, MSE etc.
Args:
ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
Returns:
Dict[int, NodeAccuracySummary]
"""
comparisons = {}
for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
if debug_handle not in actual_results:
log.debug(
"Cannot compare for handle %s because it wasn't found in the transformed model",
debug_handle,
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
try:
results = [
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
]
except Exception as e:
# Add extra information for an exception from QuantizationComparisonResult
# if the shapes didn't match, to include the handle and the node names.
raise ValueError(
f"For numeric_debug_handle={debug_handle} from ref node {ref_name} and actual node {actual_name}"
) from e
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name or "",
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name or "",
ref_module_stack=_module_stack_to_str(ref_stack),
results=results,
)
return comparisons
|