1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
|
import numpy as np
import pytest
import rasterio
from rasterio import (
ubyte,
uint8,
uint16,
uint32,
uint64,
int8,
int16,
int32,
int64,
float16,
float32,
float64,
complex_,
complex_int16,
)
from rasterio.dtypes import (
_gdal_typename,
is_ndarray,
check_dtype,
get_minimum_dtype,
can_cast_dtype,
validate_dtype,
_is_complex_int,
_getnpdtype,
_get_gdal_dtype,
)
from tests.conftest import gdal_version, requires_gdal3_11
def test_is_ndarray():
assert is_ndarray(np.zeros((1,)))
assert not is_ndarray([0])
assert not is_ndarray((0,))
def test_np_dt_uint8():
assert check_dtype(np.uint8)
def test_dt_ubyte():
assert check_dtype(ubyte)
def test_check_dtype_invalid():
assert not check_dtype('foo')
@pytest.mark.parametrize(
("dtype", "name"),
[
(ubyte, "Byte"),
(np.uint8, "Byte"),
(np.uint16, "UInt16"),
("uint8", "Byte"),
("complex_int16", "CInt16"),
(complex_int16, "CInt16"),
],
)
def test_gdal_name(dtype, name):
assert _gdal_typename(dtype) == name
def test_get_minimum_dtype():
assert get_minimum_dtype([0, 1]) == uint8
assert get_minimum_dtype([0, 1000]) == uint16
assert get_minimum_dtype([0, 100000]) == uint32
assert get_minimum_dtype([-1, 0, 1]) == int8
assert get_minimum_dtype([-1, 0, 128]) == int16
assert get_minimum_dtype([-1, 0, 100000]) == int32
assert get_minimum_dtype([-1.5e+5, 0, 1.5e+5]) == float32
assert get_minimum_dtype([-1.5e+100, 0, 1.5e+100]) == float64
assert get_minimum_dtype(np.array([0, 1], dtype=np.uint)) == uint8
assert get_minimum_dtype(np.array([0, 1000], dtype=np.uint)) == uint16
assert get_minimum_dtype(np.array([0, 100000], dtype=np.uint)) == uint32
assert get_minimum_dtype(np.array([-1, 0, 1], dtype=int)) == int8
assert get_minimum_dtype(np.array([-1, 0, 128], dtype=int)) == int16
assert get_minimum_dtype(np.array([-1, 0, 100000], dtype=int)) == int32
@pytest.mark.parametrize("values", [
[-9.1, 0, 9.1],
np.array([-1.5, 0, 1.5], dtype=np.float64),
[0, 1.5, 5], # Mixed type list where min/max are same type
])
def test_get_minimum_dtype__float16(values):
minium_dtype = get_minimum_dtype(values)
if gdal_version.at_least("3.11"):
assert minium_dtype == float16
else:
assert minium_dtype == float32
def test_get_minimum_dtype__int64():
assert get_minimum_dtype([-1, 0, 2147483648]) == int64
def test_get_minimum_dtype__uint64():
assert get_minimum_dtype([0, 4294967296]) == uint64
def test_can_cast_dtype():
assert can_cast_dtype((1, 2, 3), np.uint8)
assert can_cast_dtype(np.array([1, 2, 3]), np.uint8)
assert can_cast_dtype(np.array([1, 2, 3], dtype=np.uint8), np.uint8)
assert can_cast_dtype(np.array([1, 2, 3]), np.float32)
assert can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.float32)
assert not can_cast_dtype(np.array([1.4, 2.1, 3.65]), np.uint8)
@pytest.mark.parametrize("dtype", [
"float64",
"float32",
pytest.param(
"float16",
marks=requires_gdal3_11,
),
])
def test_can_cast_dtype_nan(dtype):
assert can_cast_dtype([np.nan], dtype)
@pytest.mark.parametrize("dtype", ["uint8", "uint16", "uint32", "int32"])
def test_cant_cast_dtype_nan(dtype):
assert not can_cast_dtype([np.nan], dtype)
def test_validate_dtype():
assert validate_dtype([1, 2, 3], ('uint8', 'uint16'))
assert validate_dtype(np.array([1, 2, 3]), ('uint8', 'uint16'))
assert validate_dtype(np.array([1.4, 2.1, 3.65]), ('float16', 'float32',))
assert not validate_dtype(np.array([1.4, 2.1, 3.65]), ('uint8',))
def test_complex(tmpdir):
name = str(tmpdir.join("complex.tif"))
arr1 = np.ones((2, 2), dtype=complex_)
profile = dict(driver='GTiff', width=2, height=2, count=1, dtype=complex_)
with rasterio.open(name, 'w', **profile) as dst:
dst.write(arr1, 1)
with rasterio.open(name) as src:
arr2 = src.read(1)
assert np.array_equal(arr1, arr2)
def test_is_complex_int():
assert _is_complex_int("complex_int16")
def test_not_is_complex_int():
assert not _is_complex_int("complex")
def test_get_npdtype():
npdtype = _getnpdtype("complex_int16")
assert npdtype == np.complex64
assert npdtype.kind == "c"
def test__get_gdal_dtype__int64():
assert _get_gdal_dtype("int64") == 13
|