File: __init__.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (119 lines) | stat: -rw-r--r-- 3,882 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# mypy: allow-untyped-defs
import sys
import warnings
from contextlib import contextmanager
from functools import lru_cache as _lru_cache
from typing import Any

from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule


try:
    import opt_einsum as _opt_einsum  # type: ignore[import]
except ImportError:
    _opt_einsum = None


@_lru_cache
def is_available() -> bool:
    r"""Return a bool indicating if opt_einsum is currently available.

    You must install opt-einsum in order for torch to automatically optimize einsum. To
    make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]``
    or by itself: ``pip install opt-einsum``. If the package is installed, torch will import
    it automatically and use it accordingly. Use this function to check whether opt-einsum
    was installed and properly imported by torch.
    """
    return _opt_einsum is not None


def get_opt_einsum() -> Any:
    r"""Return the opt_einsum package if opt_einsum is currently available, else None."""
    return _opt_einsum


def _set_enabled(_enabled: bool) -> None:
    if not is_available() and _enabled:
        raise ValueError(
            f"opt_einsum is not available, so setting `enabled` to {_enabled} will not reap "
            "the benefits of calculating an optimal path for einsum. torch.einsum will "
            "fall back to contracting from left to right. To enable this optimal path "
            "calculation, please install opt-einsum."
        )
    global enabled
    enabled = _enabled


def _get_enabled() -> bool:
    return enabled


def _set_strategy(_strategy: str) -> None:
    if not is_available():
        raise ValueError(
            f"opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. "
            "torch.einsum will bypass path calculation and simply contract from left to right. "
            "Please install opt_einsum or unset `strategy`."
        )
    if not enabled:
        raise ValueError(
            f"opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. "
            "torch.einsum will bypass path calculation and simply contract from left to right. "
            "Please set `enabled` to `True` as well or unset `strategy`."
        )
    if _strategy not in ["auto", "greedy", "optimal"]:
        raise ValueError(
            f"`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}"
        )
    global strategy
    strategy = _strategy


def _get_strategy() -> str:
    return strategy


def set_flags(_enabled=None, _strategy=None):
    orig_flags = (enabled, None if not is_available() else strategy)
    if _enabled is not None:
        _set_enabled(_enabled)
    if _strategy is not None:
        _set_strategy(_strategy)
    return orig_flags


@contextmanager
def flags(enabled=None, strategy=None):
    with __allow_nonbracketed_mutation():
        orig_flags = set_flags(enabled, strategy)
    try:
        yield
    finally:
        # recover the previous values
        with __allow_nonbracketed_mutation():
            set_flags(*orig_flags)


# The magic here is to allow us to intercept code like this:
#
#   torch.backends.opt_einsum.enabled = True


class OptEinsumModule(PropModule):
    def __init__(self, m, name):
        super().__init__(m, name)

    global enabled
    enabled = ContextProp(_get_enabled, _set_enabled)
    global strategy
    strategy = None
    if is_available():
        strategy = ContextProp(_get_strategy, _set_strategy)


# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)

enabled = True if is_available() else False
strategy = "auto" if is_available() else None