1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
import platform
import warnings
from packaging.version import Version
try: # pragma: no cover
import cupy
import cupy.cublas
import cupyx
has_cupy = True
cublas = cupy.cublas
cupy_version = Version(cupy.__version__)
try:
cupy.cuda.runtime.getDeviceCount()
has_cupy_gpu = True
except cupy.cuda.runtime.CUDARuntimeError:
has_cupy_gpu = False
if cupy_version.major >= 10:
# fromDlpack was deprecated in v10.0.0.
cupy_from_dlpack = cupy.from_dlpack
else:
cupy_from_dlpack = cupy.fromDlpack
except (ImportError, AttributeError):
cublas = None
cupy = None
cupyx = None
cupy_version = Version("0.0.0")
has_cupy = False
cupy_from_dlpack = None
has_cupy_gpu = False
try: # pragma: no cover
import torch
import torch.utils.dlpack
has_torch = True
has_torch_cuda_gpu = torch.cuda.device_count() != 0
has_torch_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_built()
has_torch_mps_gpu = has_torch_mps and torch.backends.mps.is_available()
has_torch_gpu = has_torch_cuda_gpu
torch_version = Version(str(torch.__version__))
has_torch_amp = (
torch_version >= Version("1.9.0")
and not torch.cuda.amp.common.amp_definitely_not_available()
)
except ImportError: # pragma: no cover
torch = None # type: ignore
has_torch = False
has_torch_cuda_gpu = False
has_torch_gpu = False
has_torch_mps = False
has_torch_mps_gpu = False
has_torch_amp = False
torch_version = Version("0.0.0")
def enable_tensorflow():
warn_msg = (
"Built-in TensorFlow support will be removed in Thinc v9. If you need "
"TensorFlow support in the future, you can transition to using a "
"custom copy of the current TensorFlowWrapper in your package or "
"project."
)
warnings.warn(warn_msg, DeprecationWarning)
global tensorflow, has_tensorflow, has_tensorflow_gpu
import tensorflow
import tensorflow.experimental.dlpack
has_tensorflow = True
has_tensorflow_gpu = len(tensorflow.config.get_visible_devices("GPU")) > 0
tensorflow = None
has_tensorflow = False
has_tensorflow_gpu = False
def enable_mxnet():
warn_msg = (
"Built-in MXNet support will be removed in Thinc v9. If you need "
"MXNet support in the future, you can transition to using a "
"custom copy of the current MXNetWrapper in your package or "
"project."
)
warnings.warn(warn_msg, DeprecationWarning)
global mxnet, has_mxnet
import mxnet
has_mxnet = True
mxnet = None
has_mxnet = False
try:
import h5py
except ImportError: # pragma: no cover
h5py = None
try: # pragma: no cover
import os_signpost
has_os_signpost = True
except ImportError:
os_signpost = None
has_os_signpost = False
try: # pragma: no cover
import blis
has_blis = True
except ImportError:
blis = None
has_blis = False
# AppleOps is available unconditionally on macOS.
has_apple_ops = platform.system() == "Darwin"
has_gpu = has_cupy_gpu or has_torch_mps_gpu
__all__ = [
"cupy",
"cupyx",
"torch",
"tensorflow",
"mxnet",
"h5py",
"os_signpost",
]
|