1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
import pytest
import math
import operator
import torch
import torch.fx
from opt_einsum_fx import fuse_einsums, fuse_scalars, optimize_einsums_full
def test_einsum_fuse(allclose):
def fusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
return torch.einsum("ik,ij->i", z, x)
g = torch.fx.symbolic_trace(fusable)
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = fusable(x, y)
out_fused = g(x, y)
assert allclose(out_fused, out_truth)
def test_unfusable():
def unfusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
# We use z as something besides an input to the second einsum, so it is unfusable
return torch.einsum("ik,ij->i", z, x) + z[:, 0]
g = torch.fx.symbolic_trace(unfusable)
old_code = g.code
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
# Confirm numerical equivalence
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = unfusable(x, y)
out_fused = g(x, y)
# Here we use normal allclose --- since unfusable is unfusable,
# nothing should have changed.
assert torch.allclose(out_fused, out_truth)
# Confirm no fusion:
assert old_code == g.code
def test_doublefuse(allclose):
def doublefuse(a, b, c, d):
# quadruple matmul with a final transpose
e1 = torch.einsum("ij,jk->ik", a, b)
e2 = torch.einsum("ab,bc->ac", e1, c)
return torch.einsum("tr,ry->yt", e2, d)
g = torch.fx.symbolic_trace(doublefuse)
new_graph = fuse_einsums(g.graph)
g.graph = new_graph
g.recompile()
a, b, c, d = (
torch.randn(3, 4),
torch.randn(4, 5),
torch.randn(5, 2),
torch.randn(2, 3),
)
out_truth = doublefuse(a, b, c, d)
out_fused = g(a, b, c, d)
assert allclose(out_fused, out_truth)
def test_inconsistent():
def inconsistent(x, y):
z = torch.einsum("ij,jk->ik", x, y)
# Note that the dimension labels for z have the wrong length
return torch.einsum("i,ij->i", z, x)
g = torch.fx.symbolic_trace(inconsistent)
with pytest.raises(RuntimeError):
_ = fuse_einsums(g.graph)
def scalar_fusable1(x, y):
return 7.0 * torch.einsum("ij,jk->ik", x, y / 3) / 2
def scalar_fusable2(x, y):
return 4.0 * torch.einsum("ij,jk->ik", x, 2.0 * y / 3) / 2
def scalar_fusable3(x, y):
return 4.0 * torch.einsum("ij,jk->ik", x / 1.2, 1.7 * 2.0 * y / 3) / 2
def scalar_unfusable(x, y):
z = 3 * torch.einsum("ij,jk->ik", x, y) / 4.0
# We use z as something besides an input to the second einsum, so it is unfusable
return (2.0 * torch.einsum("ik,ij->i", z, x)) + z[:, 0]
def just_scalars(x, y):
return 3.0 * x
def just_many_scalars(x, y):
return 3.0 / 3.4 * x / 4.0
def in_place(x, y):
# This *shouldn't* be fused.
a = x.clone()
b = a.mul_(4.0)
return 3.0 * b
def unused(x, y):
b = 2.3 * x / 4.5 # noqa
return 4.6 * torch.einsum("ij,jk->ik", x, y)
def constants(x, y):
return math.pi * torch.einsum("ij,jk->ik", x, math.e * y / 3) / 2
# In all cases but unfusable, after fusion, the graph should have 5 nodes:
# two placeholders, one einsum, one mul, and one output
@pytest.mark.parametrize(
"func",
[
(scalar_fusable1, 5),
(scalar_fusable2, 5),
(scalar_fusable3, 5),
(
scalar_unfusable,
9, # two placeholders, one einsum one mul, one einsum one mul, one getitem, one sum, and one output = 9
),
(just_scalars, 4),
(just_many_scalars, 4),
(in_place, 6),
(constants, 5),
(unused, 6),
],
)
def test_scalar_fuse(allclose, func):
func, truth_num_nodes = func
g = torch.fx.symbolic_trace(func)
print("old graph\n", g.graph)
new_graph = fuse_scalars(g.graph)
print("new graph\n", new_graph)
g.graph = new_graph
assert len(g.graph.nodes) == truth_num_nodes
g.recompile()
x, y = torch.randn(3, 4), torch.randn(4, 5)
out_truth = func(x, y)
out_fused = g(x, y)
assert allclose(out_fused, out_truth)
def test_scalar_positioning(allclose):
def f(x, y, z):
return 0.784 * torch.einsum("ij,jk,kl->il", x, y, z)
x, y, z = torch.randn(2, 100), torch.randn(100, 2), torch.randn(2, 100)
# note that the smallest here is y
g = torch.fx.symbolic_trace(f)
print("old graph\n", g.graph)
g = optimize_einsums_full(g, (x, y, z))
print("new graph\n", g.graph)
# optimal placement is on the 2x2 intermediate
assert list(g.graph.nodes)[4].target == operator.mul
out_truth = f(x, y, z)
out_fused = g(x, y, z)
assert allclose(out_fused, out_truth)
|