File: util.py

package info (click to toggle)
python-einx 0.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,112 kB
  • sloc: python: 11,619; makefile: 13
file content (144 lines) | stat: -rw-r--r-- 5,152 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from . import stage1, stage2, stage3
import numpy as np
import einx


def _get_expansion(expr):
    if isinstance(expr, stage1.Expression):
        return (expr.expansion(),)
    elif isinstance(expr, (stage2.Expression, stage3.Expression)):
        return (len(expr),)
    elif isinstance(expr, np.ndarray):
        return tuple(expr.shape)
    else:
        return None


def _input_expr(expr):
    if expr is None or isinstance(
        expr, (str, stage1.Expression, stage2.Expression, stage3.Expression)
    ):
        return expr
    else:
        if isinstance(expr, np.ndarray):
            pass
        elif expr == [] or expr == ():
            expr = np.asarray(expr).astype("int32")
        else:
            try:
                expr = np.asarray(expr)
            except Exception as e:
                raise ValueError(f"Invalid expression '{expr}'") from e
        if not np.issubdtype(expr.dtype, np.integer):
            raise ValueError(f"Invalid expression '{expr}', must be integers")
        expr = " ".join([str(i) for i in expr.flatten()])
        return expr


class Equation:
    def __init__(self, expr1, expr2=None, depth1=0, depth2=0):
        self.expr1 = _input_expr(expr1)
        self.expr2 = _input_expr(expr2)
        self.expansion1 = _get_expansion(expr1)
        self.expansion2 = _get_expansion(expr2)
        self.depth1 = depth1
        self.depth2 = None if expr2 is None else depth2

    def __repr__(self):
        return f"{self.expr} = {self.value.tolist()} (expansion={self.expansion} at "
        f"depth={self.depth})"


def _to_str(l):  # Print numpy arrays in a single line rather than with line breaks
    if l is None:
        return "None"
    elif isinstance(l, np.ndarray):
        return str(tuple(l.tolist()))
    elif isinstance(l, list):
        return str(tuple(l))
    else:
        return str(l)


def solve(
    equations, cse=True, cse_concat=True, cse_in_markers=False, after_stage2=None, verbose=False
):
    if any(not isinstance(c, Equation) for c in equations):
        raise ValueError("All arguments must be of type Equation")

    exprs1 = [t.expr1 for t in equations]
    exprs2 = [t.expr2 for t in equations]
    expansions1 = [t.expansion1 for t in equations]
    expansions2 = [t.expansion2 for t in equations]
    depths1 = [t.depth1 for t in equations]
    depths2 = [t.depth2 for t in equations]

    if verbose:
        print("Stage0:")
        for expr1, expr2, expansion1, expansion2, depth1, depth2 in zip(
            exprs1, exprs2, expansions1, expansions2, depths1, depths2
        ):
            print(
                f"    {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = "
                f"{_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
            )

    exprs1 = [(stage1.parse_arg(expr) if isinstance(expr, str) else expr) for expr in exprs1]
    exprs2 = [(stage1.parse_arg(expr) if isinstance(expr, str) else expr) for expr in exprs2]

    expansions1 = [
        expansion if expansion is not None else _get_expansion(expr)
        for expansion, expr in zip(expansions1, exprs1)
    ]
    expansions2 = [
        expansion if expansion is not None else _get_expansion(expr)
        for expansion, expr in zip(expansions2, exprs2)
    ]

    if verbose:
        print("Stage1:")
        for expr1, expr2, expansion1, expansion2, depth1, depth2 in zip(
            exprs1, exprs2, expansions1, expansions2, depths1, depths2
        ):
            print(
                f"    {_to_str(expr1)} (expansion={_to_str(expansion1)} at depth={depth1}) = "
                f"{_to_str(expr2)} (expansion={_to_str(expansion2)} at depth={depth2})"
            )

    exprs1, exprs2 = stage2.solve(exprs1, exprs2, expansions1, expansions2, depths1, depths2)

    if verbose:
        print("Stage2:")
        for expr1, expr2 in zip(exprs1, exprs2):
            print(f"    {_to_str(expr1)} = {_to_str(expr2)}")

    if cse:
        exprs = stage2.cse(exprs1 + exprs2, cse_concat=cse_concat, cse_in_markers=cse_in_markers)
        exprs1, exprs2 = exprs[: len(exprs1)], exprs[len(exprs1) :]

        if verbose:
            print("Stage2.CSE:")
            for expr1, expr2 in zip(exprs1, exprs2):
                print(f"    {_to_str(expr1)} = {_to_str(expr2)}")

    if after_stage2 is not None:
        return solve(
            equations + after_stage2(exprs1, exprs2),
            cse=cse,
            cse_concat=cse_concat,
            cse_in_markers=cse_in_markers,
            after_stage2=None,
            verbose=verbose,
        )

    exprs1, exprs2 = stage3.solve(exprs1, exprs2)

    if verbose:
        print("Stage3:")
        for expr1, expr2 in zip(exprs1, exprs2):
            assert expr1 is None or expr2 is None or expr1.shape == expr2.shape
            shape = expr1.shape if expr1 is not None else expr2.shape
            shape = " ".join(str(i) for i in shape)
            print(f"    {_to_str(expr1)} = {_to_str(expr2)} = {shape}")

    return exprs1