File: tutorial.rst

package info (click to toggle)
python-opt-einsum-fx 0.1.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 180 kB
  • sloc: python: 664; makefile: 13
file content (108 lines) | stat: -rw-r--r-- 4,058 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
Tutorial
========

``opt_einsum_fx`` is a library for optimizng einsums and functions involving them using `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ and `PyTorch FX <https://pytorch.org/docs/stable/fx.html>`_ compute graphs.

The library currently supports:

 - Fusing multiple einsums into one
 - Optimizing einsums using the `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_ library
 - Fusing multiplication and division with scalar constants, including fusing _through_ operations, like einsum, that commute with scalar multiplication.
 - Placing multiplication by fused scalar constants onto the smallest intermediate in a chain of operations that commute with scalar multiplication. 

``opt_einsum_fx`` is based on `torch.fx <https://pytorch.org/docs/stable/fx.html>`_, a framework for converting between PyTorch Python code and a programatically manipulable compute graph. To use this package, it must be possible to get your function or model as a `torch.fx.Graph`: the limitations of FX's symbolic tracing are discussed `here <https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing>`_.

Minimal example
---------------

 .. code-block:: python

    import torch
    import torch.fx
    import opt_einsum_fx

    def einmatvecmul(a, b, vec):
        """Batched matrix-matrix-vector product using einsum"""
        return torch.einsum("zij,zjk,zk->zi", a, b, vec)

    graph_mod = torch.fx.symbolic_trace(einmatvecmul)
    print("# Original code:\n", graph_mod.code)
    graph_opt = opt_einsum_fx.optimize_einsums_full(
        model=graph_mod,
        example_inputs=(
            torch.randn(7, 4, 5),
            torch.randn(7, 5, 3),
            torch.randn(7, 3)
        )
    )
    print("# Optimized code:\n", graph_opt.code)

outputs:

 .. code-block:: python

    # Original code:
    import torch
    def forward(self, a, b, vec):
        einsum_1 = torch.functional.einsum('zij,zjk,zk->zi', a, b, vec);  a = b = vec = None
        return einsum_1
        
    # Optimized code:
    import torch
    def forward(self, a, b, vec):
        einsum_1 = torch.functional.einsum('cb,cab->ca', vec, b);  vec = b = None
        einsum_2 = torch.functional.einsum('cb,cab->ca', einsum_1, a);  einsum_1 = a = None
        return einsum_2

The ``optimize_einsums_full`` function has four passes:

 1. Scalar accumulation --- use the multilinearity of einsum to fuse all constant coefficients and divisors of operands and outputs
 2. Fusing einsums --- gives greater flexibility to (3)
 3. Optimized contraction with ``opt_einsum``
 4. Moving constant scalar coefficients through operations they commute with in order to place them on the smallest possible intermediate results

We can measure the performance improvement (this is on a CPU):

 .. code-block:: python

    from torch.utils.benchmark import Timer

    batch = 1000
    a, b, vec = torch.randn(batch, 4, 5), torch.randn(batch, 5, 8), torch.randn(batch, 8)

    g = {"f": graph_mod, "a": a, "b": b, "vec": vec}
    t_orig = Timer("f(a, b, vec)", globals=g)
    print(t_orig.timeit(10_000))

    g["f"] = graph_opt
    t_opt = Timer("f(a, b, vec)", globals=g)
    print(t_opt.timeit(10_000))

gives ~2x improvement:

 .. code-block:: none

    f(a, b, vec)
    276.58 us
    1 measurement, 10000 runs , 1 thread

    f(a, b, vec)
    118.84 us
    1 measurement, 10000 runs , 1 thread

Depending on your function and dimensions you may see even larger improvements.

JIT
---

Currently, pure Python and TorchScript have different call signatures for `torch.tensordot` and `torch.permute`, both of which can appear in optimized einsums:

 .. code-block:: python

    graph_script = torch.jit.script(graph_opt)  # => RuntimeError: Arguments for call are not valid...

A function is provided to convert ``torch.fx.GraphModule`` s containing these operations from their Python signatures — the default — to a TorchScript compatible form:

 .. code-block:: python

    graph_script = torch.jit.script(opt_einsum_fx.jitable(graph_opt))