1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
# mypy: allow-untyped-defs
r"""
This package implements abstractions found in ``torch.cuda``
to facilitate writing device-agnostic code.
"""
from contextlib import AbstractContextManager
from typing import Any, Optional, Union
import torch
from .. import device as _device
from . import amp
__all__ = [
"is_available",
"synchronize",
"current_device",
"current_stream",
"stream",
"set_device",
"device_count",
"Stream",
"StreamContext",
"Event",
]
_device_t = Union[_device, str, int, None]
def _is_avx2_supported() -> bool:
r"""Returns a bool indicating if CPU supports AVX2."""
return torch._C._cpu._is_avx2_supported()
def _is_avx512_supported() -> bool:
r"""Returns a bool indicating if CPU supports AVX512."""
return torch._C._cpu._is_avx512_supported()
def _is_avx512_bf16_supported() -> bool:
r"""Returns a bool indicating if CPU supports AVX512_BF16."""
return torch._C._cpu._is_avx512_bf16_supported()
def _is_vnni_supported() -> bool:
r"""Returns a bool indicating if CPU supports VNNI."""
# Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later.
return torch._C._cpu._is_avx512_vnni_supported()
def _is_amx_tile_supported() -> bool:
r"""Returns a bool indicating if CPU supports AMX_TILE."""
return torch._C._cpu._is_amx_tile_supported()
def _is_amx_fp16_supported() -> bool:
r"""Returns a bool indicating if CPU supports AMX FP16."""
return torch._C._cpu._is_amx_fp16_supported()
def _init_amx() -> bool:
r"""Initializes AMX instructions."""
return torch._C._cpu._init_amx()
def _is_arm_sve_supported() -> bool:
r"""Returns a bool indicating if CPU supports Arm SVE."""
return torch._C._cpu._is_arm_sve_supported()
def is_available() -> bool:
r"""Returns a bool indicating if CPU is currently available.
N.B. This function only exists to facilitate device-agnostic code
"""
return True
def synchronize(device: _device_t = None) -> None:
r"""Waits for all kernels in all streams on the CPU device to complete.
Args:
device (torch.device or int, optional): ignored, there's only one CPU device.
N.B. This function only exists to facilitate device-agnostic code.
"""
class Stream:
"""
N.B. This class only exists to facilitate device-agnostic code
"""
def __init__(self, priority: int = -1) -> None:
pass
def wait_stream(self, stream) -> None:
pass
class Event:
def query(self) -> bool:
return True
def record(self, stream=None) -> None:
pass
def synchronize(self) -> None:
pass
def wait(self, stream=None) -> None:
pass
_default_cpu_stream = Stream()
_current_stream = _default_cpu_stream
def current_stream(device: _device_t = None) -> Stream:
r"""Returns the currently selected :class:`Stream` for a given device.
Args:
device (torch.device or int, optional): Ignored.
N.B. This function only exists to facilitate device-agnostic code
"""
return _current_stream
class StreamContext(AbstractContextManager):
r"""Context-manager that selects a given stream.
N.B. This class only exists to facilitate device-agnostic code
"""
cur_stream: Optional[Stream]
def __init__(self, stream):
self.stream = stream
self.prev_stream = _default_cpu_stream
def __enter__(self):
cur_stream = self.stream
if cur_stream is None:
return
global _current_stream
self.prev_stream = _current_stream
_current_stream = cur_stream
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
cur_stream = self.stream
if cur_stream is None:
return
global _current_stream
_current_stream = self.prev_stream
def stream(stream: Stream) -> AbstractContextManager:
r"""Wrapper around the Context-manager StreamContext that
selects a given stream.
N.B. This function only exists to facilitate device-agnostic code
"""
return StreamContext(stream)
def device_count() -> int:
r"""Returns number of CPU devices (not cores). Always 1.
N.B. This function only exists to facilitate device-agnostic code
"""
return 1
def set_device(device: _device_t) -> None:
r"""Sets the current device, in CPU we do nothing.
N.B. This function only exists to facilitate device-agnostic code
"""
def current_device() -> str:
r"""Returns current device for cpu. Always 'cpu'.
N.B. This function only exists to facilitate device-agnostic code
"""
return "cpu"
|