File: _efficient_shape_prop.py

package info (click to toggle)
python-opt-einsum-fx 0.1.4-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 180 kB
  • sloc: python: 664; makefile: 13
file content (107 lines) | stat: -rw-r--r-- 3,743 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Any, NamedTuple

import opt_einsum
import torch
from torch.fx.node import Node

from ._fuse import _EINSUM_FUNCS


class SimpleMeta(NamedTuple):
    """
    The full ShapeProp defines and uses a NamedTuple to
    store a whole bunch of metadata about the tensors
    going into and out of the Node op. But we don't
    have most of that info, and anyway, I don't think
    most of it's used in opt_einsum or opt_einsum_fx.
    (These are only concerned with computing a summation
    order.)

    Rather than give dummy or default values, which I
    only *assume* would be fine, I'm defining a NamedTuple
    with only the values we actually know. So if I'm wrong
    we will get a very clear error message, rather than
    some invisible error.
    """

    shape: torch.Size
    dtype: torch.dtype


class EfficientShapeProp(torch.fx.Interpreter):
    """
    Like ShapeProp, traverses a graph Node-by-Node
    and records the shape and type of the result
    into each Node.

    Except we treat 'einsum' as a special case.
    We don't actually execute 'einsum' on tensors,
    since the einsums will typically not be optimized
    yet (ShapeProp is called before optimization),
    and inefficient summation order can create
    enormous intermediate tensors, which often creates
    needless out-of-memory errors.

    So we override 'run_node' only for 'einsums'.
    It's straightforward to determine the shape of the
    result just from the output indices.

    (The call to opt_einsum that will typically follow
    this, also doesn't actually build the tensors
    during its exploration.)
    """

    def run_node(self, n: Node) -> Any:
        if n.op == "call_function" and n.target in _EINSUM_FUNCS:
            args, kwargs = self.fetch_args_kwargs_from_env(n)
            equation, *operands = args
            shapes = [op.shape for op in operands]

            assert len({op.dtype for op in operands}) == 1
            meta = SimpleMeta(einsum_shape(equation, *shapes), operands[0].dtype)
            result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=operands[0].device).expand(meta.shape)
        elif n.op == "call_function" and n.target == torch.tensordot:
            args, kwargs = self.fetch_args_kwargs_from_env(n)
            shape_a = [dim for i, dim in enumerate(args[0].shape) if i not in kwargs['dims'][0]]
            shape_b = [dim for i, dim in enumerate(args[1].shape) if i not in kwargs['dims'][1]]

            assert len({op.dtype for op in args}) == 1
            meta = SimpleMeta(shape_a + shape_b, args[0].dtype)
            result = torch.zeros((1,) * len(meta.shape), dtype=meta.dtype, device=args[0].device).expand(meta.shape)
        else:
            result = super().run_node(n)

            if isinstance(result, torch.Tensor):
                meta = SimpleMeta(result.shape, result.dtype)
            else:
                meta = None

        n.meta = dict()
        n.meta['tensor_meta'] = meta
        n.meta['type'] = type(result)

        return result

    def propagate(self, *args):
        return super().run(*args)


def einsum_shape(subscripts, *shapes):
    """
    Given an einsum equation and input shapes, returns the output
    shape of the einsum.

    Args:
       subscripts: the einsum formula
       shapes: the input shapes
    """
    Shaped = NamedTuple('Shaped', [('shape', tuple)])
    input_subscripts, output_subscript, _ = opt_einsum.parser.parse_einsum_input(
        (subscripts,) + tuple(Shaped(shape) for shape in shapes)
    )
    dims = {
        i: dim
        for ii, shape in zip(input_subscripts.split(','), shapes)
        for i, dim in zip(ii, shape)
    }
    return tuple(dims[i] for i in output_subscript)