1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
import numpy as np
import pytest
from xarray.core import dtypes
@pytest.mark.parametrize(
"args, expected",
[
([bool], bool),
([bool, np.string_], np.object_),
([np.float32, np.float64], np.float64),
([np.float32, np.string_], np.object_),
([np.unicode_, np.int64], np.object_),
([np.unicode_, np.unicode_], np.unicode_),
([np.bytes_, np.unicode_], np.object_),
],
)
def test_result_type(args, expected):
actual = dtypes.result_type(*args)
assert actual == expected
def test_result_type_scalar():
actual = dtypes.result_type(np.arange(3, dtype=np.float32), np.nan)
assert actual == np.float32
def test_result_type_dask_array():
# verify it works without evaluating dask arrays
da = pytest.importorskip("dask.array")
dask = pytest.importorskip("dask")
def error():
raise RuntimeError
array = da.from_delayed(dask.delayed(error)(), (), np.float64)
with pytest.raises(RuntimeError):
array.compute()
actual = dtypes.result_type(array)
assert actual == np.float64
# note that this differs from the behavior for scalar numpy arrays, which
# would get promoted to float32
actual = dtypes.result_type(array, np.array([0.5, 1.0], dtype=np.float32))
assert actual == np.float64
@pytest.mark.parametrize("obj", [1.0, np.inf, "ab", 1.0 + 1.0j, True])
def test_inf(obj):
assert dtypes.INF > obj
assert dtypes.NINF < obj
@pytest.mark.parametrize(
"kind, expected",
[
("a", (np.dtype("O"), "nan")), # dtype('S')
("b", (np.float32, "nan")), # dtype('int8')
("B", (np.float32, "nan")), # dtype('uint8')
("c", (np.dtype("O"), "nan")), # dtype('S1')
("D", (np.complex128, "(nan+nanj)")), # dtype('complex128')
("d", (np.float64, "nan")), # dtype('float64')
("e", (np.float16, "nan")), # dtype('float16')
("F", (np.complex64, "(nan+nanj)")), # dtype('complex64')
("f", (np.float32, "nan")), # dtype('float32')
("h", (np.float32, "nan")), # dtype('int16')
("H", (np.float32, "nan")), # dtype('uint16')
("i", (np.float64, "nan")), # dtype('int32')
("I", (np.float64, "nan")), # dtype('uint32')
("l", (np.float64, "nan")), # dtype('int64')
("L", (np.float64, "nan")), # dtype('uint64')
("m", (np.timedelta64, "NaT")), # dtype('<m8')
("M", (np.datetime64, "NaT")), # dtype('<M8')
("O", (np.dtype("O"), "nan")), # dtype('O')
("p", (np.float64, "nan")), # dtype('int64')
("P", (np.float64, "nan")), # dtype('uint64')
("q", (np.float64, "nan")), # dtype('int64')
("Q", (np.float64, "nan")), # dtype('uint64')
("S", (np.dtype("O"), "nan")), # dtype('S')
("U", (np.dtype("O"), "nan")), # dtype('<U')
("V", (np.dtype("O"), "nan")), # dtype('V')
],
)
def test_maybe_promote(kind, expected):
# 'g': np.float128 is not tested : not available on all platforms
# 'G': np.complex256 is not tested : not available on all platforms
actual = dtypes.maybe_promote(np.dtype(kind))
assert actual[0] == expected[0]
assert str(actual[1]) == expected[1]
|