File: nvfuser_executor.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (351 lines) | stat: -rw-r--r-- 13,408 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
from copy import deepcopy
from dataclasses import dataclass
from functools import lru_cache
from types import MappingProxyType
from warnings import warn

import torch
import torch.overrides
from torch._prims_common import (
    _torch_dtype_to_nvfuser_dtype_map,
    getnvFuserDtype,
    Number,
    number_type,
)

from torch.fx import GraphModule
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

if torch.cuda.is_available():
    from torch._C._nvfuser import (  # type: ignore[import]
        DataType,
        Fusion,
        FusionDefinition,
    )
else:
    DataType = None

DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType(
    {
        "use_python_fusion_cache": True,
        "allow_single_op_fusion": True,
    }
)

# nvFuserTensorTemplate and nvFuserScalarTemplate are helper objects
# for cached construction of the nvFuser's Fusion
# TODO: change what is stored in the cache for nvFuser's Tensor objects
# https://github.com/pytorch/pytorch/issues/80551
@dataclass(frozen=True)
class nvFuserTensorTemplate:
    size: tuple
    stride: tuple
    dtype: DataType
    is_cpu: bool


@dataclass(frozen=True)
class nvFuserScalarTemplate:
    dtype: DataType


def to_nvfuser_template_args(args):
    def to_nvfuser(arg):
        if isinstance(arg, torch.Tensor):
            return nvFuserTensorTemplate(
                arg.size(),
                arg.stride(),
                getnvFuserDtype(arg.dtype),
                arg.is_cpu,  # type: ignore[attr-defined]
            )
        elif isinstance(arg, Number):
            return nvFuserScalarTemplate(getnvFuserDtype(number_type(arg)))
        else:
            return arg

    return tree_map(to_nvfuser, args)


def _any_get_attr_used(call_function_nodes):
    return any(
        filter(
            # bug in mypy https://github.com/python/mypy/issues/12682
            lambda n: any(  # type: ignore[arg-type]
                a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node)  # type: ignore[attr-defined]
            ),
            call_function_nodes,
        )
    )


# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024)  # type: ignore[arg-type]
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
    if not torch.cuda.is_available():
        raise RuntimeError(
            "Attempting to use nvFuser trace executor but CUDA is not available!"
        )

    # Everything in the graph must support nvfuser
    for node in gm.graph.nodes:
        if node.op == "call_function" and "getitem" in node.name:
            continue
        if (
            node.op == "call_function"
            and getattr(node.target, "impl_nvfuser", None) is None
        ):
            raise ValueError(
                "All call_function nodes in the graph must support nvfuser. "
                f"Node {node} with target {node.target} does not support nvfuser"
            )

    graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
    call_function_nodes = list(
        filter(lambda n: n.op == "call_function", gm.graph.nodes)
    )
    assert len(graph_input_nodes) == len(
        nv_args_templates
    ), "Number of placeholder nodes in the graph must match number of args"
    assert len(nv_args_templates) > 0, "There must be at least one argument"
    assert (
        len(call_function_nodes) > 0
    ), "Graph must contain at least one call_function node"
    assert not _any_get_attr_used(
        call_function_nodes
    ), "Constant tensors that are saved in the graph and used as arguments are not supported yet"

    fusion = Fusion()
    with FusionDefinition(fusion) as fd:

        def _to_nvfuser_constant(arg):
            if isinstance(arg, Number):
                return fd.define_constant(arg)
            else:
                return arg

        class FusionInterpreter(torch.fx.Interpreter):
            def run_node(self, node):
                # Squeeze requires original shape of args[0]
                if node.target in [
                    torch.ops.nvprims.squeeze,
                    torch.ops.nvprims.squeeze.default,
                ]:
                    original_shape = list(node.args[0].meta["tensor_meta"].shape)
                    assert len(node.args) == 2
                    args, kwargs = self.fetch_args_kwargs_from_env(node)
                    args = [args[0], original_shape, args[1]]
                    return self.call_function(node.target, args, node.kwargs)

                if node.target in [
                    torch.ops.nvprims.native_batch_norm,
                    torch.ops.nvprims.native_batch_norm.default,
                ]:
                    args, kwargs = self.fetch_args_kwargs_from_env(node)
                    assert len(args) == 8
                    training = args[5]
                    args6_end = tuple(map(_to_nvfuser_constant, args[6:]))
                    args = args[:5] + (training,) + args6_end
                    return node.target.impl_nvfuser(fd, *args, **kwargs)

                return super().run_node(node)

            def call_function(self, target, args, kwargs):
                # This handles tuple unpacking
                if "getitem" in str(target):
                    assert isinstance(args[0], tuple)
                    return target(*args, **kwargs)
                args = tuple(map(_to_nvfuser_constant, args))
                target = target.impl_nvfuser
                args = (fd,) + args
                return target(*args, **kwargs)

        def templates_to_nvfuser_inputs(arg):
            if isinstance(arg, nvFuserTensorTemplate):
                x = fd.define_tensor(arg.size, arg.stride, arg.dtype, arg.is_cpu)
                return x
            elif isinstance(arg, nvFuserScalarTemplate):
                x = fd.define_scalar(arg.dtype)
                return x
            else:
                return arg

        # Transforms graph to call nvfuser lowerings
        nv_args = tuple(map(templates_to_nvfuser_inputs, nv_args_templates))
        out = FusionInterpreter(gm).run(*nv_args)
        flat_out, unflatten_spec = tree_flatten(out)
        for o in flat_out:
            fd.add_output(o)

    return fusion, unflatten_spec


