File: test_dtype_registry.py

package info (click to toggle)
zarr 3.1.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,068 kB
  • sloc: python: 31,589; makefile: 10
file content (207 lines) | stat: -rw-r--r-- 8,199 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Literal, get_args

import numpy as np
import pytest

from tests.conftest import skip_object_dtype
from zarr.core.dtype import (
    AnyDType,
    DataTypeRegistry,
    TBaseDType,
    TBaseScalar,
    get_data_type_from_json,
)
from zarr.core.dtype.common import unpack_dtype_json
from zarr.dtype import (  # type: ignore[attr-defined]
    Bool,
    FixedLengthUTF32,
    ZDType,
    data_type_registry,
    parse_data_type,
    parse_dtype,
)

if TYPE_CHECKING:
    from zarr.core.common import ZarrFormat

from .test_dtype.conftest import zdtype_examples


@pytest.fixture
def data_type_registry_fixture() -> DataTypeRegistry:
    return DataTypeRegistry()


class TestRegistry:
    @staticmethod
    def test_register(data_type_registry_fixture: DataTypeRegistry) -> None:
        """
        Test that registering a dtype in a data type registry works.
        """
        data_type_registry_fixture.register(Bool._zarr_v3_name, Bool)
        assert data_type_registry_fixture.get(Bool._zarr_v3_name) == Bool
        assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), Bool)

    @staticmethod
    def test_override(data_type_registry_fixture: DataTypeRegistry) -> None:
        """
        Test that registering a new dtype with the same name works (overriding the previous one).
        """
        data_type_registry_fixture.register(Bool._zarr_v3_name, Bool)

        class NewBool(Bool):
            def default_scalar(self) -> np.bool_:
                return np.True_

        data_type_registry_fixture.register(NewBool._zarr_v3_name, NewBool)
        assert isinstance(data_type_registry_fixture.match_dtype(np.dtype("bool")), NewBool)

    @staticmethod
    @pytest.mark.parametrize(
        ("wrapper_cls", "dtype_str"), [(Bool, "bool"), (FixedLengthUTF32, "|U4")]
    )
    def test_match_dtype(
        data_type_registry_fixture: DataTypeRegistry,
        wrapper_cls: type[ZDType[TBaseDType, TBaseScalar]],
        dtype_str: str,
    ) -> None:
        """
        Test that match_dtype resolves a numpy dtype into an instance of the correspond wrapper for that dtype.
        """
        data_type_registry_fixture.register(wrapper_cls._zarr_v3_name, wrapper_cls)
        assert isinstance(data_type_registry_fixture.match_dtype(np.dtype(dtype_str)), wrapper_cls)

    @staticmethod
    def test_unregistered_dtype(data_type_registry_fixture: DataTypeRegistry) -> None:
        """
        Test that match_dtype raises an error if the dtype is not registered.
        """
        outside_dtype_name = "int8"
        outside_dtype = np.dtype(outside_dtype_name)
        msg = f"No Zarr data type found that matches dtype '{outside_dtype!r}'"
        with pytest.raises(ValueError, match=re.escape(msg)):
            data_type_registry_fixture.match_dtype(outside_dtype)

        with pytest.raises(KeyError):
            data_type_registry_fixture.get(outside_dtype_name)

    @staticmethod
    @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
    @pytest.mark.parametrize("zdtype", zdtype_examples)
    def test_registered_dtypes_match_dtype(zdtype: ZDType[TBaseDType, TBaseScalar]) -> None:
        """
        Test that the registered dtypes can be retrieved from the registry.
        """
        skip_object_dtype(zdtype)
        assert data_type_registry.match_dtype(zdtype.to_native_dtype()) == zdtype

    @staticmethod
    @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
    @pytest.mark.parametrize("zdtype", zdtype_examples)
    def test_registered_dtypes_match_json(
        zdtype: ZDType[TBaseDType, TBaseScalar], zarr_format: ZarrFormat
    ) -> None:
        assert (
            data_type_registry.match_json(
                zdtype.to_json(zarr_format=zarr_format), zarr_format=zarr_format
            )
            == zdtype
        )

    @staticmethod
    @pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
    @pytest.mark.parametrize("zdtype", zdtype_examples)
    def test_match_dtype_unique(
        zdtype: ZDType[Any, Any],
        data_type_registry_fixture: DataTypeRegistry,
        zarr_format: ZarrFormat,
    ) -> None:
        """
        Test that the match_dtype method uniquely specifies a registered data type. We create a local registry
        that excludes the data type class being tested, and ensure that an instance of the wrapped data type
        fails to match anything in the registry
        """
        skip_object_dtype(zdtype)
        for _cls in get_args(AnyDType):
            if _cls is not type(zdtype):
                data_type_registry_fixture.register(_cls._zarr_v3_name, _cls)

        dtype_instance = zdtype.to_native_dtype()

        msg = f"No Zarr data type found that matches dtype '{dtype_instance!r}'"
        with pytest.raises(ValueError, match=re.escape(msg)):
            data_type_registry_fixture.match_dtype(dtype_instance)

        instance_dict = zdtype.to_json(zarr_format=zarr_format)
        msg = f"No Zarr data type found that matches {instance_dict!r}"
        with pytest.raises(ValueError, match=re.escape(msg)):
            data_type_registry_fixture.match_json(instance_dict, zarr_format=zarr_format)


