File: test_graph.py

package info (click to toggle)
dask.distributed 2022.12.1%2Bds.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 10,164 kB
  • sloc: python: 81,938; javascript: 1,549; makefile: 228; sh: 100
file content (65 lines) | stat: -rw-r--r-- 2,186 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

import asyncio

import pytest

pd = pytest.importorskip("pandas")
pytest.importorskip("dask.dataframe")
pytest.importorskip("pyarrow")

import dask
import dask.dataframe as dd
from dask.blockwise import Blockwise
from dask.utils_test import hlg_layer_topological

from distributed.utils_test import gen_cluster


def test_basic(client):
    df = dd.demo.make_timeseries(freq="15D", partition_freq="30D")
    df["name"] = df["name"].astype("string[python]")
    shuffled = df.shuffle("id", shuffle="p2p")

    (opt,) = dask.optimize(shuffled)
    assert isinstance(hlg_layer_topological(opt.dask, 0), Blockwise)
    # blockwise -> barrier -> unpack -> drop_by_shallow_copy

    dd.utils.assert_eq(shuffled, df.shuffle("id", shuffle="tasks"), scheduler=client)
    # ^ NOTE: this works because `assert_eq` sorts the rows before comparing


@gen_cluster([("", 2)] * 4, client=True)
async def test_basic_state(c, s, *workers):
    df = dd.demo.make_timeseries(freq="15D", partition_freq="30D")
    shuffled = df.shuffle("id", shuffle="p2p")

    exts = [w.extensions["shuffle"] for w in workers]
    for ext in exts:
        assert not ext.shuffles

    f = c.compute(shuffled)
    # TODO this is a bad/pointless test. the `f.done()` is necessary in case the shuffle is really fast.
    # To test state more thoroughly, we'd need a way to 'stop the world' at various stages. Like have the
    # scheduler pause everything when the barrier is reached. Not sure yet how to implement that.
    while not all(len(ext.shuffles) == 1 for ext in exts) and not f.done():
        await asyncio.sleep(0.1)

    await f
    assert all(not ext.shuffles for ext in exts)


def test_multiple_linear(client):
    df = dd.demo.make_timeseries(freq="15D", partition_freq="30D")
    df["name"] = df["name"].astype("string[python]")
    s1 = df.shuffle("id", shuffle="p2p")
    s1["x"] = s1["x"] + 1
    s2 = s1.shuffle("x", shuffle="p2p")

    # TODO eventually test for fusion between s1's unpacks, the `+1`, and s2's `transfer`s

    dd.utils.assert_eq(
        s2,
        df.assign(x=lambda df: df.x + 1).shuffle("x", shuffle="tasks"),
        scheduler=client,
    )