1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
import pytest
import astra
cuda_present = astra.use_cuda()
cupy_present = False
if cuda_present:
try:
import cupy
cupy_present = True
except Exception:
pass
pytorch_present = False
try:
import torch
pytorch_present = True
except Exception:
pass
pytorch_cuda_present = False
if pytorch_present:
try:
import torch
if torch.cuda.is_available():
pytorch_cuda_present = True
except Exception:
pass
jax_present = False
try:
import jax
jax_present = True
except Exception:
pass
jax_cuda_present = False
if jax_present:
try:
import jax
if len(jax.devices('cuda')) > 0:
jax_cuda_present = True
except Exception:
pass
backends_to_skip = []
if not cupy_present:
backends_to_skip.append('cupy')
if not jax_present:
backends_to_skip.append('jax_cpu')
if not jax_cuda_present:
backends_to_skip.append('jax_cuda')
if not pytorch_present:
backends_to_skip.append('pytorch_cpu')
if not pytorch_cuda_present:
backends_to_skip.append('pytorch_cuda')
def pytest_collection_modifyitems(config, items):
for item in items:
if hasattr(item, 'callspec') and item.callspec.params.get('backend') in backends_to_skip:
item.add_marker(pytest.mark.skip('Backend skipped because it is unavailable'))
|