File: rmm.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (48 lines) | stat: -rw-r--r-- 1,507 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import annotations

import numba
import numba.cuda
import numpy
import rmm

from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import dask_deserialize, dask_serialize

# Used for RMM 0.11.0+ otherwise Numba serializers used
if hasattr(rmm, "DeviceBuffer"):

    @cuda_serialize.register(rmm.DeviceBuffer)
    def cuda_serialize_rmm_device_buffer(x):
        header = x.__cuda_array_interface__.copy()
        header["strides"] = (1,)
        frames = [x]
        return header, frames

    @cuda_deserialize.register(rmm.DeviceBuffer)
    def cuda_deserialize_rmm_device_buffer(header, frames):
        (arr,) = frames

        # We should already have `DeviceBuffer`
        # as RMM is used preferably for allocations
        # when it is available (as it is here).
        assert isinstance(arr, rmm.DeviceBuffer)

        return arr

    @dask_serialize.register(rmm.DeviceBuffer)
    def dask_serialize_rmm_device_buffer(x):
        header, frames = cuda_serialize_rmm_device_buffer(x)
        frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
        return header, frames

    @dask_deserialize.register(rmm.DeviceBuffer)
    def dask_deserialize_rmm_device_buffer(header, frames):
        (frame,) = frames

        arr = numpy.asarray(memoryview(frame))
        ptr = arr.__array_interface__["data"][0]
        size = arr.nbytes

        buf = rmm.DeviceBuffer(ptr=ptr, size=size)

        return buf