File: torch.py

package info (click to toggle)
python-opt-einsum 3.4.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,772 kB
  • sloc: python: 4,124; makefile: 31; javascript: 15
file content (130 lines) | stat: -rw-r--r-- 3,583 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Required functions for optimized contractions of numpy arrays using pytorch."""

from opt_einsum.helpers import has_array_interface
from opt_einsum.parser import convert_to_valid_einsum_chars
from opt_einsum.sharing import to_backend_cache_wrap

__all__ = [
    "transpose",
    "einsum",
    "tensordot",
    "to_torch",
    "build_expression",
    "evaluate_constants",
]

_TORCH_DEVICE = None
_TORCH_HAS_TENSORDOT = None

_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"


def _get_torch_and_device():
    global _TORCH_DEVICE
    global _TORCH_HAS_TENSORDOT

    if _TORCH_DEVICE is None:
        import torch  # type: ignore

        device = "cuda" if torch.cuda.is_available() else "cpu"
        _TORCH_DEVICE = torch, device
        _TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")

    return _TORCH_DEVICE


def transpose(a, axes):
    """Normal torch transpose is only valid for 2D matrices."""
    return a.permute(*axes)


def einsum(equation, *operands, **kwargs):
    """Variadic version of torch.einsum to match numpy api."""
    # rename symbols to support PyTorch 0.4.1 and earlier,
    # which allow only symbols a-z.
    equation = convert_to_valid_einsum_chars(equation)

    torch, _ = _get_torch_and_device()
    return torch.einsum(equation, operands)


def tensordot(x, y, axes=2):
    """Simple translation of tensordot syntax to einsum."""
    torch, _ = _get_torch_and_device()

    if _TORCH_HAS_TENSORDOT:
        return torch.tensordot(x, y, dims=axes)

    xnd = x.ndimension()
    ynd = y.ndimension()

    # convert int argument to (list[int], list[int])
    if isinstance(axes, int):
        axes = range(xnd - axes, xnd), range(axes)

    # convert (int, int) to (list[int], list[int])
    if isinstance(axes[0], int):
        axes = (axes[0],), axes[1]
    if isinstance(axes[1], int):
        axes = axes[0], (axes[1],)

    # initialize empty indices
    x_ix = [None] * xnd
    y_ix = [None] * ynd
    out_ix = []

    # fill in repeated indices
    available_ix = iter(_torch_symbols_base)
    for ax1, ax2 in zip(*axes):
        repeat = next(available_ix)
        x_ix[ax1] = repeat
        y_ix[ax2] = repeat

    # fill in the rest, and maintain output order
    for i in range(xnd):
        if x_ix[i] is None:
            leave = next(available_ix)
            x_ix[i] = leave
            out_ix.append(leave)
    for i in range(ynd):
        if y_ix[i] is None:
            leave = next(available_ix)
            y_ix[i] = leave
            out_ix.append(leave)

    # form full string and contract!
    einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
    return einsum(einsum_str, x, y)


@to_backend_cache_wrap
def to_torch(array):
    torch, device = _get_torch_and_device()

    if has_array_interface(array):
        return torch.from_numpy(array).to(device)

    return array


def build_expression(_, expr):  # pragma: no cover
    """Build a torch function based on ``arrays`` and ``expr``."""

    def torch_contract(*arrays):
        torch_arrays = [to_torch(x) for x in arrays]
        torch_out = expr._contract(torch_arrays, backend="torch")

        if torch_out.device.type == "cpu":
            return torch_out.numpy()

        return torch_out.cpu().numpy()

    return torch_contract


def evaluate_constants(const_arrays, expr):
    """Convert constant arguments to torch, and perform any possible constant
    contractions.
    """
    const_arrays = [to_torch(x) for x in const_arrays]
    return expr(*const_arrays, backend="torch", evaluate_constants=True)