File: proxystore.py

package info (click to toggle)
python-parsl 2025.01.13%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,072 kB
  • sloc: python: 23,817; makefile: 349; sh: 276; ansic: 45
file content (50 lines) | stat: -rw-r--r-- 1,624 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import io
import typing as t

import dill
from proxystore.store import Store

from parsl.serialize.base import SerializerBase


class ProxyStoreDeepPickler(dill.Pickler):
    """This class extends dill so that certain objects will be stored into
    ProxyStore rather than serialized directly. The selection of objects is
    made by a user-specified policy.
    """

    def __init__(self, *args: t.Any, should_proxy: t.Callable[[t.Any], bool], store: Store, **kwargs: t.Any) -> None:
        super().__init__(*args, **kwargs)
        self._store = store
        self._should_proxy = should_proxy

    def reducer_override(self, o: t.Any) -> t.Any:
        if self._should_proxy(o):
            proxy = self._store.proxy(o)
            return proxy.__reduce__()
        else:
            # fall through to dill
            return NotImplemented


class ProxyStoreSerializer(SerializerBase):

    def __init__(self, *, should_proxy: t.Optional[t.Callable[[t.Any], bool]] = None, store: t.Optional[Store] = None) -> None:
        self._store = store
        self._should_proxy = should_proxy

    def serialize(self, data: t.Any) -> bytes:
        assert self._store is not None
        assert self._should_proxy is not None

        assert data is not None

        f = io.BytesIO()
        pickler = ProxyStoreDeepPickler(file=f, store=self._store, should_proxy=self._should_proxy)
        pickler.dump(data)
        return f.getvalue()

    def deserialize(self, body: bytes) -> t.Any:
        # because we aren't customising deserialization, use regular
        # dill for deserialization
        return dill.loads(body)