File: solve.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (137 lines) | stat: -rw-r--r-- 4,853 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import einx
import numpy as np
from collections import defaultdict
from typing import Mapping, Optional
import numpy.typing as npt


@einx.lru_cache
def _solve(description, *tensor_shapes, cse=True, **parameters):
    description, parameters = einx.op.util._clean_description_and_parameters(
        description, parameters
    )

    exprs = einx.expr.stage1.parse_args(description)
    if len(exprs) != len(tensor_shapes):
        raise ValueError(f"Expected {len(exprs)} tensors, got {len(tensor_shapes)}")

    try:
        exprs = einx.expr.solve(
            [
                einx.expr.Equation(expr, tensor_shape)
                for expr, tensor_shape in zip(exprs, tensor_shapes)
            ]
            + [
                einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
                for k, v in parameters.items()
            ],
            cse=cse,
        )
    except (
        einx.expr.stage2.SolveDepthException,
        einx.expr.stage2.SolveExpansionException,
        einx.expr.stage3.SolveValueException,
    ):
        return None

    values = defaultdict(list)
    for root in exprs:
        for expr in root.all():
            if isinstance(expr, einx.expr.stage3.Axis):
                tokens = expr.name.split(".")
                values[tokens[0]].append((tuple(int(t) for t in tokens[1:]), expr.value))

    values2 = {}
    for name, xs in values.items():
        shape = np.amax([coord for coord, value in xs], axis=0) + 1
        value = np.zeros(shape, dtype="int32")
        for coord, v in xs:
            value[coord] = v
        if value.shape == ():
            value = int(value)
        values2[name] = value

    return values2


def solve(
    description: str, *tensors: einx.Tensor, cse: bool = False, **parameters: npt.ArrayLike
) -> Optional[Mapping[str, npt.ArrayLike]]:
    """Solve for the axis values of the given expressions and tensors.

    Args:
        description: Description string for the tensors in einx notation.
        tensors: Input tensors or tensor factories matching the description string.
        cse: Whether to apply common subexpression elimination to the expressions.
            Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.

    Returns:
        A mapping from axis name to axis value, or ``None`` if no solution was found.

    Examples:
        >>> x = np.zeros((10, 5))
        >>> einx.solve("a b", x)
        {'a': 10, 'b': 5}
    """
    return _solve(
        description, *[einx.tracer.get_shape(tensor) for tensor in tensors], cse=cse, **parameters
    )


def matches(
    description: str, *tensors: einx.Tensor, cse: bool = True, **parameters: npt.ArrayLike
) -> bool:
    """Check whether the given expressions and tensors match.

    Args:
        description: Description string for the tensors in einx notation.
        tensors: Input tensors or tensor factories matching the description string.
        cse: Whether to apply common subexpression elimination to the expressions.
            Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.

    Returns:
        True if the expressions and tensors match, False otherwise.

    Examples:
        >>> x = np.zeros((10, 5))
        >>> einx.matches("a b", x)
        True
        >>> einx.matches("a b c", x)
        False
    """
    return solve(description, *tensors, cse=cse, **parameters) is not None


@einx.traceback_util.filter
def check(
    description: str, *tensors: einx.Tensor, cse: bool = True, **parameters: npt.ArrayLike
) -> None:
    """Check whether the given expressions and tensors match and raise an exception if they don't.

    Args:
        description: Description string for the tensors in einx notation.
        tensors: Input tensors or tensor factories matching the description string.
        cse: Whether to apply common subexpression elimination to the expressions.
            Defaults to False.
        **parameters: Additional parameters that specify values for single axes, e.g. ``a=4``.
    """

    description, parameters = einx.op.util._clean_description_and_parameters(
        description, parameters
    )

    exprs = einx.expr.stage1.parse_args(description)
    if len(exprs) != len(tensors):
        raise ValueError(f"Expected {len(exprs)} tensors, got {len(tensors)}")

    tensor_shapes = [einx.tracer.get_shape(tensor) for tensor in tensors]
    einx.expr.solve(
        [einx.expr.Equation(expr, tensor_shape) for expr, tensor_shape in zip(exprs, tensor_shapes)]
        + [
            einx.expr.Equation(k, np.asarray(v)[..., np.newaxis], depth1=None, depth2=None)
            for k, v in parameters.items()
        ],
        cse=cse,
    )  # Raises an exception if no solution is found