File: test_cupy.py

package info (click to toggle)
python-xarray 0.16.2-2
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 6,568 kB
  • sloc: python: 60,570; makefile: 236; sh: 38
file content (60 lines) | stat: -rw-r--r-- 1,632 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import numpy as np
import pandas as pd
import pytest

import xarray as xr

cp = pytest.importorskip("cupy")


@pytest.fixture
def toy_weather_data():
    """Construct the example DataSet from the Toy weather data example.

    http://xarray.pydata.org/en/stable/examples/weather-data.html

    Here we construct the DataSet exactly as shown in the example and then
    convert the numpy arrays to cupy.

    """
    np.random.seed(123)
    times = pd.date_range("2000-01-01", "2001-12-31", name="time")
    annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28))

    base = 10 + 15 * annual_cycle.reshape(-1, 1)
    tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3)
    tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3)

    ds = xr.Dataset(
        {
            "tmin": (("time", "location"), tmin_values),
            "tmax": (("time", "location"), tmax_values),
        },
        {"time": times, "location": ["IA", "IN", "IL"]},
    )

    ds.tmax.data = cp.asarray(ds.tmax.data)
    ds.tmin.data = cp.asarray(ds.tmin.data)

    return ds


def test_cupy_import():
    """Check the import worked."""
    assert cp


def test_check_data_stays_on_gpu(toy_weather_data):
    """Perform some operations and check the data stays on the GPU."""
    freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time")
    assert isinstance(freeze.data, cp.core.core.ndarray)


def test_where():
    from xarray.core.duck_array_ops import where

    data = cp.zeros(10)

    output = where(data < 1, 1, data).all()
    assert output
    assert isinstance(output, cp.ndarray)