1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
import torch
import torch.fx
import traceback
from torch.fx.node import Node, map_aggregate
from typing import Any, Tuple, NamedTuple, Optional, Dict
from torch.fx._compatibility import compatibility
__all__ = ['TensorMetadata', 'ShapeProp']
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):
# TensorMetadata is a structure containing pertinent information
# about a tensor within a PyTorch program.
# General Tensor metadata
shape : torch.Size
dtype : torch.dtype
requires_grad : bool
stride : Tuple[int]
memory_format : Optional[torch.memory_format]
# Quantization metadata
is_quantized : bool
qparams: Dict[str, Any]
def _extract_tensor_metadata(result : torch.Tensor) -> TensorMetadata:
"""
Extract a TensorMetadata NamedTuple describing `result`.
"""
shape = result.shape
dtype = result.dtype
requires_grad = result.requires_grad
stride = result.stride()
memory_formats = {
torch.contiguous_format,
torch.channels_last,
torch.channels_last_3d,
}
memory_format = None
for query_format in memory_formats:
if result.is_contiguous(memory_format=query_format):
memory_format = query_format
break
is_quantized = result.is_quantized
qparams: Dict[str, Any] = {}
if is_quantized:
qscheme = result.qscheme()
qparams["qscheme"] = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
qparams["scale"] = result.q_scale() # type: ignore[assignment]
qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment]
elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}:
# In this branch, scale and zero_point are expected to be tensors,
# we store the values as immutable_list in TensorMetadata for
# easier serialization downstream
qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment]
qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment]
qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment]
return TensorMetadata(
shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams)
@compatibility(is_backward_compatible=True)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
into the corresponding node.
Example:
In this example, we record the shape
and data type of a module given
an example input ``torch.randn(50, D_in)``.
We print the name, shape and dtype of each node.
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
The output of this code is:
x torch.float32 torch.Size([50, 1000])
linear1 torch.float32 torch.Size([50, 100])
clamp_1 torch.float32 torch.Size([50, 100])
linear2 torch.float32 torch.Size([50, 10])
output torch.float32 torch.Size([50, 10])
Args:
module (GraphModule): The module to be executed
"""
def run_node(self, n : Node) -> Any:
try:
result = super().run_node(n)
except Exception:
traceback.print_exc()
raise RuntimeError(
f"ShapeProp error for: node={n.format_node()} with "
f"meta={n.meta}"
)
found_tensor = False
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
nonlocal found_tensor
found_tensor = True
return _extract_tensor_metadata(obj)
else:
return obj
meta = map_aggregate(result, extract_tensor_meta)
if found_tensor:
n.meta['tensor_meta'] = meta
n.meta['type'] = type(result)
return result
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)
|