1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
import sys
import os
import torch
import warnings
from contextlib import contextmanager
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation
try:
from torch._C import _cudnn
except ImportError:
_cudnn = None # type: ignore[assignment]
# Write:
#
# torch.backends.cudnn.enabled = False
#
# to globally disable CuDNN/MIOpen
__cudnn_version = None
if _cudnn is not None:
def _init():
global __cudnn_version
if __cudnn_version is None:
__cudnn_version = _cudnn.getVersionInt()
runtime_version = _cudnn.getRuntimeVersion()
compile_version = _cudnn.getCompileVersion()
runtime_major, runtime_minor, _ = runtime_version
compile_major, compile_minor, _ = compile_version
# Different major versions are always incompatible
# Starting with cuDNN 7, minor versions are backwards-compatible
# Not sure about MIOpen (ROCm), so always do a strict check
if runtime_major != compile_major:
cudnn_compatible = False
elif runtime_major < 7 or not _cudnn.is_cuda:
cudnn_compatible = runtime_minor == compile_minor
else:
cudnn_compatible = runtime_minor >= compile_minor
if not cudnn_compatible:
base_error_msg = (f'cuDNN version incompatibility: '
f'PyTorch was compiled against {compile_version} '
f'but found runtime version {runtime_version}. '
f'PyTorch already comes bundled with cuDNN. '
f'One option to resolving this error is to ensure PyTorch '
f'can find the bundled cuDNN.')
if 'LD_LIBRARY_PATH' in os.environ:
ld_library_path = os.environ.get('LD_LIBRARY_PATH', '')
if any(substring in ld_library_path for substring in ['cuda', 'cudnn']):
raise RuntimeError(f'{base_error_msg}'
f'Looks like your LD_LIBRARY_PATH contains incompatible version of cudnn'
f'Please either remove it from the path or install cudnn {compile_version}')
else:
raise RuntimeError(f'{base_error_msg}'
f'one possibility is that there is a '
f'conflicting cuDNN in LD_LIBRARY_PATH.')
else:
raise RuntimeError(base_error_msg)
return True
else:
def _init():
return False
def version():
"""Returns the version of cuDNN"""
if not _init():
return None
return __cudnn_version
CUDNN_TENSOR_DTYPES = {
torch.half,
torch.float,
torch.double,
}
def is_available():
r"""Returns a bool indicating if CUDNN is currently available."""
return torch._C.has_cudnn
def is_acceptable(tensor):
if not torch._C._get_cudnn_enabled():
return False
if tensor.device.type != 'cuda' or tensor.dtype not in CUDNN_TENSOR_DTYPES:
return False
if not is_available():
warnings.warn(
"PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild "
"PyTorch making sure the library is visible to the build system.")
return False
if not _init():
warnings.warn('cuDNN/MIOpen library not found. Check your {libpath}'.format(
libpath={
'darwin': 'DYLD_LIBRARY_PATH',
'win32': 'PATH'
}.get(sys.platform, 'LD_LIBRARY_PATH')))
return False
return True
def set_flags(_enabled=None, _benchmark=None, _benchmark_limit=None, _deterministic=None, _allow_tf32=None):
orig_flags = (torch._C._get_cudnn_enabled(),
torch._C._get_cudnn_benchmark(),
None if not is_available() else torch._C._cuda_get_cudnn_benchmark_limit(),
torch._C._get_cudnn_deterministic(),
torch._C._get_cudnn_allow_tf32())
if _enabled is not None:
torch._C._set_cudnn_enabled(_enabled)
if _benchmark is not None:
torch._C._set_cudnn_benchmark(_benchmark)
if _benchmark_limit is not None and is_available():
torch._C._cuda_set_cudnn_benchmark_limit(_benchmark_limit)
if _deterministic is not None:
torch._C._set_cudnn_deterministic(_deterministic)
if _allow_tf32 is not None:
torch._C._set_cudnn_allow_tf32(_allow_tf32)
return orig_flags
@contextmanager
def flags(enabled=False, benchmark=False, benchmark_limit=10, deterministic=False, allow_tf32=True):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, benchmark, benchmark_limit, deterministic, allow_tf32)
try:
yield
finally:
# recover the previous values
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
# The magic here is to allow us to intercept code like this:
#
# torch.backends.<cudnn|mkldnn>.enabled = True
class CudnnModule(PropModule):
def __init__(self, m, name):
super(CudnnModule, self).__init__(m, name)
enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled)
deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic)
benchmark = ContextProp(torch._C._get_cudnn_benchmark, torch._C._set_cudnn_benchmark)
benchmark_limit = None
if is_available():
benchmark_limit = ContextProp(torch._C._cuda_get_cudnn_benchmark_limit, torch._C._cuda_set_cudnn_benchmark_limit)
allow_tf32 = ContextProp(torch._C._get_cudnn_allow_tf32, torch._C._set_cudnn_allow_tf32)
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = CudnnModule(sys.modules[__name__], __name__)
# Add type annotation for the replaced module
enabled: bool
deterministic: bool
benchmark: bool
allow_tf32: bool
benchmark_limit: int
|