File: __init__.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (55 lines) | stat: -rw-r--r-- 1,642 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# mypy: allow-untyped-defs
from functools import lru_cache as _lru_cache
from typing import Optional, TYPE_CHECKING

import torch
from torch.library import Library as _Library


__all__ = ["is_built", "is_available", "is_macos13_or_newer", "is_macos_or_newer"]


def is_built() -> bool:
    r"""Return whether PyTorch is built with MPS support.

    Note that this doesn't necessarily mean MPS is available; just that
    if this PyTorch binary were run a machine with working MPS drivers
    and devices, we would be able to use it.
    """
    return torch._C._has_mps


@_lru_cache
def is_available() -> bool:
    r"""Return a bool indicating if MPS is currently available."""
    return torch._C._mps_is_available()


@_lru_cache
def is_macos_or_newer(major: int, minor: int) -> bool:
    r"""Return a bool indicating whether MPS is running on given MacOS or newer."""
    return torch._C._mps_is_on_macos_or_newer(major, minor)


@_lru_cache
def is_macos13_or_newer(minor: int = 0) -> bool:
    r"""Return a bool indicating whether MPS is running on MacOS 13 or newer."""
    return torch._C._mps_is_on_macos_or_newer(13, minor)


_lib: Optional[_Library] = None


def _init():
    r"""Register prims as implementation of var_mean and group_norm."""
    global _lib

    if _lib is not None or not is_built():
        return

    from torch._decomp.decompositions import native_group_norm_backward
    from torch._refs import native_group_norm

    _lib = _Library("aten", "IMPL")  # noqa: TOR901
    _lib.impl("native_group_norm", native_group_norm, "MPS")
    _lib.impl("native_group_norm_backward", native_group_norm_backward, "MPS")