File: flatten.rst

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (46 lines) | stat: -rw-r--r-- 2,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
How does einx handle input and output tensors?
##############################################

einx functions accept an operation string that specifies einx expressions for the input and output tensors. The expressions potentially
contain nested compositions and concatenations that prevent the backend functions from directly accessing the required axes. To resolve this, einx
first flattens the input tensors in each operation such that they contain only a flat list of axes. After the backend operation is applied, the
resulting tensors are unflattened to match the requested output expressions.

Compositions are flattened by applying a `reshape` operation:

..  code::

    einx.rearrange("(a b) -> a b", x, a=10, b=20)
    # same as
    np.reshape(x, (10, 20))

Concatenations are flattened by splitting the input tensor into multiple tensors along the concatenated axis:

..  code::

    einx.rearrange("(a + b) -> a, b", x, a=10, b=20)
    # same as
    np.split(x, [10], axis=0)

After the operation is applied to the flattened tensors, the results are reshaped and concatenated and missing axes are inserted and broadcasted
to match the requested output expressions.

When multiple input and output tensors are specified, einx tries to find a valid assignment between inputs and outputs for the given axis names. This
can sometimes lead to ambiguous assignments:

..  code::

    # Broadcast and stack x and y along the last axis. x or y first?
    einx.rearrange("a, b -> a b (1 + 1)", x, y)

To find an assignment, einx iterates over the outputs in the order they appear in the operation string, and for each output tries to find the first input
expression that allows for a successful assignment. In most cases, this leads to input and output expressions being assigned in the same order:

..  code::

    einx.rearrange("a, b -> a b (1 + 1)", x, y)
    # same as
    np.stack([x, y], axis=-1)

The function :func:`einx.rearrange` can be used to perform flattening and unflattening of the input tensors as described in the operation string. Other functions
such as :func:`einx.vmap` and :func:`einx.dot` perform the same flattening and unflattening, in addition to applying an operation to the flattened tensors.