1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
"""Required functions for optimized contractions of numpy arrays using pytorch."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.parser import convert_to_valid_einsum_chars
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = [
"transpose",
"einsum",
"tensordot",
"to_torch",
"build_expression",
"evaluate_constants",
]
_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None
_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def _get_torch_and_device():
global _TORCH_DEVICE
global _TORCH_HAS_TENSORDOT
if _TORCH_DEVICE is None:
import torch # type: ignore
device = "cuda" if torch.cuda.is_available() else "cpu"
_TORCH_DEVICE = torch, device
_TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")
return _TORCH_DEVICE
def transpose(a, axes):
"""Normal torch transpose is only valid for 2D matrices."""
return a.permute(*axes)
def einsum(equation, *operands, **kwargs):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)
torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)
def tensordot(x, y, axes=2):
"""Simple translation of tensordot syntax to einsum."""
torch, _ = _get_torch_and_device()
if _TORCH_HAS_TENSORDOT:
return torch.tensordot(x, y, dims=axes)
xnd = x.ndimension()
ynd = y.ndimension()
# convert int argument to (list[int], list[int])
if isinstance(axes, int):
axes = range(xnd - axes, xnd), range(axes)
# convert (int, int) to (list[int], list[int])
if isinstance(axes[0], int):
axes = (axes[0],), axes[1]
if isinstance(axes[1], int):
axes = axes[0], (axes[1],)
# initialize empty indices
x_ix = [None] * xnd
y_ix = [None] * ynd
out_ix = []
# fill in repeated indices
available_ix = iter(_torch_symbols_base)
for ax1, ax2 in zip(*axes):
repeat = next(available_ix)
x_ix[ax1] = repeat
y_ix[ax2] = repeat
# fill in the rest, and maintain output order
for i in range(xnd):
if x_ix[i] is None:
leave = next(available_ix)
x_ix[i] = leave
out_ix.append(leave)
for i in range(ynd):
if y_ix[i] is None:
leave = next(available_ix)
y_ix[i] = leave
out_ix.append(leave)
# form full string and contract!
einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
return einsum(einsum_str, x, y)
@to_backend_cache_wrap
def to_torch(array):
torch, device = _get_torch_and_device()
if has_array_interface(array):
return torch.from_numpy(array).to(device)
return array
def build_expression(_, expr): # pragma: no cover
"""Build a torch function based on ``arrays`` and ``expr``."""
def torch_contract(*arrays):
torch_arrays = [to_torch(x) for x in arrays]
torch_out = expr._contract(torch_arrays, backend="torch")
if torch_out.device.type == "cpu":
return torch_out.numpy()
return torch_out.cpu().numpy()
return torch_contract
def evaluate_constants(const_arrays, expr):
"""Convert constant arguments to torch, and perform any possible constant
contractions.
"""
const_arrays = [to_torch(x) for x in const_arrays]
return expr(*const_arrays, backend="torch", evaluate_constants=True)
|