File: cuda.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (44 lines) | stat: -rw-r--r-- 1,181 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from __future__ import annotations

import dask
from dask.utils import typename

from distributed.protocol import pickle
from distributed.protocol.serialize import (
    ObjectDictSerializer,
    register_serialization_family,
)

cuda_serialize = dask.utils.Dispatch("cuda_serialize")
cuda_deserialize = dask.utils.Dispatch("cuda_deserialize")


def cuda_dumps(x):
    type_name = typename(type(x))
    try:
        dumps = cuda_serialize.dispatch(type(x))
    except TypeError:
        raise NotImplementedError(type_name)

    sub_header, frames = dumps(x)
    header = {
        "sub-header": sub_header,
        "type-serialized": pickle.dumps(type(x)),
        "serializer": "cuda",
        "compression": (False,) * len(frames),  # no compression for gpu data
    }
    return header, frames


def cuda_loads(header, frames):
    typ = pickle.loads(header["type-serialized"])
    loads = cuda_deserialize.dispatch(typ)
    return loads(header["sub-header"], frames)


register_serialization_family("cuda", cuda_dumps, cuda_loads)


cuda_object_with_dict_serializer = ObjectDictSerializer("cuda")

cuda_deserialize.register(dict)(cuda_object_with_dict_serializer.deserialize)