File: nvfuser.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (286 lines) | stat: -rw-r--r-- 13,385 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
from typing import Dict

import torch
from torch.nn import Module
from torch._ops import OpOverload

from torch.fx import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch._prims.executor import execute
from torch.fx.experimental.proxy_tensor import DecompositionInterpreter
from torch._decomp import decomposition_table

import typing as t

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)

def aten_to_dtype(self, dtype: torch.dtype, **kwargs):
    if len(kwargs) > 0 or not dtype:
        raise RuntimeError("No support for other to.dtype() formats other than to.dtype(self, dtype)")
    return torch._prims.convert_element_type(self, dtype)

# decomposition_table currently contains both aten2aten and aten2prim decomposition
# this is a hack to separate them, as we only need aten2prim decomposition for nvfuser-supported aten graph lowering
aten2aten_decomp = {}
aten2prim_decomp = {}

for op, decomp_fn in decomposition_table.items():
    if "torch._refs" in decomp_fn.__module__:
        aten2prim_decomp[op] = decomp_fn
    else:
        aten2aten_decomp[op] = decomp_fn

aten2aten_decomp_skips = {
    "aten.native_layer_norm_backward.default",
    "aten.embedding_dense_backward.default",   # This is hurting nvfuser's perf
    "aten.addmm.default"
}

for op, decomp_fn in decomposition_table.items():
    if "torch._refs" in decomp_fn.__module__:
        aten2prim_decomp[op] = decomp_fn
    else:
        if str(op) not in aten2aten_decomp_skips:
            aten2aten_decomp[op] = decomp_fn


aten2prim_decomp[torch.ops.aten.to.dtype] = aten_to_dtype