def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None):
    executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
    flat_args, _ = tree_flatten(args)

    # check for cuda only fusion
    if any(isinstance(arg, torch.Tensor) and arg.is_cuda for arg in flat_args) and all(  # type: ignore[attr-defined]
        (
            not isinstance(arg, torch.Tensor)
            or (arg.is_cpu and arg.ndim == 0)  # type: ignore[attr-defined]
            or arg.is_cuda  # type: ignore[attr-defined]
        )
        for arg in flat_args
    ):

        # Construction of the fusion is expensive and cached based on the GraphModule
        # and symbolic nvFuser args.
        nv_template_args = to_nvfuser_template_args(flat_args)
        use_cache = executor_parameters.get(
            "use_python_fusion_cache",
            DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
        )
        if use_cache:
            fusion, unflatten_spec = make_nvfuser_fusion(gm, *nv_template_args)  # type: ignore[misc]
        else:
            fusion, unflatten_spec = make_nvfuser_fusion.__wrapped__(gm, *nv_template_args)  # type: ignore[misc]

        # Inputs to fusion.execute correspond to the same template/symbolic inputs
        # marked with `define_tensor/scalar`
        concrete_fusion_inputs = tuple(
            arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number))
        )

        return tree_unflatten(
            fusion.execute(concrete_fusion_inputs),  # type: ignore[has-type]
            unflatten_spec,  # type: ignore[has-type]
        )
    else:
        warn(
            "nvfuser_executor is executed with non-cuda args, fallback to aten executor"
        )
        return gm.forward(*args)


class NvfuserPrimOperatorSupport(torch.fx.passes.operator_support.OperatorSupport):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        # special case to stop lowering to nvprim when converting to an unsupported type
        if (
            node.op == "call_function"
            and node.target == torch.ops.nvprims.convert_element_type.default
        ):
            return (
                _torch_dtype_to_nvfuser_dtype_map.get(node.args[1]) is not None
                and _torch_dtype_to_nvfuser_dtype_map.get(
                    node.args[0].meta["tensor_meta"].dtype  # type: ignore[union-attr]
                )
                is not None
            )
        return (
            node.op == "call_function"
            and getattr(node.target, "impl_nvfuser", None) is not None
            or "getitem" in node.name  # getitem is a special case
        )


