File: test_rmm.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (35 lines) | stat: -rw-r--r-- 1,166 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import annotations

import pytest

pytestmark = pytest.mark.gpu

from distributed.protocol import deserialize, serialize

numpy = pytest.importorskip("numpy")
cuda = pytest.importorskip("numba.cuda")
rmm = pytest.importorskip("rmm")


@pytest.mark.parametrize("size", [0, 3, 10])
@pytest.mark.parametrize("serializers", [("cuda",), ("dask",), ("pickle",)])
def test_serialize_rmm_device_buffer(size, serializers):
    if not hasattr(rmm, "DeviceBuffer"):
        pytest.skip("RMM pre-0.11.0 does not have DeviceBuffer")

    x_np = numpy.arange(size, dtype="u1")
    x = rmm.DeviceBuffer(size=size)
    cuda.to_device(x_np, to=cuda.as_cuda_array(x))

    header, frames = serialize(x, serializers=serializers)
    y = deserialize(header, frames, deserializers=serializers)
    y_np = y.copy_to_host()

    if serializers[0] == "cuda":
        assert header["sub-header"]["strides"] == (1,)
        assert all(hasattr(f, "__cuda_array_interface__") for f in frames)
    elif serializers[0] == "dask":
        assert header["sub-header"]["strides"] == (1,)
        assert all(isinstance(f, memoryview) for f in frames)

    assert (x_np == y_np).all()