File: _msgpack_numpy.py

package info (click to toggle)
python-srsly 2.5.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,852 kB
  • sloc: python: 21,404; ansic: 4,160; cpp: 51; sh: 12; makefile: 6
file content (94 lines) | stat: -rw-r--r-- 2,717 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python

"""
Support for serialization of numpy data types with msgpack.
"""

# Copyright (c) 2013-2018, Lev E. Givon
# All rights reserved.
# Distributed under the terms of the BSD license:
# http://www.opensource.org/licenses/bsd-license
try:
    import numpy as np

    has_numpy = True
except ImportError:
    has_numpy = False

try:
    import cupy

    has_cupy = True
except ImportError:
    has_cupy = False


def encode_numpy(obj, chain=None):
    """
    Data encoder for serializing numpy data types.
    """
    if not has_numpy:
        return obj if chain is None else chain(obj)
    if has_cupy and isinstance(obj, cupy.ndarray):
        obj = obj.get()
    if isinstance(obj, np.ndarray):
        # If the dtype is structured, store the interface description;
        # otherwise, store the corresponding array protocol type string:
        if obj.dtype.kind == "V":
            kind = b"V"
            descr = obj.dtype.descr
        else:
            kind = b""
            descr = obj.dtype.str
        return {
            b"nd": True,
            b"type": descr,
            b"kind": kind,
            b"shape": obj.shape,
            b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(),
        }
    elif isinstance(obj, (np.bool_, np.number)):
        return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data}
    elif isinstance(obj, complex):
        return {b"complex": True, b"data": obj.__repr__()}
    else:
        return obj if chain is None else chain(obj)


def tostr(x):
    if isinstance(x, bytes):
        return x.decode()
    else:
        return str(x)


def decode_numpy(obj, chain=None):
    """
    Decoder for deserializing numpy data types.
    """

    try:
        if b"nd" in obj:
            if obj[b"nd"] is True:

                # Check if b'kind' is in obj to enable decoding of data
                # serialized with older versions (#20):
                if b"kind" in obj and obj[b"kind"] == b"V":
                    descr = [
                        tuple(tostr(t) if type(t) is bytes else t for t in d)
                        for d in obj[b"type"]
                    ]
                else:
                    descr = obj[b"type"]
                return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(
                    obj[b"shape"]
                )
            else:
                descr = obj[b"type"]
                return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0]
        elif b"complex" in obj:
            return complex(tostr(obj[b"data"]))
        else:
            return obj if chain is None else chain(obj)
    except KeyError:
        return obj if chain is None else chain(obj)