1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
|
from typing import List, Optional, Tuple
import string
import copy
import operator
import numbers
import torch
from torch import fx
from opt_einsum.parser import find_output_str
from .fx_utils import get_shape
_EINSUM_FUNCS = {torch.functional.einsum, torch.einsum}
# == Einsum fusion ==
def _get_einstrs(einstr: str) -> Tuple[List[str], str]:
if "..." in einstr:
raise NotImplementedError("Ellipsis `...` in einsum string not supported yet")
tmp = einstr.split("->")
if len(tmp) == 1:
ops = tmp[0]
out = find_output_str(ops)
elif len(tmp) == 2:
ops, out = tmp
else:
raise ValueError(f"Invalid einstr {einstr}")
return ops.split(","), out
def fuse_einsums(graph: fx.Graph, in_place: bool = False) -> fx.Graph:
"""Fuse einsums when possible.
When the output of one einsum is only used as an operand in another einsum, the two einsums can be fused into one.
Example:
.. code-block:: python
def fusable(x, y):
z = torch.einsum("ij,jk->ik", x, y)
return torch.einsum("ik,ij->i", z, x)
g = torch.fx.symbolic_trace(fusable)
print(fuse_einsums(g.graph).python_code(""))
gives::
import torch
def forward(self, x, y):
einsum_2 = torch.functional.einsum('ib,bk,ij->i', x, y, x); x = y = None
return einsum_2
Args:
graph: the graph to process.
in_place (bool, optional): whether to process ``graph`` in place.
Returns:
The graph with fused einsums.
"""
if not in_place:
graph = copy.deepcopy(graph)
for node in graph.nodes:
if node.op == "call_function" and node.target in _EINSUM_FUNCS:
our_inp_einstrs, our_out_einstr = _get_einstrs(node.args[0])
assert len(our_inp_einstrs) == len(node.args) - 1
avail_letters = iter(
set(string.ascii_lowercase)
- set.union(*(set(e) for e in our_inp_einstrs))
)
new_our_einstrs = []
new_our_args = []
we_fused_nodes = []
# Iterate over operands
for inp_idex, inp in enumerate(node.args[1:]):
if (
inp.op == "call_function"
and inp.target in _EINSUM_FUNCS
and len(inp.users) == 1
):
# This operand is the output of another einsum, and is not used by any other operation
# As a result, we can fuse it
its_inp_einstrs, its_out_einstr = _get_einstrs(inp.args[0])
if len(its_out_einstr) != len(our_inp_einstrs[inp_idex]):
raise RuntimeError(
f"Inconsistent rank: einsum `{node}`'s input {inp_idex} is the result of einsum {inp}; the output of `{inp}` is labeled `{its_out_einstr}` (rank {len(its_out_einstr)}), but the corresponding input of `{node}` is labeled `{our_inp_einstrs[inp_idex]}` (rank {len(our_inp_einstrs[inp_idex])})"
)
# First, we need to figure out which of its output dimensions correspond to our dimensions:
its_dim_to_ours = dict(
zip(its_out_einstr, our_inp_einstrs[inp_idex])
)
# assign any labels that don't show up in the output of the previous einsum --- and thus dont have labels in the current einsum --- to new letters
its_remaining_labels = set.union(
*(set(e) for e in its_inp_einstrs)
) - set(its_dim_to_ours.keys())
try:
its_dim_to_ours.update(
dict((i, next(avail_letters)) for i in its_remaining_labels)
)
except StopIteration:
# We ran out of letters
raise NotImplementedError(
f"At einsum {node}, ran out of letters when trying to fuse parameter einsum {inp}. A fallback for this case is not yet implimented."
)
else:
# We had enough letters, finish adding the fuse
del its_remaining_labels
new_our_args.extend(inp.args[1:])
new_our_einstrs.extend(
"".join(its_dim_to_ours[d] for d in es)
for es in its_inp_einstrs
)
we_fused_nodes.append(inp)
else:
# This argument is not from an einsum, or is from an einsum that is used elsewhere as well
# Thus we just pass it through
new_our_einstrs.append(our_inp_einstrs[inp_idex])
new_our_args.append(inp)
# -- end iter over prev einsum inputs --
# Set the new values for the einstrs
node.args = (f"{','.join(new_our_einstrs)}->{our_out_einstr}",) + tuple(
new_our_args
)
# Remove fused inputs
for to_remove in we_fused_nodes:
graph.erase_node(to_remove)
# -- end case for einsum nodes --
# -- end iter over nodes --
return graph
# == Scalar fusion ==
#
# Note that in general we do not support scalar fusion through in-place operations; it complicates following things through the compute graph too much
# TODO: ^ ???
# TODO: should the accumulation of constants happen in more than double precision?
def _get_node_and_scalar(node: fx.Node) -> Tuple[fx.Node, Optional[numbers.Number]]:
"""Get a multiplicative scalar for an operation, if applicable."""
# This supports in-place *= and /= because fx traces them as normal operator.mul/div.
if node.op == "call_function":
if node.target == operator.mul or node.target == torch.mul:
if isinstance(node.args[0], numbers.Number):
return node.args[1], node.args[0]
elif isinstance(node.args[1], numbers.Number):
return node.args[0], node.args[1]
elif node.target == operator.truediv or node.target == torch.div:
if isinstance(node.args[1], numbers.Number):
return node.args[0], 1.0 / node.args[1]
elif node.op == "call_method":
# TODO: this could _technically_ be wrong if the nodes `self` argument is not a (proxy to) a Tensor
if node.target == "mul":
if isinstance(node.args[1], numbers.Number):
return node.args[0], node.args[1]
elif node.target == "div":
if isinstance(node.args[1], numbers.Number):
return node.args[0], 1.0 / node.args[1]
return node, None
# Operations that are (almost) "multilinear", in the sense that they commute with scalar multiplication of their operands
SCALAR_COMMUTE_OPS = [
torch.einsum,
torch.functional.einsum,
torch.tensordot,
torch.functional.tensordot,
"permute",
# "reshape",
"mul",
"div",
operator.mul,
operator.truediv,
]
def prod(x):
"""Compute the product of a sequence."""
out = 1
for a in x:
out *= a
return out
def fuse_scalars(graph: fx.Graph, in_place: bool = False) -> fx.Graph:
"""Use the multilinearity of einsum to unify and remove constant scalars around einsums.
Args:
graph: the graph to process.
in_place (bool, optional): whether to process ``graph`` in place.
Returns:
The graph with fused scalars.
"""
if not in_place:
graph = copy.deepcopy(graph)
# Clear any previous state this graph has
for node in graph.nodes:
if hasattr(node, "in_lin_chain"):
delattr(node, "in_lin_chain")
# Find chains of multilinear ops
seen_nodes = set()
linear_chains = []
for node in graph.nodes:
if id(node) in seen_nodes:
continue
# Determine a linear chain
cur_linear_chain = []
while (
id(node) not in seen_nodes
and getattr(node, "target", None) in SCALAR_COMMUTE_OPS
):
seen_nodes.add(id(node))
node.in_lin_chain = len(linear_chains)
cur_linear_chain.append(node)
# Continue building the chain regardless, since the merger uses this
users = list(node.users.keys())
if len(users) > 0:
# Get the next node in the chain
node = users[0]
else:
# This isn't used in the graph at all, break the chain
node = None
if len(users) != 1:
# End this chain
break
# If the next user, which is now in node, was seen but is itself in a linear chain, this means we merge them
# TODO: thoroughly test this
if hasattr(node, "in_lin_chain") and len(cur_linear_chain) > 0:
# Merge
merge_into = node.in_lin_chain
for n in cur_linear_chain:
n.in_lin_chain = merge_into
linear_chains[merge_into].extend(cur_linear_chain)
else:
# This is a new chain
linear_chains.append(cur_linear_chain)
# Accumulate scalars in them
scalars = []
for lin_chain_i, lin_chain in enumerate(linear_chains):
if len(lin_chain) < 2:
# There's nothing to do here: either the chain is empty,
# or there's only one operation — even if its a scalar multiplication,
# theres nothing for us to do with it
scalars.append(None)
continue
# Accumulate scalars
scalar_node_idexes = []
total_scalar = 1.0
for node_i, node in enumerate(lin_chain):
new_node, scalar = _get_node_and_scalar(node)
if scalar is not None:
total_scalar *= scalar
scalar_node_idexes.append(node_i)
is_all_scalars = len(scalar_node_idexes) == len(lin_chain)
# Remove scalar nodes
for node_i in scalar_node_idexes:
node = lin_chain[node_i]
new_node, scalar = _get_node_and_scalar(node)
assert scalar is not None
if is_all_scalars and node_i == len(lin_chain) - 1:
# If it's all scalars, we just put the total_scalar into the last operation
# and don't save a scalar for later
with graph.inserting_after(node):
new_node = graph.call_function(
operator.mul,
(total_scalar, new_node),
)
total_scalar = None
node.replace_all_uses_with(new_node)
graph.erase_node(node)
# Save the scalar for this chain
scalars.append(total_scalar)
# Remove all of the removed scalar operations from the lin chain
# See https://stackoverflow.com/a/11303234/1008938
for index in sorted(
(scalar_node_idexes[:-1] if is_all_scalars else scalar_node_idexes),
reverse=True,
):
del lin_chain[index]
del seen_nodes
# Make sure everything is still OK
graph.lint()
# Now we have chains without scalar operations; we can go through and add back in the scalars in the optimal place
for lin_chain_i, lin_chain in enumerate(linear_chains):
if (
len(lin_chain) == 0
or scalars[lin_chain_i] == 1.0
or scalars[lin_chain_i] is None
):
# Nothing to do with an empty chain
# No reason to add back a scalar that does nothing
# None signals don't process from above
continue
# Find the smallest argument or the output
smallest_node_i = None
smallest_arg_i = None
smallest_size = float("inf")
for node_i, node in enumerate(lin_chain):
for arg_i, arg in enumerate(node.args):
if not isinstance(arg, fx.Node):
continue
shape = get_shape(arg)
if shape is not None and prod(shape) < smallest_size:
smallest_node_i = node_i
smallest_arg_i = arg_i
smallest_size = prod(shape)
# Put the accumulated scalar on a node
if (smallest_node_i is None) or (
get_shape(lin_chain[-1]) is not None
and prod(get_shape(lin_chain[-1])) < smallest_size
):
# The output is the smallest, put it there
# OR there was no smallest argument, put it on the end of the chain
with graph.inserting_after(lin_chain[-1]):
new_node = graph.call_function(operator.mul, tuple()) # placeholder
lin_chain[-1].replace_all_uses_with(new_node)
new_node.args = (lin_chain[-1], scalars[lin_chain_i])
else:
# The smallest was someone's arg, so we replace that with a scalar multiplication:
with graph.inserting_before(lin_chain[smallest_node_i]):
new_arg = graph.call_function(
operator.mul,
(
lin_chain[smallest_node_i].args[smallest_arg_i],
scalars[lin_chain_i],
),
)
new_args = list(lin_chain[smallest_node_i].args)
new_args[smallest_arg_i] = new_arg
lin_chain[smallest_node_i].args = tuple(new_args)
graph.lint()
return graph
|