File: test_einsum_optimizer.py

package info (click to toggle)
python-opt-einsum-fx 0.1.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 180 kB
  • sloc: python: 664; makefile: 13
file content (122 lines) | stat: -rw-r--r-- 3,050 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import pytest

import torch
import torch.fx

from opt_einsum_fx import optimize_einsums, optimize_einsums_full, jitable, EfficientShapeProp


def einmatmul(x, y):
    return torch.einsum("ij,jk->ik", x, y)


def eintrace(x, y):
    # these indexings make it square
    b = torch.einsum("ii", x[:, : x.shape[0]])
    return torch.einsum("jj", y[:, : y.shape[0]]) * b


def fusable(x, y):
    z = torch.einsum("ij,jk->ik", x, y)
    return torch.einsum("ik,ij->i", z, x)


def fusable_w_scalars(x, y):
    z = torch.einsum("ij,jk->ik", x, y) / 3.0
    return 4.0 * torch.einsum("ik,ij->i", z, x)


def unfusable(x, y):
    z = torch.einsum("ij,jk->ik", x, y)
    # We use z as something besides an input to the second einsum, so it is unfusable
    return torch.einsum("ik,ij->i", z, x) + z[:, 0]


def unfusable_w_scalars(x, y):
    z = 2.7 * torch.einsum("ij,jk->ik", x, y)
    # We use z as something besides an input to the second einsum, so it is unfusable
    return torch.einsum("ik,ij->i", z, x) + 1.1 * z[:, 0]


def not_einsum(x, y):
    # Try to trip it up with lots of scalar fusion but no einsums
    return 3.0 * 2.7 * x.sum() + (4.6 / y.relu().sum())


def not_einsum2(x, y):
    a = x.tanh().relu().sum() - y.sum()
    b = 3.41 * y.sum().tanh()
    return a - 6.7 * b


@pytest.fixture(
    scope="module",
    params=[
        einmatmul,
        eintrace,
        fusable,
        fusable_w_scalars,
        unfusable,
        unfusable_w_scalars,
        not_einsum,
        not_einsum2,
    ],
)
def einfunc(request):
    return request.param


def test_optimize_einsums(einfunc, allclose):
    x = torch.randn(3, 4)
    y = torch.randn(4, 5)

    func_res = einfunc(x, y)

    func_fx = torch.fx.symbolic_trace(einfunc)
    sp = EfficientShapeProp(func_fx)
    sp.run(x, y)

    func_fx_res = func_fx(x, y)
    assert torch.all(func_res == func_fx_res)

    graph_opt = optimize_einsums(func_fx.graph)
    func_fx.graph = graph_opt
    func_fx.recompile()

    func_opt_res = func_fx(x, y)
    assert allclose(func_opt_res, func_fx_res)


def test_optimize_einsums_full(einfunc, allclose):
    x = torch.randn(3, 4)
    y = torch.randn(4, 5)
    func_res = einfunc(x, y)
    func_opt = optimize_einsums_full(einfunc, (x, y))
    assert allclose(func_res, func_opt(x, y))


def test_fallback():
    # We only bother to test this for one function
    einfunc = fusable
    # If there is no shape propagation, it should warn
    # and not do anything.
    func_fx = torch.fx.symbolic_trace(einfunc)
    old_code = func_fx.code

    with pytest.warns(RuntimeWarning):
        graph_opt = optimize_einsums(func_fx.graph)

    func_fx.graph = graph_opt
    func_fx.recompile()
    assert old_code == func_fx.code


def test_torchscript(einfunc, allclose):
    x = torch.randn(3, 4)
    y = torch.randn(4, 5)
    func_res = einfunc(x, y)
    mod_opt = optimize_einsums_full(einfunc, (x, y))
    mod_opt = jitable(mod_opt)
    mod_opt = torch.jit.script(mod_opt)
    func_opt_res = mod_opt(x, y)
    assert allclose(func_opt_res, func_res)