class NvFuserOperatorSupport(OperatorSupport):
    """
    Operator support for nvFuser backend.

    Currently, partitioning is based on FX ATen graph. The fused subgraph will latter be decomposed into prims.
    To determine if an ATen ops is supported by nvFuser, we shall check the prim ops used in its ref decomposition.
    Only if all the prim ops in the ref has a nvfuser_impl, we say this Aten op is suppported by nvFuser.

    Note: When adding a rule, please add it to the corresponding section and follow the
    alphabetical order.
    """

    def __init__(self):

        # TODO: current list copied from torch/csrc/jit/codegen/cuda/parser.cpp is incorrect,
        # as that file is solely for TorchScript and doesn't represent the actual status
        # whether operation would be runnable by primTorch+nvFuser.
        # We will iterate on this list to reflect the the reality.
        support_dict = {
            # ===============================================================
            # call_function aten
            # ===============================================================
            # Following supported aten ops is copied from torch/csrc/jit/codegen/cuda/parser.cpp
            # TODO: might need to update according to supported input types
            "torch.ops.aten.add": None,
            "torch.ops.aten.sub": None,
            # "torch.ops.aten.rsub": None,    # rsub decomp is supported at aten2aten level
            "torch.ops.aten.div": None,
            "torch.ops.aten.atan2": None,
            "torch.ops.aten.mul": None,
            "torch.ops.aten.max": None,
            "torch.ops.aten.min": None,
            "torch.ops.aten.pow": None,
            "torch.ops.aten.remainder": None,
            "torch.ops.aten.fmod": None,
            "torch.ops.aten.bitwise_and": None,
            "torch.ops.aten.__and__": None,
            "torch.ops.aten.bitwise_or": None,
            "torch.ops.aten.__or__": None,
            "torch.ops.aten.bitwise_xor": None,
            "torch.ops.aten.__xor__": None,
            "torch.ops.aten.bitwise_left_shift": None,
            "torch.ops.aten.__lshift__": None,
            "torch.ops.aten.bitwise_right_shift": None,
            "torch.ops.aten.__rshift__": None,
            "torch.ops.aten.eq": None,
            "torch.ops.aten.ne": None,
            "torch.ops.aten.ge": None,
            "torch.ops.aten.gt": None,
            "torch.ops.aten.le": None,
            "torch.ops.aten.lt": None,
            "torch.ops.aten.abs": None,
            "torch.ops.aten.bitwise_not": None,
            "torch.ops.aten.ceil": None,
            "torch.ops.aten.floor": None,
            "torch.ops.aten.frac": None,
            "torch.ops.aten.neg": None,
            "torch.ops.aten.relu": None,
            "torch.ops.aten.round": None,
            "torch.ops.aten.silu": None,
            "torch.ops.aten.trunc": None,
            "torch.ops.aten.log": None,
            "torch.ops.aten.log10": None,
            "torch.ops.aten.log1p": None,
            "torch.ops.aten.log2": None,
            "torch.ops.aten.lgamma": None,
            "torch.ops.aten.exp": None,
            "torch.ops.aten.expm1": None,
            "torch.ops.aten.erf": None,
            "torch.ops.aten.erfc": None,
            "torch.ops.aten.cos": None,
            "torch.ops.aten.acos": None,
            "torch.ops.aten.cosh": None,
            "torch.ops.aten.sin": None,
            "torch.ops.aten.asin": None,
            "torch.ops.aten.sinh": None,
            "torch.ops.aten.tan": None,
            "torch.ops.aten.atan": None,
            "torch.ops.aten.tanh": None,
            "torch.ops.aten.atanh": None,
            "torch.ops.aten.sqrt": None,
            "torch.ops.aten.rsqrt": None,
            "torch.ops.aten.reciprocal": None,
            "torch.ops.aten.sigmoid": None,
            "torch.ops.aten.isfinite": None,
            "torch.ops.aten.isinf": None,
            "torch.ops.aten.isnan": None,
            "torch.ops.aten.isneginf": None,
            "torch.ops.aten.isposinf": None,
            "torch.ops.aten.isreal": None,
            # "torch.ops.aten.rand_like": None,  # causing Node empty_like_default does not support nvfuser
            "torch.ops.aten.softplus": None,
            "torch.ops.aten.threshold": None,
            # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.new_zero op
            # "torch.ops.aten.threshold_backward": None,
            "torch.ops.aten.clamp": None,
            # "torch.ops.aten.clone": None,
            # Failing with where(): incompatible function arguments: \
            # [<torch._C._nvfuser.TensorView, tensor, <torch._C._nvfuser.TensorView]
            # failing with BERT_pytorch_forward_0, which has aten.where.ScalarSelf in the decomps
            # "torch.ops.aten.where": None,
            # However, aten.where.self overload is fully supported
            "torch.ops.aten.where.self": None,
            "torch.ops.aten.lerp": None,
            "torch.ops.aten.addcmul": None,
            # "torch.ops.aten.native_dropout": None,    # missing refs for aten.rank_like
            "torch.ops.aten.dropout": None,
            # "torch.ops.aten.native_dropout_backward": None,   # missing refs for aten.type_as
            "torch.ops.aten.instance_norm": None,
            "torch.ops.aten._batch_norm_impl_index": None,
            # "torch.ops.aten.native_batch_norm": None,     # missing refs for aten.var
            "torch.ops.aten.batch_norm": None,
            "torch.ops.aten.cudnn_batch_norm": None,
            "torch.ops.aten._batch_norm_impl_index_backward": None,
            # "torch.ops.aten.native_batch_norm_backward": None,    # should have been handled at aten2aten decomp
            "torch.ops.aten.native_layer_norm": None,
            "torch.ops.aten.layer_norm": None,
            # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.div
            # "torch.ops.aten.native_layer_norm_backward": None,
            "torch.ops.aten.softmax.int": None,
            "torch.ops.aten.log_softmax.int": None,
            # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.amax
            # "torch.ops.aten._softmax": None,
            "torch.ops.aten._log_softmax_backward_data": None,
            # "torch.ops.aten._softmax_backward_data": None,  # Node _softmax_backward_data_default does not support nvfuser
            # "torch.ops.aten.var.dim": None,       # missing refs
            "torch.ops.aten.std.dim": None,
            "torch.ops.aten.sum": None,
            # "torch.ops.aten.mean.dim": None,      # missing refs
            "torch.ops.aten._grad_sum_to_size": None,
            "torch.ops.aten.sum_to_size": None,
            "torch.ops.aten._autocast_to_reduced_precision": None,
            "torch.ops.aten._autocast_to_full_precision": None,
            # "torch.ops.aten.to.dtype": None,      # causing segfault
            # "torch.ops.aten.type_as": None,       # missing refs
            "torch.ops.aten.linear": None,
            "torch.ops.aten.gelu": None,
            # "torch.ops.aten.gelu_backward": None,       # gelu_backward is handled at aten2aten decomp
            # "torch.ops.aten.hardtanh": None,        # has functional ref, using unsupported aten.clamp
            "torch.ops.aten.leaky_relu": None,
            "torch.ops.aten.square": None,
            # relying on aten->aten->prim decomp, aten2aten is using unsupported aten.conj_physical
            "torch.ops.aten.tanh_backward": None,
            # "torch.ops.aten.amax": None,      # missing prim decomp
            # "torch.ops.aten.amin": None,      # missing prim decomp
            # "torch.ops.aten.reshape": None,
            # "torch.ops.aten.view": None,      # missing prim decomp
            "torch.ops.aten.flatten.using_ints": None,

            # ===============================================================
            # call_function builtins and operator
            # ===============================================================
            "getattr": None,
            "_operator.getitem": None,
        }

        super().__init__(support_dict)

    def is_node_supported(
        self, submodules: t.Mapping[str, Module], node: Node
    ) -> bool:

        # nvFuser FX subgraph should be purely functional
        if node.op not in CALLABLE_NODE_OPS:
            return False

        # ops in supported_dict doesn't have overload name
        # use overloadpacket's qualified_name for OpOverload
        if isinstance(node.target, OpOverload):
            target = _get_qualified_name(node.target.overloadpacket)
            if target in self._support_dict:
                return True

        return super().is_node_supported(submodules, node)


