File: datatree_mapping.py

package info (click to toggle)
python-xarray 2025.08.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 11,796 kB
  • sloc: python: 115,416; makefile: 258; sh: 47
file content (215 lines) | stat: -rw-r--r-- 8,024 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, Any, cast, overload

from xarray.core.dataset import Dataset
from xarray.core.treenode import group_subtrees
from xarray.core.utils import result_name

if TYPE_CHECKING:
    from xarray.core.datatree import DataTree


@overload
def map_over_datasets(
    func: Callable[
        ...,
        Dataset | None,
    ],
    *args: Any,
    kwargs: Mapping[str, Any] | None = None,
) -> DataTree: ...


@overload
def map_over_datasets(
    func: Callable[..., tuple[Dataset | None, Dataset | None]],
    *args: Any,
    kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, DataTree]: ...


# add an expect overload for the most common case of two return values
# (python typing does not have a way to match tuple lengths in general)
@overload
def map_over_datasets(
    func: Callable[..., tuple[Dataset | None, ...]],
    *args: Any,
    kwargs: Mapping[str, Any] | None = None,
) -> tuple[DataTree, ...]: ...


def map_over_datasets(
    func: Callable[..., Dataset | tuple[Dataset | None, ...] | None],
    *args: Any,
    kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
    """
    Applies a function to every dataset in one or more DataTree objects with
    the same structure (ie.., that are isomorphic), returning new trees which
    store the results.

    The function will be applied to any dataset stored in any of the nodes in
    the trees. The returned trees will have the same structure as the supplied
    trees.

    ``func`` needs to return a Dataset, tuple of Dataset objects or None in order
    to be able to rebuild the subtrees after mapping, as each result will be
    assigned to its respective node of a new tree via `DataTree.from_dict`. Any
    returned value that is one of these types will be stacked into a separate
    tree before returning all of them.

    ``map_over_datasets`` is essentially syntactic sugar for the combination of
    ``group_subtrees`` and ``DataTree.from_dict``. For example, in the case of
    a two argument function that return one result, it is equivalent to::

        results = {}
        for path, (left, right) in group_subtrees(left_tree, right_tree):
            results[path] = func(left.dataset, right.dataset)
        return DataTree.from_dict(results)

    Parameters
    ----------
    func : callable
        Function to apply to datasets with signature:

        `func(*args: Dataset, **kwargs) -> Union[Dataset, tuple[Dataset, ...]]`.

        (i.e. func must accept at least one Dataset and return at least one Dataset.)
    *args : tuple, optional
        Positional arguments passed on to `func`. Any DataTree arguments will be
        converted to Dataset objects via `.dataset`.
    kwargs : dict, optional
        Optional keyword arguments passed directly to ``func``.

    Returns
    -------
    Result of applying `func` to each node in the provided trees, packed back
    into DataTree objects via `DataTree.from_dict`.

    See also
    --------
    DataTree.map_over_datasets
    group_subtrees
    DataTree.from_dict
    """
    # TODO examples in the docstring
    # TODO inspect function to work out immediately if the wrong number of arguments were passed for it?

    from xarray.core.datatree import DataTree

    if kwargs is None:
        kwargs = {}

    # Walk all trees simultaneously, applying func to all nodes that lie in same position in different trees
    # We don't know which arguments are DataTrees so we zip all arguments together as iterables
    # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
    out_data_objects: dict[str, Dataset | tuple[Dataset | None, ...] | None] = {}

    tree_args = [arg for arg in args if isinstance(arg, DataTree)]
    name = result_name(tree_args)

    for path, node_tree_args in group_subtrees(*tree_args):
        node_dataset_args = [arg.dataset for arg in node_tree_args]
        for i, arg in enumerate(args):
            if not isinstance(arg, DataTree):
                node_dataset_args.insert(i, arg)

        func_with_error_context = _handle_errors_with_path_context(path)(func)
        results = func_with_error_context(*node_dataset_args, **kwargs)
        out_data_objects[path] = results

    num_return_values = _check_all_return_values(out_data_objects)

    if num_return_values is None:
        # one return value
        out_data = cast(Mapping[str, Dataset | None], out_data_objects)
        return DataTree.from_dict(out_data, name=name)

    # multiple return values
    out_data_tuples = cast(Mapping[str, tuple[Dataset | None, ...]], out_data_objects)
    output_dicts: list[dict[str, Dataset | None]] = [
        {} for _ in range(num_return_values)
    ]
    for path, outputs in out_data_tuples.items():
        for output_dict, output in zip(output_dicts, outputs, strict=False):
            output_dict[path] = output

    return tuple(
        DataTree.from_dict(output_dict, name=name) for output_dict in output_dicts
    )


def _handle_errors_with_path_context(path: str):
    """Wraps given function so that if it fails it also raises path to node on which it failed."""

    def decorator(func):
        def wrapper(*args, **kwargs):
            try:
                return func(*args, **kwargs)
            except Exception as e:
                # Add the context information to the error message
                add_note(
                    e, f"Raised whilst mapping function over node with path {path!r}"
                )
                raise

        return wrapper

    return decorator


def add_note(err: BaseException, msg: str) -> None:
    err.add_note(msg)


def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
    """Check types returned from single evaluation of func, and return number of return values received from func."""
    if isinstance(obj, Dataset | None):
        return None  # no need to pack results

    if not isinstance(obj, tuple) or not all(
        isinstance(r, Dataset | None) for r in obj
    ):
        raise TypeError(
            f"the result of calling func on the node at position '{path_to_node}' is"
            f" not a Dataset or None or a tuple of such types:\n{obj!r}"
        )

    return len(obj)


def _check_all_return_values(returned_objects) -> int | None:
    """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

    result_data_objects = list(returned_objects.items())

    first_path, result = result_data_objects[0]
    return_values = _check_single_set_return_values(first_path, result)

    for path_to_node, obj in result_data_objects[1:]:
        cur_return_values = _check_single_set_return_values(path_to_node, obj)

        if return_values != cur_return_values:
            if return_values is None:
                raise TypeError(
                    f"Calling func on the nodes at position {path_to_node} returns "
                    f"a tuple of {cur_return_values} datasets, whereas calling func on the "
                    f"nodes at position {first_path} instead returns a single dataset."
                )
            elif cur_return_values is None:
                raise TypeError(
                    f"Calling func on the nodes at position {path_to_node} returns "
                    f"a single dataset, whereas calling func on the nodes at position "
                    f"{first_path} instead returns a tuple of {return_values} datasets."
                )
            else:
                raise TypeError(
                    f"Calling func on the nodes at position {path_to_node} returns "
                    f"a tuple of {cur_return_values} datasets, whereas calling func on "
                    f"the nodes at position {first_path} instead returns a tuple of "
                    f"{return_values} datasets."
                )

    return return_values