class PartitionedInterpreter(torch.fx.Interpreter):
    def call_module(self, target, args, kwargs):
        assert isinstance(target, str)
        assert len(kwargs) == 0
        submod = self.fetch_attr(target)
        # CapabilityBasedPartitioner hardcodes the name of the subgraphs with supported_ops as "fused_" + subgraph id
        if target.startswith("fused_"):
            return nvfuser_execute(submod, *args)
        else:
            return super().call_module(target, args, kwargs)


class NvfuserGraphModule(torch.nn.Module):
    def __init__(self, gm, use_python_fusion_cache):
        super().__init__()
        self.gm = gm
        self.executor_parameters = {"use_python_fusion_cache": use_python_fusion_cache}

    def __call__(self, *args):
        return nvfuser_execute(
            self.gm, *args, executor_parameters=self.executor_parameters
        )


# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024)  # type: ignore[arg-type]
def maybe_partition_graph(
    gm: GraphModule, allow_single_op_fusion: bool, use_python_fusion_cache: bool
):
    supported_ops = NvfuserPrimOperatorSupport()
    call_function_nodes = list(
        filter(lambda n: n.op == "call_function", gm.graph.nodes)
    )
    # the graph is partitioned only if at least one node is not supported by nvFuser
    any_unsupported = any(
        not supported_ops.is_node_supported(None, node) for node in call_function_nodes
    )
    any_unsupported |= len(call_function_nodes) == 0

    # When there are constant tensors in the graph, we can't partition it
    # because deepcopy fails. Here we just return the original graph to be
    # executed by eager mode
    # https://github.com/pytorch/pytorch/issues/84415
    if (
        _any_get_attr_used(call_function_nodes)
        or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
    ):
        return gm, True

    if any_unsupported:
        # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
        gm = deepcopy(gm)
        partitioner = CapabilityBasedPartitioner(
            gm, supported_ops, allows_single_node_partition=allow_single_op_fusion
        )
        partitions = partitioner.propose_partitions()
        if len(partitions) == 0:
            warn(
                "No partition found for the graph. "
                + "This is likely because the graph is not supported by nvFuser. "
                + "Please use the eager ATen mode to execute the graph.",
                category=RuntimeWarning,
            )
        partitioned_graph = partitioner.fuse_partitions(partitions)

        # Replacing graph's fused submodules with a wrapper module with
        # __call__() method that calls nvfuser_execute.
        # This avoids the need to call the interpreter on the graph
        for node in partitioned_graph.graph.nodes:
            # TODO: use a better way to identify fused submodule
            if node.op == "call_module" and "fused_" in node.name:
                nvfuser_submodule = getattr(partitioned_graph, node.name)
                partitioned_graph.delete_submodule(node.target)
                gm.add_submodule(
                    node.target,
                    NvfuserGraphModule(nvfuser_submodule, use_python_fusion_cache),
                )

        return partitioned_graph, any_unsupported
    else:
        return gm, any_unsupported


def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None):
    executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG
    # maybe_partition_graph function is cached so we can't use non-hashable arguments
    allow_single_op_fusion = executor_parameters.get(
        "allow_single_op_fusion",
        DEFAULT_NVFUSER_PYTHON_CONFIG["allow_single_op_fusion"],
    )
    use_python_fusion_cache = executor_parameters.get(
        "use_python_fusion_cache",
        DEFAULT_NVFUSER_PYTHON_CONFIG["use_python_fusion_cache"],
    )
    # When possible it's better to use nvfuser_execute directly
    # because it avoids GraphModule's overhead
    gm, is_partitioned = maybe_partition_graph(
        gm,
        allow_single_op_fusion=allow_single_op_fusion,
        use_python_fusion_cache=use_python_fusion_cache,
    )
    if is_partitioned:
        return gm(*args)
    else:
        return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)