1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
from __future__ import annotations
import pytest
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
pa = pytest.importorskip("pyarrow")
from distributed.shuffle._scheduler_extension import get_worker_for
from distributed.shuffle._worker_extension import (
ShuffleWorkerExtension,
split_by_partition,
split_by_worker,
)
from distributed.utils_test import gen_cluster
@gen_cluster([("", 1)])
async def test_installation(s, a):
ext = a.extensions["shuffle"]
assert isinstance(ext, ShuffleWorkerExtension)
assert a.handlers["shuffle_receive"] == ext.shuffle_receive
assert a.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done
def test_split_by_worker():
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3
for part in range(npartitions):
worker_for_mapping[part] = get_worker_for(part, workers, npartitions)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert set(out) == {"alice", "bob"}
assert list(out["alice"].to_pandas().columns) == list(df.columns)
assert sum(map(len, out.values())) == len(df)
def test_split_by_worker_empty():
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert out == {}
def test_split_by_worker_many_workers():
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [5, 7, 5, 0, 1],
}
)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = get_worker_for(part, workers, npartitions)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert get_worker_for(5, workers, npartitions) in out
assert get_worker_for(0, workers, npartitions) in out
assert get_worker_for(7, workers, npartitions) in out
assert get_worker_for(1, workers, npartitions) in out
assert sum(map(len, out.values())) == len(df)
def test_split_by_partition():
pa = pytest.importorskip("pyarrow")
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [3, 1, 2, 3, 1],
}
)
t = pa.Table.from_pandas(df)
out = split_by_partition(t, "_partition")
assert set(out) == {1, 2, 3}
assert out[1].column_names == list(df.columns)
assert sum(map(len, out.values())) == len(df)
|