File: test_datatree_mapping.py

package info (click to toggle)
python-xarray 2025.08.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,796 kB
  • sloc: python: 115,416; makefile: 258; sh: 47
file content (255 lines) | stat: -rw-r--r-- 9,628 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import re

import numpy as np
import pytest

import xarray as xr
from xarray.core.datatree_mapping import map_over_datasets
from xarray.core.treenode import TreeIsomorphismError
from xarray.testing import assert_equal, assert_identical

empty = xr.Dataset()


class TestMapOverSubTree:
    def test_no_trees_passed(self):
        with pytest.raises(TypeError, match="must pass at least one tree object"):
            map_over_datasets(lambda x: x, "dt")

    def test_not_isomorphic(self, create_test_datatree):
        dt1 = create_test_datatree()
        dt2 = create_test_datatree()
        dt2["set1/set2/extra"] = xr.DataTree(name="extra")

        with pytest.raises(
            TreeIsomorphismError,
            match=re.escape(
                r"children at node 'set1/set2' do not match: [] vs ['extra']"
            ),
        ):
            map_over_datasets(lambda x, y: None, dt1, dt2)

    def test_no_trees_returned(self, create_test_datatree):
        dt1 = create_test_datatree()
        dt2 = create_test_datatree()
        expected = xr.DataTree.from_dict(dict.fromkeys(dt1.to_dict()))
        actual = map_over_datasets(lambda x, y: None, dt1, dt2)
        assert_equal(expected, actual)

    def test_single_tree_arg(self, create_test_datatree):
        dt = create_test_datatree()
        expected = create_test_datatree(modify=lambda x: 10.0 * x)
        result_tree = map_over_datasets(lambda x: 10 * x, dt)
        assert_equal(result_tree, expected)

    def test_single_tree_arg_plus_arg(self, create_test_datatree):
        dt = create_test_datatree()
        expected = create_test_datatree(modify=lambda ds: (10.0 * ds))
        result_tree = map_over_datasets(lambda x, y: x * y, dt, 10.0)
        assert_equal(result_tree, expected)

        result_tree = map_over_datasets(lambda x, y: x * y, 10.0, dt)
        assert_equal(result_tree, expected)

    def test_single_tree_arg_plus_kwarg(self, create_test_datatree):
        dt = create_test_datatree()
        expected = create_test_datatree(modify=lambda ds: (10.0 * ds))

        def multiply_by_kwarg(ds, **kwargs):
            ds = ds * kwargs.pop("multiplier")
            return ds

        result_tree = map_over_datasets(
            multiply_by_kwarg, dt, kwargs=dict(multiplier=10.0)
        )
        assert_equal(result_tree, expected)

    def test_multiple_tree_args(self, create_test_datatree):
        dt1 = create_test_datatree()
        dt2 = create_test_datatree()
        expected = create_test_datatree(modify=lambda ds: 2.0 * ds)
        result = map_over_datasets(lambda x, y: x + y, dt1, dt2)
        assert_equal(result, expected)

    def test_return_multiple_trees(self, create_test_datatree):
        dt = create_test_datatree()
        dt_min, dt_max = map_over_datasets(lambda x: (x.min(), x.max()), dt)
        expected_min = create_test_datatree(modify=lambda ds: ds.min())
        assert_equal(dt_min, expected_min)
        expected_max = create_test_datatree(modify=lambda ds: ds.max())
        assert_equal(dt_max, expected_max)

    def test_return_wrong_type(self, simple_datatree):
        dt1 = simple_datatree

        with pytest.raises(
            TypeError,
            match=re.escape(
                "the result of calling func on the node at position '.' is not a "
                "Dataset or None or a tuple of such types"
            ),
        ):
            map_over_datasets(lambda x: "string", dt1)  # type: ignore[arg-type,return-value]

    def test_return_tuple_of_wrong_types(self, simple_datatree):
        dt1 = simple_datatree

        with pytest.raises(
            TypeError,
            match=re.escape(
                "the result of calling func on the node at position '.' is not a "
                "Dataset or None or a tuple of such types"
            ),
        ):
            map_over_datasets(lambda x: (x, "string"), dt1)  # type: ignore[arg-type,return-value]

    def test_return_inconsistent_number_of_results(self, simple_datatree):
        dt1 = simple_datatree

        with pytest.raises(
            TypeError,
            match=re.escape(
                r"Calling func on the nodes at position set1 returns a tuple "
                "of 0 datasets, whereas calling func on the nodes at position "
                ". instead returns a tuple of 2 datasets."
            ),
        ):
            # Datasets in simple_datatree have different numbers of dims
            map_over_datasets(lambda ds: tuple((None,) * len(ds.dims)), dt1)

    def test_wrong_number_of_arguments_for_func(self, simple_datatree):
        dt = simple_datatree

        with pytest.raises(
            TypeError, match="takes 1 positional argument but 2 were given"
        ):
            map_over_datasets(lambda x: 10 * x, dt, dt)

    def test_map_single_dataset_against_whole_tree(self, create_test_datatree):
        dt = create_test_datatree()

        def nodewise_merge(node_ds, fixed_ds):
            return xr.merge([node_ds, fixed_ds])

        other_ds = xr.Dataset({"z": ("z", [0])})
        expected = create_test_datatree(modify=lambda ds: xr.merge([ds, other_ds]))
        result_tree = map_over_datasets(nodewise_merge, dt, other_ds)
        assert_equal(result_tree, expected)

    @pytest.mark.xfail
    def test_trees_with_different_node_names(self):
        # TODO test this after I've got good tests for renaming nodes
        raise NotImplementedError

    def test_tree_method(self, create_test_datatree):
        dt = create_test_datatree()

        def multiply(ds, times):
            return times * ds

        expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
        result_tree = dt.map_over_datasets(multiply, 10.0)
        assert_equal(result_tree, expected)

    def test_tree_method_with_kwarg(self, create_test_datatree):
        dt = create_test_datatree()

        def multiply(ds, **kwargs):
            return kwargs.pop("times") * ds

        expected = create_test_datatree(modify=lambda ds: 10.0 * ds)
        result_tree = dt.map_over_datasets(multiply, kwargs=dict(times=10.0))
        assert_equal(result_tree, expected)

    def test_discard_ancestry(self, create_test_datatree):
        # Check for datatree GH issue https://github.com/xarray-contrib/datatree/issues/48
        dt = create_test_datatree()
        subtree = dt["set1"]
        expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"]
        result_tree = map_over_datasets(lambda x: 10.0 * x, subtree)
        assert_equal(result_tree, expected)

    def test_keep_attrs_on_empty_nodes(self, create_test_datatree):
        # GH278
        dt = create_test_datatree()
        dt["set1/set2"].attrs["foo"] = "bar"

        def empty_func(ds):
            return ds

        result = dt.map_over_datasets(empty_func)
        assert result["set1/set2"].attrs == dt["set1/set2"].attrs

    def test_error_contains_path_of_offending_node(self, create_test_datatree):
        dt = create_test_datatree()
        dt["set1"]["bad_var"] = 0
        print(dt)

        def fail_on_specific_node(ds):
            if "bad_var" in ds:
                raise ValueError("Failed because 'bar_var' present in dataset")

        with pytest.raises(
            ValueError,
            match=re.escape(
                r"Raised whilst mapping function over node with path 'set1'"
            ),
        ):
            dt.map_over_datasets(fail_on_specific_node)

    def test_inherited_coordinates_with_index(self):
        root = xr.Dataset(coords={"x": [1, 2]})
        child = xr.Dataset({"foo": ("x", [0, 1])})  # no coordinates
        tree = xr.DataTree.from_dict({"/": root, "/child": child})
        actual = tree.map_over_datasets(lambda ds: ds)  # identity
        assert isinstance(actual, xr.DataTree)
        assert_identical(tree, actual)

        actual_child = actual.children["child"].to_dataset(inherit=False)
        assert_identical(actual_child, child)


