File: test_2682_custom_pickler.py

package info (click to toggle)
python-awkward 2.6.5-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 23,088 kB
  • sloc: python: 148,689; cpp: 33,562; sh: 432; makefile: 21; javascript: 8
file content (105 lines) | stat: -rw-r--r-- 3,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import multiprocessing
import os
import pickle
import sys
from concurrent.futures import ProcessPoolExecutor

if sys.version_info < (3, 12):
    import importlib_metadata
else:
    import importlib.metadata as importlib_metadata

import pytest

import awkward as ak


def has_entry_point():
    return bool(importlib_metadata.entry_points(group="awkward.pickle.reduce").names)


pytestmark = pytest.mark.skipif(
    has_entry_point(),
    reason="Custom pickler is already registered!",
)


def _init_process_with_pickler(pickler_source: str, tmp_path):
    # Create custom plugin
    (tmp_path / "impl_pickler.py").write_bytes(pickler_source.encode("UTF-8"))
    dist_info = tmp_path / "impl_pickler-0.0.0.dist-info"
    dist_info.mkdir()
    (dist_info / "entry_points.txt").write_bytes(
        b"[awkward.pickle.reduce]\nimpl = impl_pickler:plugin\n"
    )
    sys.path.insert(0, os.fsdecode(tmp_path))


def _pickle_complex_array_and_return_form_impl():
    array = ak.Array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])[[0, 2]]
    return pickle.loads(pickle.dumps(array)).layout.form


def pickle_complex_array_and_return_form(pickler_source, tmp_path):
    """Create a new (spawned) process, and register the given pickler source
    via entrypoints"""
    with ProcessPoolExecutor(
        1,
        initializer=_init_process_with_pickler,
        initargs=(pickler_source, tmp_path),
        # Don't fork the current process with all of its state
        mp_context=multiprocessing.get_context("spawn"),
    ) as executor:
        pickle_future = executor.submit(_pickle_complex_array_and_return_form_impl)
        return pickle_future.result()


def test_default_pickler():
    assert _pickle_complex_array_and_return_form_impl() == ak.forms.from_dict(
        {"class": "ListOffsetArray", "offsets": "i64", "content": "int64"}
    )


def test_noop_pickler(tmp_path):
    assert pickle_complex_array_and_return_form(
        """
def plugin(obj, protocol: int):
    return NotImplemented""",
        tmp_path,
    ) == ak.forms.from_dict(
        {"class": "ListOffsetArray", "offsets": "i64", "content": "int64"}
    )


def test_non_packing_pickler(tmp_path):
    assert pickle_complex_array_and_return_form(
        """
def plugin(obj, protocol):
    import awkward as ak
    if isinstance(obj, ak.Array):
        form, length, container = ak.to_buffers(obj)
        return (
            object.__new__,
            (ak.Array,),
            (form.to_dict(), length, container, obj.behavior),
        )
    else:
        return NotImplemented""",
        tmp_path,
    ) == ak.forms.from_dict(
        {"class": "ListArray", "starts": "i64", "stops": "i64", "content": "int64"}
    )


def test_malformed_pickler(tmp_path):
    with pytest.raises(RuntimeError, match=r"malformed pickler!"):
        pickle_complex_array_and_return_form(
            """
def plugin(obj, protocol: int):
    raise RuntimeError('malformed pickler!')""",
            tmp_path,
        )