class NvFuserBackend:
    def __init__(self):
        self.supported_ops = NvFuserOperatorSupport()

        # TODO: this is a naive implementation of cache without proper guard
        self.partitioner_cache: Dict[GraphModule, GraphModule] = {}

        # TODO: this is a naive implementation of cache without proper guard, this will only work for identical inputs
        self.prim_decomp_cache: Dict[GraphModule, GraphModule] = {}

    def lower_to_prims_and_execute(self, graph_module: GraphModule, *args, **kwargs):
        # `graph_module` is an Aten-Fx graph
        # "lowering to prims" and "trace execution" are grouped into this function, as they are both input dependent

        if graph_module in self.prim_decomp_cache:
            logger.debug("prim_decomp_cache hit!")
            prim_module = self.prim_decomp_cache[graph_module]
        else:
            prim_graph = torch.fx.Graph()
            DecompositionInterpreter(graph_module, prim_graph, decomposition_table=aten2prim_decomp).run(*args, **kwargs)
            prim_module = torch.fx.GraphModule(graph_module, prim_graph)
            self.prim_decomp_cache[graph_module] = prim_module

            logger.debug("Lower to prims graph: ", prim_module.code)

        # invokes trace executor for running the prim graph
        return execute(prim_module, *args, executor="nvfuser")

    def compile(self, graph_module: GraphModule) -> GraphModule:
        # entry function for nvFuser backend
        logger.debug("Compiling graph_module: ", graph_module.code)

        # FX graph based partitioning based on nvfuser supported ops
        if graph_module in self.partitioner_cache:
            logger.debug("partitioner_cache hit!")
            fused_graph_module = self.partitioner_cache[graph_module]
        else:
            partitioner = CapabilityBasedPartitioner(
                graph_module, self.supported_ops, allows_single_node_partition=False)
            fused_graph_module = partitioner.partition_and_fuse()

            self.partitioner_cache[graph_module] = fused_graph_module

        # Overriding fused_module's __call__() function with lower_to_prims_and_execute()
        for node in fused_graph_module.graph.nodes:
            # TODO: use a better way to identify fused submodule
            if node.op == "call_module" and "fused_" in node.name:
                fused_module = getattr(fused_graph_module, node.name)
                fused_module._wrapped_call = self.lower_to_prims_and_execute

        return fused_graph_module

    def __call__(self, graph_module: GraphModule, _) -> GraphModule:
        # wrap self.compile as __call__ function to fit the interface for AOTAutograd's fw_compiler
        return self.compile(graph_module)