1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
# mypy: allow-untyped-defs
import logging
import operator
import torch
from torch.ao.quantization.pt2e.utils import (
_filter_sym_size_users,
_is_valid_annotation,
)
from torch.fx.node import map_arg
from torch.fx.passes.infra.pass_base import PassBase, PassResult
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
__all__ = ["DuplicateDQPass"]
_QUANTIZE_OPS = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.quantize_per_channel.default,
]
_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]
def _maybe_duplicate_dq(
gm: torch.fx.GraphModule, dq_node: torch.fx.Node, user: torch.fx.Node
):
annotation = user.meta.get("quantization_annotation", None)
if not _is_valid_annotation(annotation):
return
with gm.graph.inserting_after(dq_node):
new_node = gm.graph.node_copy(dq_node)
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
if n == dq_node:
return new_node
else:
return n
new_args = map_arg(user.args, maybe_replace_node)
new_kwargs = map_arg(user.kwargs, maybe_replace_node)
user.args = new_args # type: ignore[assignment]
user.kwargs = new_kwargs # type: ignore[assignment]
class DuplicateDQPass(PassBase):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in _DEQUANTIZE_OPS:
dq_users = _filter_sym_size_users(node)
if len(dq_users) <= 1:
continue
# Do not duplicate dq for dynamic quantization
# Pattern: choose_qparam - getitem - q - dq
q_node = node.args[0]
if q_node.op == "call_function" and q_node.target in _QUANTIZE_OPS:
getitem_node = q_node.args[1]
if (
isinstance(getitem_node, torch.fx.node.Node)
and getitem_node.op == "call_function"
and getitem_node.target == operator.getitem
):
choose_qparam_node = getitem_node.args[0]
if (
isinstance(choose_qparam_node, torch.fx.node.Node)
and choose_qparam_node.op == "call_function"
and choose_qparam_node.target
== torch.ops.quantized_decomposed.choose_qparams.tensor
):
continue
for user in dq_users:
_maybe_duplicate_dq(graph_module, node, user)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
|