File: test_dtypes.py

package info (click to toggle)
rasterio 1.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 22,732 kB
  • sloc: python: 23,119; sh: 947; makefile: 275; xml: 29
file content (169 lines) | stat: -rw-r--r-- 4,463 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import numpy as np
import pytest

import rasterio
from rasterio import (
    ubyte,
    uint8,
    uint16,
    uint32,
    uint64,
    int8,
    int16,
    int32,
    int64,
    float16,
    float32,
    float64,
    complex_,
    complex_int16,
)
from rasterio.dtypes import (
    _gdal_typename,
    is_ndarray,
    check_dtype,
    get_minimum_dtype,
    can_cast_dtype,
    validate_dtype,
    _is_complex_int,
    _getnpdtype,
    _get_gdal_dtype,
)
from tests.conftest import gdal_version, requires_gdal3_11


def test_is_ndarray():
    assert is_ndarray(np.zeros((1,)))
    assert not is_ndarray([0])
    assert not is_ndarray((0,))


def test_np_dt_uint8():
    assert check_dtype(np.uint8)


def test_dt_ubyte():
    assert check_dtype(ubyte)


def test_check_dtype_invalid():
    assert not check_dtype('foo')


@pytest.mark.parametrize(
    ("dtype", "name"),
    [
        (ubyte, "Byte"),
        (np.uint8, "Byte"),
        (np.uint16, "UInt16"),
        ("uint8", "Byte"),
        ("complex_int16", "CInt16"),
        (complex_int16, "CInt16"),
    ],
)
def test_gdal_name(dtype, name):
    assert _gdal_typename(dtype) == name


def test_get_minimum_dtype():
    assert get_minimum_dtype([0, 1]) == uint8
    assert get_minimum_dtype([0, 1000]) == uint16
    assert get_minimum_dtype([0, 100000]) == uint32
    assert get_minimum_dtype([-1, 0, 1]) == int8
    assert get_minimum_dtype([-1, 0, 128]) == int16
    assert get_minimum_dtype([-1, 0, 100000]) == int32
    assert get_minimum_dtype([-1.5e+5, 0, 1.5e+5]) == float32
    assert get_minimum_dtype([-1.5e+100, 0, 1.5e+100]) == float64

    assert get_minimum_dtype(np.array([0, 1], dtype=np.uint)) == uint8
    assert get_minimum_dtype(np.array([0, 1000], dtype=np.uint)) == uint16
    assert get_minimum_dtype(np.array([0, 100000], dtype=np.uint)) == uint32
    assert get_minimum_dtype(np.array([-1, 0, 1], dtype=int)) == int8
    assert get_minimum_dtype(np.array([-1, 0, 128], dtype=int)) == int16
    assert get_minimum_dtype(np.array([-1, 0, 100000], dtype=int)) == int32



@pytest.mark.parametrize("values", [
    [-9.1, 0, 9.1],
    np.array([-1.5, 0, 1.5], dtype=np.float64),
    [0, 1.5, 5], # Mixed type list where min/max are same type
])
def test_get_minimum_dtype__float16(values):
    minium_dtype = get_minimum_dtype(values)
    if gdal_version.at_least("3.11"):
        assert minium_dtype == float16
    else:
        assert minium_dtype == float32


def test_get_minimum_dtype__int64():
    assert get_minimum_dtype([-1, 0, 2147483648]) == int64

def test_get_minimum_dtype__uint64():
    assert get_minimum_dtype([0, 4294967296]) == uint64


def test_can_cast_dtype():
    assert can_cast_dtype((1, 2, 3), np.uint8)
    assert can_cast_dtype(np.array([1, 2, 3]), np.uint8)
    assert can_cast_dtype(np.array([1, 2, 3], dtype=np.uint8), np.uint8)
    assert can_cast_dtype(np.array([1, 2, 3]), np.float32)
    assert can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.float32)
    assert not can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.uint8)


@pytest.mark.parametrize("dtype", [
    "float64",
    "float32",
    pytest.param(
        "float16",
        marks=requires_gdal3_11,
    ),
])
def test_can_cast_dtype_nan(dtype):
    assert can_cast_dtype([np.nan], dtype)


@pytest.mark.parametrize("dtype", ["uint8", "uint16", "uint32", "int32"])
def test_cant_cast_dtype_nan(dtype):
    assert not can_cast_dtype([np.nan], dtype)


def test_validate_dtype():
    assert validate_dtype([1, 2, 3], ('uint8', 'uint16'))
    assert validate_dtype(np.array([1, 2, 3]), ('uint8', 'uint16'))
    assert validate_dtype(np.array([1.4, 2.1, 3.65]), ('float16', 'float32',))
    assert not validate_dtype(np.array([1.4, 2.1, 3.65]), ('uint8',))


def test_complex(tmpdir):
    name = str(tmpdir.join("complex.tif"))
    arr1 = np.ones((2, 2), dtype=complex_)
    profile = dict(driver='GTiff', width=2, height=2, count=1, dtype=complex_)

    with rasterio.open(name, 'w', **profile) as dst:
        dst.write(arr1, 1)

    with rasterio.open(name) as src:
        arr2 = src.read(1)

    assert np.array_equal(arr1, arr2)


def test_is_complex_int():
    assert _is_complex_int("complex_int16")


def test_not_is_complex_int():
    assert not _is_complex_int("complex")


def test_get_npdtype():
    npdtype = _getnpdtype("complex_int16")
    assert npdtype == np.complex64
    assert npdtype.kind == "c"


def test__get_gdal_dtype__int64():
    assert _get_gdal_dtype("int64") == 13