| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 
 | from .tracer import *
from .tensor import *
def get_signature(node):
    if not node.origin is None:
        return node.origin.signature
    else:
        return None
class Optimizer:
    def __init__(self):
        self.optimized_nodes = {}
        self.changed = False
    def __call__(self, node):
        if id(node) in self.optimized_nodes:
            return self.optimized_nodes[id(node)]
        if isinstance(node, TracableFunction):
            if node.output is None:
                raise ValueError("Function output is None")
            new_node = TracableFunction(
                func=self(node.func),
                args=node.args,
                kwargs=node.kwargs,
                virtual_args=self(node.virtual_args),
                output=self(node.output),
            )
        elif isinstance(node, Tracer):
            if isinstance(node.origin, Application):
                if (
                    get_signature(node) == "reshape"
                    and get_signature(node.origin.tensor) == "reshape"
                ):
                    # Merge consecutive reshape ops
                    shape = node.origin.shape
                    new_node = apply(
                        self(node.origin.op),
                        [self(node.origin.tensor.origin.tensor), shape],
                        output=Tensor(shape),
                        signature="reshape",
                    )
                    self.changed = True
                elif (
                    get_signature(node) == "reshape"
                    and get_shape(node.origin.tensor) == node.origin.shape
                ):
                    # Skip reshape op if tensor already has right shape
                    new_node = self(node.origin.tensor)
                    self.changed = True
                elif (
                    get_signature(node) == "broadcast_to"
                    and get_shape(node.origin.tensor) == node.origin.shape
                ):
                    # Skip broadcast_to op if tensor already has right shape
                    new_node = self(node.origin.tensor)
                    self.changed = True
                elif get_signature(node) == "transpose" and list(node.origin.permutation) == list(
                    range(len(node.shape))
                ):
                    # Skip transpose op if permutation is identity
                    new_node = self(node.origin.tensor)
                    self.changed = True
                else:
                    # Optimize only arguments
                    new_output_nodes = einx.tree_util.tree_map(
                        lambda node: node.__copy__(), node.origin.output
                    )
                    def store(new_node, node):
                        assert not id(node) in self.optimized_nodes
                        self.optimized_nodes[id(node)] = new_node
                    einx.tree_util.tree_map(store, new_output_nodes, node.origin.output)
                    new_node = self.optimized_nodes[id(node)]
                    apply(
                        self(node.origin.op),
                        self(node.origin.args),
                        self(node.origin.kwargs),
                        output=new_output_nodes,
                        signature=node.origin.signature,
                        inplace_updates=[
                            (
                                self.optimized_nodes[id(tensor_in)],
                                self.optimized_nodes[id(tensor_out)],
                            )
                            for tensor_in, tensor_out in node.origin.inplace_updates
                        ],
                        comment=node.origin.comment,
                        depend_on=self(node.origin.depend_on),
                    )
            else:
                new_node = node
        elif isinstance(node, list):
            new_node = [self(x) for x in node]
        elif isinstance(node, tuple):
            new_node = tuple(self(x) for x in node)
        elif isinstance(node, dict):
            new_node = {k: self(v) for k, v in node.items()}
        else:
            new_node = node
        self.optimized_nodes[id(node)] = new_node
        return new_node
def optimize(node):
    while True:
        optimizer = Optimizer()
        node = optimizer(node)
        if not optimizer.changed:
            break
    return node
 |