1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
|
# Avoid a hard jax dependency
def tree_map_with_key(func, *trees, key=(), is_leaf=None):
if is_leaf is not None and is_leaf(key, *trees):
return func(*trees, key=key)
elif all(isinstance(tree, list) for tree in trees) and all(
len(trees[0]) == len(tree) for tree in trees[1:]
):
return [
tree_map_with_key(func, *elements, key=key + (i,), is_leaf=is_leaf)
for i, elements in enumerate(zip(*trees))
]
elif all(isinstance(tree, tuple) for tree in trees) and all(
len(trees[0]) == len(tree) for tree in trees[1:]
):
return tuple(
tree_map_with_key(func, *elements, key=key + (i,), is_leaf=is_leaf)
for i, elements in enumerate(zip(*trees))
)
elif all(isinstance(tree, dict) for tree in trees) and all(
trees[0].keys() == tree.keys() for tree in trees[1:]
):
return {
k: tree_map_with_key(
func, *[tree[k] for tree in trees], key=key + (k,), is_leaf=is_leaf
)
for k in trees[0]
}
else:
return func(*trees, key=key)
def tree_map(func, *trees, is_leaf=None):
if is_leaf is not None and is_leaf(*trees):
return func(*trees)
elif all(isinstance(tree, list) for tree in trees) and all(
len(trees[0]) == len(tree) for tree in trees[1:]
):
return [
tree_map(func, *elements, is_leaf=is_leaf) for i, elements in enumerate(zip(*trees))
]
elif all(isinstance(tree, tuple) for tree in trees) and all(
len(trees[0]) == len(tree) for tree in trees[1:]
):
return tuple(
tree_map(func, *elements, is_leaf=is_leaf) for i, elements in enumerate(zip(*trees))
)
elif all(isinstance(tree, dict) for tree in trees) and all(
trees[0].keys() == tree.keys() for tree in trees[1:]
):
return {k: tree_map(func, *[tree[k] for tree in trees], is_leaf=is_leaf) for k in trees[0]}
else:
return func(*trees)
def tree_flatten(x, is_leaf=None):
if is_leaf is not None and is_leaf(x):
yield x
elif isinstance(x, (list, tuple)):
for x in x:
yield from tree_flatten(x, is_leaf=is_leaf)
elif isinstance(x, dict):
for x in x.items():
yield from tree_flatten(x, is_leaf=is_leaf)
else:
yield x
|