1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
|
from .tracer import *
from .tensor import *
def get_signature(node):
if not node.origin is None:
return node.origin.signature
else:
return None
class Optimizer:
def __init__(self):
self.optimized_nodes = {}
self.changed = False
def __call__(self, node):
if id(node) in self.optimized_nodes:
return self.optimized_nodes[id(node)]
if isinstance(node, TracableFunction):
if node.output is None:
raise ValueError("Function output is None")
new_node = TracableFunction(
func=self(node.func),
args=node.args,
kwargs=node.kwargs,
virtual_args=self(node.virtual_args),
output=self(node.output),
)
elif isinstance(node, Tracer):
if isinstance(node.origin, Application):
if (
get_signature(node) == "reshape"
and get_signature(node.origin.tensor) == "reshape"
):
# Merge consecutive reshape ops
shape = node.origin.shape
new_node = apply(
self(node.origin.op),
[self(node.origin.tensor.origin.tensor), shape],
output=Tensor(shape),
signature="reshape",
)
self.changed = True
elif (
get_signature(node) == "reshape"
and get_shape(node.origin.tensor) == node.origin.shape
):
# Skip reshape op if tensor already has right shape
new_node = self(node.origin.tensor)
self.changed = True
elif (
get_signature(node) == "broadcast_to"
and get_shape(node.origin.tensor) == node.origin.shape
):
# Skip broadcast_to op if tensor already has right shape
new_node = self(node.origin.tensor)
self.changed = True
elif get_signature(node) == "transpose" and list(node.origin.permutation) == list(
range(len(node.shape))
):
# Skip transpose op if permutation is identity
new_node = self(node.origin.tensor)
self.changed = True
else:
# Optimize only arguments
new_output_nodes = einx.tree_util.tree_map(
lambda node: node.__copy__(), node.origin.output
)
def store(new_node, node):
assert not id(node) in self.optimized_nodes
self.optimized_nodes[id(node)] = new_node
einx.tree_util.tree_map(store, new_output_nodes, node.origin.output)
new_node = self.optimized_nodes[id(node)]
apply(
self(node.origin.op),
self(node.origin.args),
self(node.origin.kwargs),
output=new_output_nodes,
signature=node.origin.signature,
inplace_updates=[
(
self.optimized_nodes[id(tensor_in)],
self.optimized_nodes[id(tensor_out)],
)
for tensor_in, tensor_out in node.origin.inplace_updates
],
comment=node.origin.comment,
depend_on=self(node.origin.depend_on),
)
else:
new_node = node
elif isinstance(node, list):
new_node = [self(x) for x in node]
elif isinstance(node, tuple):
new_node = tuple(self(x) for x in node)
elif isinstance(node, dict):
new_node = {k: self(v) for k, v in node.items()}
else:
new_node = node
self.optimized_nodes[id(node)] = new_node
return new_node
def optimize(node):
while True:
optimizer = Optimizer()
node = optimizer(node)
if not optimizer.changed:
break
return node
|