1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
import abc
import typing as t
import torch
import torch.fx
from torch.fx._compatibility import compatibility
from .shape_prop import TensorMetadata
from .tools_common import get_node_target, CALLABLE_NODE_OPS
__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports']
# fx.Node.target typename, as returned by `get_node_target()`
TargetTypeName = str
# Arguments' dtypes for a given node, see `OperatorSupport`
SupportedArgumentDTypes = t.Optional[
t.Tuple[
t.Sequence[t.Sequence[torch.dtype]],
t.Dict[str, t.Sequence[torch.dtype]],
]
]
SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes]
@compatibility(is_backward_compatible=False)
class OperatorSupportBase(abc.ABC):
"""Interface for determining if a fx.Node is supported by a backend"""
@abc.abstractmethod
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
raise NotImplementedError()
@compatibility(is_backward_compatible=False)
class OperatorSupport(OperatorSupportBase):
"""
`_support_dict` maps node.target typename to supported inputs dtypes.
node.target typename is retrieved using helper function `get_node_target()`
If supported inputs dtypes is None, it means any dtype is supported, else
we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}).
The first tuple ([dtypes], ...) indicates what dtypes are supported for
inputs in node.args and the second dict {"name": [dtypes], ...} indicates
what dtypes are supported for inputs in node.kwargs.
For inputs in args, if we don't want to check it, we can put None there,
e.g. (None, [torch.float]) indicates that we don't care about the type of
the first input in args. And for inputs in kwargs, if not listed, will not
be checked.
"""
_support_dict: SupportDict
def __init__(
self,
support_dict: t.Optional[SupportDict] = None
):
self._support_dict = support_dict or {}
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
"""
Args:
`sumodules`: mapping from module name to the module. This can be
retrieved by calling model.named_modules().
`node`: a Fx node that we want to determine whether it's supported.
Returns:
`is_supported`: whether the arg `node` is supported.
"""
if node.op not in CALLABLE_NODE_OPS:
return True
target = get_node_target(submodules, node)
# Target not found in _support_dict meaning that we don't support this op at all
if target not in self._support_dict:
return False
# The rule for target is None meaning that we accept any dtype
if self._support_dict[target] is None:
return True
args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc]
# Check args dtypes
for i, dtypes in enumerate(args_dtypes):
if len(node.args) <= i:
break
# None indicates we don't care about the dtype of args[i]
if dtypes is None:
continue
# If arg is not a node then we don't check it
if not isinstance(node.args[i], torch.fx.Node):
continue
arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type]
if arg_dtype not in dtypes:
return False
# Check kwargs dtypes
for k, dtypes in kwargs_dtypes.items():
if k not in node.kwargs:
continue
# If arg is not a node then we don't check it
if not isinstance(node.kwargs[k], torch.fx.Node):
continue
kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type]
if kwarg_dtype not in dtypes:
return False
return True
# ======================================================================
# Functional interfaces and utils for defining basic operator support logic
# and composing them into more complex ones
# ======================================================================
IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool]
@compatibility(is_backward_compatible=False)
def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase:
"""Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance
`IsNodeSupported` has the same call signature as
`OperatorSupportBase.is_node_supported`
"""
class FunctionalOperatorSupport(OperatorSupportBase):
def is_node_supported(
self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return is_node_supported(submodules, node)
return FunctionalOperatorSupport()
@compatibility(is_backward_compatible=False)
def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase:
"""Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase`
instance by evaluating each input `OperatorSupportBase` instance, and returns False if
any of it reports False.
"""
def _chain(submods, node) -> bool:
return all(
x.is_node_supported(submods, node)
for x in op_support
)
return create_op_support(_chain)
@compatibility(is_backward_compatible=False)
class OpSupports:
"""A set of atomic `OperatorSupportBase` instances that can be combined together
to form more complex operator support logic.
"""
@classmethod
def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase:
"""Report a node as non-supported, if any of its arguments is of dtype"""
def _decline_if_input_dtype(
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
for arg in node.all_input_nodes:
# escape dtype check for get_attr node
if arg.op == "get_attr":
continue
arg_dtype = _get_arg_dtype(arg)
if arg_dtype == dtype:
return False
return True
return create_op_support(_decline_if_input_dtype)
@classmethod
def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBase:
"""
If a node has a name that is in the disallow set, reported it as non-supported.
"""
def _decline_if_node_in_names(
submodules: t.Mapping[str, torch.nn.Module],
node: torch.fx.Node,
) -> bool:
if node.name in disallow_set:
return False
else:
return True
return create_op_support(_decline_if_node_in_names)
def _get_arg_dtype(arg: torch.fx.Node) -> t.Any:
assert isinstance(arg, torch.fx.Node)
tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr]
dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"]
return dtype
|