class TestMutableOperations:
    def test_construct_using_type(self):
        # from datatree GH issue https://github.com/xarray-contrib/datatree/issues/188
        # xarray's .weighted is unusual because it uses type() to create a Dataset/DataArray

        a = xr.DataArray(
            np.random.rand(3, 4, 10),
            dims=["x", "y", "time"],
            coords={"area": (["x", "y"], np.random.rand(3, 4))},
        ).to_dataset(name="data")
        b = xr.DataArray(
            np.random.rand(2, 6, 14),
            dims=["x", "y", "time"],
            coords={"area": (["x", "y"], np.random.rand(2, 6))},
        ).to_dataset(name="data")
        dt = xr.DataTree.from_dict({"a": a, "b": b})

        def weighted_mean(ds):
            if "area" not in ds.coords:
                return None
            return ds.weighted(ds.area).mean(["x", "y"])

        dt.map_over_datasets(weighted_mean)

    def test_alter_inplace_forbidden(self):
        simpsons = xr.DataTree.from_dict(
            {
                "/": xr.Dataset({"age": 83}),
                "/Herbert": xr.Dataset({"age": 40}),
                "/Homer": xr.Dataset({"age": 39}),
                "/Homer/Bart": xr.Dataset({"age": 10}),
                "/Homer/Lisa": xr.Dataset({"age": 8}),
                "/Homer/Maggie": xr.Dataset({"age": 1}),
            },
            name="Abe",
        )

        def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset:
            """Add some years to the age, but by altering the given dataset"""
            ds["age"] = ds["age"] + years
            return ds

        with pytest.raises(AttributeError):
            simpsons.map_over_datasets(fast_forward, 10)