1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
|
"""Compatibility analyzer for PyTorch models."""
# mypy: allow-untyped-defs
# flake8: noqa: B950 We do not need flake8 as it complains line length
from __future__ import annotations
import dataclasses
import textwrap
import traceback
from collections import defaultdict
from typing import TYPE_CHECKING
import torch
import torch._export.serde.schema
from torch.export import graph_signature
from torch.onnx._internal.exporter import _dispatching, _registration
if TYPE_CHECKING:
import torch.fx
@dataclasses.dataclass
class ModelInfo:
"""Information about the model."""
parameter_count: defaultdict[torch.dtype, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
buffer_count: defaultdict[torch.dtype, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
fx_node_count: int = 0
fx_node_op_count: defaultdict[str, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
fx_node_target_count: defaultdict[str, int] = dataclasses.field(
default_factory=lambda: defaultdict(int)
)
dispatch_failures: list[tuple[torch.fx.Node, str]] = dataclasses.field(
default_factory=list
)
inputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
default_factory=dict
)
outputs: dict[str, torch._export.serde.schema.TensorMeta] = dataclasses.field(
default_factory=dict
)
def _count_weights(
exported_program: torch.export.ExportedProgram,
) -> tuple[defaultdict[torch.dtype, int], defaultdict[torch.dtype, int]]:
"""Count the size of the parameters in the exported program."""
parameter_count: defaultdict[torch.dtype, int] = defaultdict(int)
buffer_count: defaultdict[torch.dtype, int] = defaultdict(int)
for parameter in exported_program.parameters():
dtype = parameter.dtype
parameter_count[dtype] += parameter.numel()
for buffer in exported_program.buffers():
dtype = buffer.dtype
buffer_count[dtype] += buffer.numel()
return parameter_count, buffer_count
def _format_model_info(model_info: ModelInfo) -> str:
"""Format the information about the model."""
lines = [
textwrap.dedent(
f"""\
PyTorch ONNX Conversion Analysis
## Model Information
The model has {sum(model_info.parameter_count.values())} parameters and {sum(model_info.buffer_count.values())} buffers (non-trainable parameters).
Number of parameters per dtype:
```python
{model_info.parameter_count}
```
Number of buffers per dtype:
```python
{model_info.buffer_count}
```
"""
),
"Inputs:",
*[f"- `{name}`: `{meta}`" for name, meta in model_info.inputs.items()],
"",
"Outputs:",
*[f"- `{name}`: `{meta}`" for name, meta in model_info.outputs.items()],
"",
f"The FX graph has {model_info.fx_node_count} nodes in total. Number of FX nodes per op:",
]
for op, count in model_info.fx_node_op_count.items():
lines.append(f"- `{op}`: {count}")
lines.append("\n")
lines.append("Of the call_function nodes, the counts of operators used are:\n")
sorted_targets = sorted(
model_info.fx_node_target_count.items(), key=lambda x: x[1], reverse=True
)
for target, count in sorted_targets:
lines.append(f"- `{target}`: {count}")
lines.append("")
lines.append("## ONNX Conversion Information")
lines.append("")
if model_info.dispatch_failures:
lines.append(
"The model contains operators the dispatcher could not find registered ONNX decompositions for. "
"This may be due to missing implementations, decompositions not registered "
"correctly, or a bug in the dispatcher."
)
lines.append("")
lines.append("Errors grouped by operator:\n")
target_to_nodes = defaultdict(list)
for node, _ in model_info.dispatch_failures:
target_to_nodes[str(node.target)].append(node)
target_to_messages = {}
for node, message in model_info.dispatch_failures:
if str(node.target) not in target_to_messages:
target_to_messages[str(node.target)] = message
for target, nodes in sorted(
target_to_nodes.items(), key=lambda x: x[0], reverse=True
):
message = textwrap.indent(
f"{target_to_messages[target]}. Example node: `{nodes[0].format_node()}`. All nodes: `{nodes}`",
" ",
)
lines.append(f"- `{target}`: {message}")
else:
lines.append("All operators in the model have registered ONNX decompositions.")
return "\n".join(lines)
def _get_io_specs(exported_program: torch.export.ExportedProgram) -> tuple[dict, dict]:
"""Get the input and output specs of the exported program."""
nodes: dict[str, torch.fx.Node] = {
node.name: node for node in exported_program.graph.nodes
}
user_inputs = [
spec
for spec in exported_program.graph_signature.input_specs
if spec.kind == graph_signature.InputKind.USER_INPUT
]
user_outputs = [
spec
for spec in exported_program.graph_signature.output_specs
if spec.kind == graph_signature.OutputKind.USER_OUTPUT
]
inputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
outputs: dict[str, torch._export.serde.schema.TensorMeta] = {}
for spec in user_inputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
# FIXME: tensor_meta is None sometimes when the exported program still knows the shape/type
inputs[name] = nodes[name].meta["tensor_meta"]
for spec in user_outputs:
if isinstance(spec.arg, graph_signature.ConstantArgument):
continue
name = spec.arg.name
outputs[name] = nodes[name].meta["tensor_meta"]
return inputs, outputs
def _count_fx_targets(
exported_program: torch.export.ExportedProgram,
) -> defaultdict[str, int]:
"""Count the number of targets for each node in the exported program."""
fx_node_target_count: defaultdict[str, int] = defaultdict(int)
for node in exported_program.graph.nodes:
if node.op == "call_function":
fx_node_target_count[str(node.target)] += 1
return fx_node_target_count
def analyze(
exported_program: torch.export.ExportedProgram,
registry: _registration.ONNXRegistry | None = None,
file=None,
) -> None:
"""Analyze the compatibility of the exported program."""
# Get basic information about the model
model_info = ModelInfo()
model_info.parameter_count, model_info.buffer_count = _count_weights(
exported_program
)
model_info.fx_node_count = len(exported_program.graph.nodes)
model_info.fx_node_target_count = _count_fx_targets(exported_program)
inputs, outputs = _get_io_specs(exported_program)
model_info.inputs = inputs
model_info.outputs = outputs
if registry is None:
registry = _registration.ONNXRegistry.from_torchlib()
# Try to find ops for every node in the graph
for node in exported_program.graph.nodes:
model_info.fx_node_op_count[node.op] += 1
if node.op == "call_function":
try:
onnx_function, message = _dispatching.dispatch(node, registry)
except Exception as e:
message = "Critical Error in dispatcher:\n"
formatted_exception = "\n".join(
traceback.format_exception(type(e), e, e.__traceback__)
)
message += f"```pytb\n{formatted_exception}\n```\n"
onnx_function = None
if onnx_function is None:
model_info.dispatch_failures.append((node, message))
# Print the results
report = _format_model_info(model_info)
print(report, file=file, flush=True)
def compare_ops(
program_a: torch.export.ExportedProgram, program_b: torch.export.ExportedProgram
) -> tuple[set[str], set[str]]:
"""Compare and get unique ops in two exported programs.
Args:
program_a: The first exported program.
program_b: The second exported program.
Returns:
A tuple of two sets, where the first set contains the unique ops in the first program
and the second set contains the unique ops in the second program.
"""
program_a_ops = set(_count_fx_targets(program_a))
program_b_ops = set(_count_fx_targets(program_b))
return program_a_ops - program_b_ops, program_b_ops - program_a_ops
|