@pytest.mark.usefixtures("set_path")
def test_entrypoint_dtype(zarr_format: ZarrFormat) -> None:
    from package_with_entrypoint import TestDataType

    data_type_registry._lazy_load()
    instance = TestDataType()
    dtype_json = instance.to_json(zarr_format=zarr_format)
    assert get_data_type_from_json(dtype_json, zarr_format=zarr_format) == instance
    data_type_registry.unregister(TestDataType._zarr_v3_name)


@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
@pytest.mark.parametrize("data_type", zdtype_examples, ids=str)
@pytest.mark.parametrize("json_style", [(2, "internal"), (2, "metadata"), (3, None)], ids=str)
@pytest.mark.parametrize(
    "dtype_parser_func", [parse_dtype, parse_data_type], ids=["parse_dtype", "parse_data_type"]
)
def test_parse_data_type(
    data_type: ZDType[Any, Any],
    json_style: tuple[ZarrFormat, None | Literal["internal", "metadata"]],
    dtype_parser_func: Any,
) -> None:
    """
    Test the parsing of data types into ZDType instances.

    This function tests the ability of `dtype_parser_func` to correctly
    interpret and parse data type specifications into `ZDType` instances
    according to the specified Zarr format and JSON style.

    Parameters
    ----------
    data_type : ZDType[Any, Any]
        The data type to be tested for parsing.
    json_style : tuple[ZarrFormat, None or Literal["internal", "metadata"]]
        A tuple specifying the Zarr format version and the JSON style
        for Zarr V2 2. For Zarr V2 there are 2 JSON styles: "internal", and
        "metadata". The internal style takes the form {"name": <data type identifier>, "object_codec_id": <object codec id>},
        while the metadata style is just <data type identifier>.
    dtype_parser_func : Any
        The function to be tested for parsing the data type. This is necessary for compatibility
        reasons, as we support multiple functions that perform the same data type parsing operation.
    """
    zarr_format, style = json_style
    dtype_spec: Any

    if zarr_format == 2:
        dtype_spec = data_type.to_json(zarr_format=zarr_format)
        if style == "internal":
            pass
        elif style == "metadata":
            dtype_spec = unpack_dtype_json(dtype_spec)
        else:
            raise ValueError(f"Invalid zarr v2 json style: {style}")
    else:
        dtype_spec = data_type.to_json(zarr_format=zarr_format)

    if dtype_spec == "|O":
        # The object data type on its own is ambiguous and should fail to resolve.
        msg = "Zarr data type resolution from object failed."
        with pytest.raises(ValueError, match=msg):
            dtype_parser_func(dtype_spec, zarr_format=zarr_format)
    else:
        observed = dtype_parser_func(dtype_spec, zarr_format=zarr_format)
        assert observed == data_type