File: rearrange.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (159 lines) | stat: -rw-r--r-- 5,715 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import einx
from . import util
import numpy as np
from typing import Union, Tuple
import numpy.typing as npt


@einx.jit(
    trace=lambda t, c: lambda exprs_in, tensors_in, exprs_out, backend=None: c(
        exprs_in, [t(x) for x in tensors_in], exprs_out
    )
)
def rearrange_stage3(exprs_in, tensors_in, exprs_out, backend=None):
    if len(exprs_in) != len(tensors_in):
        raise ValueError(f"Expected {len(exprs_in)} input tensor(s), got {len(tensors_in)}")
    if any(
        isinstance(expr, einx.expr.stage3.Marker)
        for root in list(exprs_in) + list(exprs_out)
        for expr in root.all()
    ):
        raise ValueError(f"Marker '{expr}' is not allowed")

    # Call tensor factories
    tensors_in = [
        einx.tracer.call_factory(tensor, expr.shape, backend, name="embedding", init="rearrange")
        for tensor, expr in zip(tensors_in, exprs_in)
    ]
    tensors_in = backend.all_to_tensor(tensors_in, convert_scalars=True)

    # Flatten expressions
    exprs_in, tensors_in = util.flatten(exprs_in, tensors_in, backend=backend)
    exprs_out_flat = util.flatten(exprs_out)
    assert all(einx.expr.stage3.is_flat(expr) for expr in exprs_in)
    assert all(einx.expr.stage3.is_flat(expr) for expr in exprs_out_flat)
    if len(exprs_in) != len(exprs_out_flat):
        raise ValueError(
            f"Got different number of input ({len(exprs_in)}) and output expressions "
            f"({len(exprs_out_flat)}) (after flattening)"
        )  # TODO:

    # Order inputs to align with output expressions
    indices = util.assignment(exprs_in, exprs_out_flat)
    exprs_in = [exprs_in[i] for i in indices]
    tensors_in = [tensors_in[i] for i in indices]

    # Transpose and broadcast missing output dimensions
    tensors = [
        util.transpose_broadcast(expr_in, tensor, expr_out, backend=backend)[0]
        for expr_in, tensor, expr_out in zip(exprs_in, tensors_in, exprs_out_flat)
    ]

    # Unflatten output expressions
    tensors = util.unflatten(exprs_out_flat, tensors, exprs_out, backend=backend)

    return tensors, exprs_out


@einx.lru_cache
def parse(description, *tensor_shapes, cse=True, **parameters):
    description, parameters = einx.op.util._clean_description_and_parameters(
        description, parameters
    )

    op = einx.expr.stage1.parse_op(description)

    if len(op[0]) != len(tensor_shapes):
        raise ValueError(f"Expected {len(op[0])} input tensors, but got {len(tensor_shapes)}")

    exprs = einx.expr.solve(
        [
            einx.expr.Equation(expr_in, tensor_shape)
            for expr_in, tensor_shape in zip(op[0], tensor_shapes)
        ]
        + [einx.expr.Equation(expr_out) for expr_out in op[1]]
        + [
            einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
            for k, v in parameters.items()
        ],
        cse=cse,
    )[: len(op[0]) + len(op[1])]
    exprs_in, exprs_out = exprs[: len(op[0])], exprs[len(op[0]) :]

    return exprs_in, exprs_out


@einx.traceback_util.filter
@einx.jit(
    trace=lambda t, c: lambda description, *tensors, backend=None, **kwargs: c(
        description, *[t(x) for x in tensors], **kwargs
    )
)
def rearrange(
    description: str,
    *tensors: einx.Tensor,
    backend: Union[einx.Backend, str, None] = None,
    cse: bool = True,
    **parameters: npt.ArrayLike,
) -> Union[einx.Tensor, Tuple[einx.Tensor, ...]]:
    """Rearranges the input tensors to match the output expressions.

    Args:
        description: Description string for the operation in einx notation. Must not contain
            brackets.
        tensors: Input tensors or tensor factories matching the description string.
        backend: Backend to use for all operations. If None, determines the backend from
            the input tensors. Defaults to None.
        cse: Whether to apply common subexpression elimination to the expressions. Defaults
            to True.
        graph: Whether to return the graph representation of the operation instead of
            computing the result. Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.

    Returns:
        The result of the rearrange operation if ``graph=False``, otherwise the graph
        representation of the operation.

    Examples:
        Transpose the row and column axes of a batch of images:

        >>> x = np.random.uniform(size=(4, 64, 48, 3))
        >>> einx.rearrange("b h w c -> b w h c", x).shape
        (4, 48, 64, 3,)

        Insert new axis (repeats elements along the new axis):

        >>> x = np.random.uniform(size=(10, 10))
        >>> einx.rearrange("a b -> a c b", x, c=100).shape
        (10, 100, 10,)

        Concatenate two tensors along the first axis:

        >>> a, b = (
        ...     np.random.uniform(size=(10, 10)),
        ...     np.random.uniform(size=(20, 10)),
        ... )
        >>> einx.rearrange("a b, c b -> (a + c) b", a, b).shape
        (30, 10,)

        Split a tensor:

        >>> x = np.random.uniform(size=(10, 2))
        >>> a, b = einx.rearrange("a (1 + 1) -> a, a", x)
        >>> a.shape, b.shape
        ((10,), (10,))

        Swap the first and last third of a tensor along a given axis:

        >>> x = np.arange(6)
        >>> einx.rearrange("(b + c + d) -> (d + c + b)", x, b=2, c=2)
        array([4, 5, 2, 3, 0, 1])
    """
    exprs_in, exprs_out = parse(
        description, *[einx.tracer.get_shape(tensor) for tensor in tensors], cse=cse, **parameters
    )
    tensors, exprs_out = rearrange_stage3(exprs_in, tensors, exprs_out, backend=backend)
    return tensors[0] if len(exprs_out) == 1 else tensors


rearrange.parse = parse