1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
|
from __future__ import annotations
import pytest
pytestmark = pytest.mark.gpu
from distributed.protocol import deserialize, serialize
numpy = pytest.importorskip("numpy")
cuda = pytest.importorskip("numba.cuda")
rmm = pytest.importorskip("rmm")
@pytest.mark.parametrize("size", [0, 3, 10])
@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)])
def test_serialize_rmm_device_buffer(size, serializers):
if not hasattr(rmm, "DeviceBuffer"):
pytest.skip("RMM pre-0.11.0 does not have DeviceBuffer")
x_np = numpy.arange(size, dtype="u1")
x = rmm.DeviceBuffer(size=size)
cuda.to_device(x_np, to=cuda.as_cuda_array(x))
header, frames = serialize(x, serializers=serializers)
y = deserialize(header, frames, deserializers=serializers)
y_np = y.copy_to_host()
if serializers[0] == "cuda":
assert header["sub-header"]["strides"] == (1,)
assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
elif serializers[0] == "dask":
assert header["sub-header"]["strides"] == (1,)
assert all(isinstance(f, memoryview) for f in frames)
assert (x_np == y_np).all()
|