File: test_fuse.py

package info (click to toggle)
python-opt-einsum-fx 0.1.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 180 kB
  • sloc: python: 664; makefile: 13
file content (173 lines) | stat: -rw-r--r-- 4,778 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import pytest

import math
import operator

import torch
import torch.fx

from opt_einsum_fx import fuse_einsums, fuse_scalars, optimize_einsums_full


def test_einsum_fuse(allclose):
    def fusable(x, y):
        z = torch.einsum("ij,jk->ik", x, y)
        return torch.einsum("ik,ij->i", z, x)

    g = torch.fx.symbolic_trace(fusable)
    new_graph = fuse_einsums(g.graph)
    g.graph = new_graph
    g.recompile()
    x, y = torch.randn(3, 4), torch.randn(4, 5)
    out_truth = fusable(x, y)
    out_fused = g(x, y)
    assert allclose(out_fused, out_truth)


def test_unfusable():
    def unfusable(x, y):
        z = torch.einsum("ij,jk->ik", x, y)
        # We use z as something besides an input to the second einsum, so it is unfusable
        return torch.einsum("ik,ij->i", z, x) + z[:, 0]

    g = torch.fx.symbolic_trace(unfusable)
    old_code = g.code
    new_graph = fuse_einsums(g.graph)
    g.graph = new_graph
    g.recompile()
    # Confirm numerical equivalence
    x, y = torch.randn(3, 4), torch.randn(4, 5)
    out_truth = unfusable(x, y)
    out_fused = g(x, y)
    # Here we use normal allclose --- since unfusable is unfusable,
    # nothing should have changed.
    assert torch.allclose(out_fused, out_truth)
    # Confirm no fusion:
    assert old_code == g.code


def test_doublefuse(allclose):
    def doublefuse(a, b, c, d):
        # quadruple matmul with a final transpose
        e1 = torch.einsum("ij,jk->ik", a, b)
        e2 = torch.einsum("ab,bc->ac", e1, c)
        return torch.einsum("tr,ry->yt", e2, d)

    g = torch.fx.symbolic_trace(doublefuse)
    new_graph = fuse_einsums(g.graph)
    g.graph = new_graph
    g.recompile()
    a, b, c, d = (
        torch.randn(3, 4),
        torch.randn(4, 5),
        torch.randn(5, 2),
        torch.randn(2, 3),
    )
    out_truth = doublefuse(a, b, c, d)
    out_fused = g(a, b, c, d)
    assert allclose(out_fused, out_truth)


def test_inconsistent():
    def inconsistent(x, y):
        z = torch.einsum("ij,jk->ik", x, y)
        # Note that the dimension labels for z have the wrong length
        return torch.einsum("i,ij->i", z, x)

    g = torch.fx.symbolic_trace(inconsistent)
    with pytest.raises(RuntimeError):
        _ = fuse_einsums(g.graph)


def scalar_fusable1(x, y):
    return 7.0 * torch.einsum("ij,jk->ik", x, y / 3) / 2


def scalar_fusable2(x, y):
    return 4.0 * torch.einsum("ij,jk->ik", x, 2.0 * y / 3) / 2


def scalar_fusable3(x, y):
    return 4.0 * torch.einsum("ij,jk->ik", x / 1.2, 1.7 * 2.0 * y / 3) / 2


def scalar_unfusable(x, y):
    z = 3 * torch.einsum("ij,jk->ik", x, y) / 4.0
    # We use z as something besides an input to the second einsum, so it is unfusable
    return (2.0 * torch.einsum("ik,ij->i", z, x)) + z[:, 0]


def just_scalars(x, y):
    return 3.0 * x


def just_many_scalars(x, y):
    return 3.0 / 3.4 * x / 4.0


def in_place(x, y):
    # This *shouldn't* be fused.
    a = x.clone()
    b = a.mul_(4.0)
    return 3.0 * b


def unused(x, y):
    b = 2.3 * x / 4.5  # noqa
    return 4.6 * torch.einsum("ij,jk->ik", x, y)


def constants(x, y):
    return math.pi * torch.einsum("ij,jk->ik", x, math.e * y / 3) / 2


# In all cases but unfusable, after fusion, the graph should have 5 nodes:
# two placeholders, one einsum, one mul, and one output
@pytest.mark.parametrize(
    "func",
    [
        (scalar_fusable1, 5),
        (scalar_fusable2, 5),
        (scalar_fusable3, 5),
        (
            scalar_unfusable,
            9,  # two placeholders, one einsum one mul, one einsum one mul, one getitem, one sum, and one output = 9
        ),
        (just_scalars, 4),
        (just_many_scalars, 4),
        (in_place, 6),
        (constants, 5),
        (unused, 6),
    ],
)
def test_scalar_fuse(allclose, func):
    func, truth_num_nodes = func
    g = torch.fx.symbolic_trace(func)
    print("old graph\n", g.graph)
    new_graph = fuse_scalars(g.graph)
    print("new graph\n", new_graph)
    g.graph = new_graph
    assert len(g.graph.nodes) == truth_num_nodes
    g.recompile()
    x, y = torch.randn(3, 4), torch.randn(4, 5)
    out_truth = func(x, y)
    out_fused = g(x, y)
    assert allclose(out_fused, out_truth)


def test_scalar_positioning(allclose):
    def f(x, y, z):
        return 0.784 * torch.einsum("ij,jk,kl->il", x, y, z)

    x, y, z = torch.randn(2, 100), torch.randn(100, 2), torch.randn(2, 100)

    # note that the smallest here is y
    g = torch.fx.symbolic_trace(f)
    print("old graph\n", g.graph)
    g = optimize_einsums_full(g, (x, y, z))
    print("new graph\n", g.graph)
    # optimal placement is on the 2x2 intermediate
    assert list(g.graph.nodes)[4].target == operator.mul
    out_truth = f(x, y, z)
    out_fused = g(x, y, z)
    assert allclose(out_fused, out_truth)