1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
# opt_einsum_fx
[](https://opt-einsum-fx.readthedocs.io/en/latest/?badge=latest)
Optimizing einsums and functions involving them using [`opt_einsum`](https://optimized-einsum.readthedocs.io/en/stable/) and PyTorch [FX](https://pytorch.org/docs/stable/fx.html) compute graphs.
Issues, questions, PRs, and any thoughts about further optimizing these kinds of operations are welcome!
For more information please see [the docs](https://opt-einsum-fx.readthedocs.io/en/stable/).
## Installation
### PyPI
The latest release can be installed from PyPI:
```bash
$ pip install opt_einsum_fx
```
### Source
To get the latest code, run:
```bash
$ git clone https://github.com/Linux-cpp-lisp/opt_einsum_fx.git
```
and install it by running
```bash
$ cd opt_einsum_fx/
$ pip install .
```
You can run the tests with
```bash
$ pytest tests/
```
## Minimal example
```python
import torch
import torch.fx
import opt_einsum_fx
def einmatvecmul(a, b, vec):
"""Batched matrix-matrix-vector product using einsum"""
return torch.einsum("zij,zjk,zk->zi", a, b, vec)
graph_mod = torch.fx.symbolic_trace(einmatvecmul)
print("Original code:\n", graph_mod.code)
graph_opt = opt_einsum_fx.optimize_einsums_full(
model=graph_mod,
example_inputs=(
torch.randn(7, 4, 5),
torch.randn(7, 5, 3),
torch.randn(7, 3)
)
)
print("Optimized code:\n", graph_opt.code)
```
outputs
```
Original code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec); a = b = vec = None
return einsum_1
Optimized code:
import torch
def forward(self, a, b, vec):
einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b); vec = b = None
einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a); einsum_1 = a = None
return einsum_2
```
We can measure the performance improvement (this is on a CPU):
```python
from torch.utils.benchmark import Timer
batch = 1000
a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)
g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
t_orig = Timer("f(a, b, vec)", globals=g)
print(t_orig.timeit(10_000))
g["f"] = graph_opt
t_opt = Timer("f(a, b, vec)", globals=g)
print(t_opt.timeit(10_000))
```
gives ~2x improvement:
```
f(a, b, vec)
276.58 us
1 measurement, 10000 runs , 1 thread
f(a, b, vec)
118.84 us
1 measurement, 10000 runs , 1 thread
```
Depending on your function and dimensions you may see even larger improvements.
## License
`opt_einsum_fx` is distributed under an [MIT license](LICENSE).
|