File: utils.py

package info (click to toggle)
dask.distributed 2024.12.1%2Bds-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 12,588 kB
  • sloc: python: 96,954; javascript: 1,549; sh: 390; makefile: 220
file content (55 lines) | stat: -rw-r--r-- 1,630 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations

import itertools
from typing import Any

from distributed.core import PooledRPCCall
from distributed.shuffle._core import ShuffleId, ShuffleRun

UNPACK_PREFIX = "shuffle_p2p"
try:
    import dask.dataframe as dd

    if dd._dask_expr_enabled():
        UNPACK_PREFIX = "p2pshuffle"
except ImportError:
    pass


class PooledRPCShuffle(PooledRPCCall):
    def __init__(self, shuffle: ShuffleRun):
        self.shuffle = shuffle

    def __getattr__(self, key):
        async def _(**kwargs):
            from distributed.protocol.serialize import _nested_deserialize

            method_name = key.replace("shuffle_", "")
            kwargs.pop("shuffle_id", None)
            kwargs.pop("run_id", None)
            # TODO: This is a bit awkward. At some point the arguments are
            # already getting wrapped with a `Serialize`. We only want to unwrap
            # here.
            kwargs = _nested_deserialize(kwargs)
            meth = getattr(self.shuffle, method_name)
            return _nested_deserialize(await meth(**kwargs))

        return _


class AbstractShuffleTestPool:
    _shuffle_run_id_iterator = itertools.count()

    def __init__(self, *args, **kwargs):
        self.shuffles = {}

    def __call__(self, addr: str, *args: Any, **kwargs: Any) -> PooledRPCShuffle:
        return PooledRPCShuffle(self.shuffles[addr])

    async def shuffle_barrier(
        self, id: ShuffleId, run_id: int, consistent: bool
    ) -> dict[str, None]:
        out = {}
        for addr, s in self.shuffles.items():
            out[addr] = await s.inputs_done()
        return out