File: yamlutil.py

package info (click to toggle)
python-asdf 4.3.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,032 kB
  • sloc: python: 24,068; makefile: 123
file content (449 lines) | stat: -rw-r--r-- 16,247 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
import warnings
from collections import OrderedDict
from types import GeneratorType

import numpy as np
import yaml

from . import config, schema, tagged, treeutil, util
from .constants import STSCI_SCHEMA_TAG_BASE, YAML_TAG_PREFIX
from .exceptions import AsdfConversionWarning, AsdfSerializationError
from .extension._serialization_context import BlockAccess
from .tags.core import AsdfObject
from .versioning import _YAML_VERSION, _yaml_base_loader

__all__ = ["custom_tree_to_tagged_tree", "tagged_tree_to_custom_tree"]


_yaml_base_dumper = yaml.CSafeDumper if getattr(yaml, "__with_libyaml__", None) else yaml.SafeDumper


YAML_OMAP_TAG = YAML_TAG_PREFIX + "omap"


# ----------------------------------------------------------------------
# Custom loader/dumpers


class AsdfDumper(_yaml_base_dumper):
    """
    A specialized YAML dumper that understands "tagged basic Python
    data types" as implemented in the `tagged` module.
    """

    def __init__(self, *args, **kwargs):
        kwargs["default_flow_style"] = None
        super().__init__(*args, **kwargs)

    def represent_data(self, data):
        node = super().represent_data(data)

        tag_name = getattr(data, "_tag", None)
        if tag_name is not None:
            node.tag = tag_name

        return node


_flow_style_map = {"flow": True, "block": False}


def represent_sequence(dumper, sequence):
    flow_style = _flow_style_map.get(sequence.flow_style, None)
    sequence = sequence.data
    return super(AsdfDumper, dumper).represent_sequence(None, sequence, flow_style)


def represent_mapping(dumper, mapping):
    flow_style = _flow_style_map.get(mapping.flow_style, None)
    node = super(AsdfDumper, dumper).represent_mapping(None, mapping.data, flow_style)

    if mapping.property_order:
        values = node.value
        new_mapping = {}
        for key, val in values:
            new_mapping[key.value] = (key, val)

        new_values = []
        for key in mapping.property_order:
            if key in mapping:
                new_values.append(new_mapping[key])

        property_order = set(mapping.property_order)
        for key, val in values:
            if key.value not in property_order:
                new_values.append((key, val))

        node.value = new_values

    return node


_style_map = {"inline": '"', "folded": ">", "literal": "|"}


def represent_scalar(dumper, value):
    style = _style_map.get(value.style, None)
    return super(AsdfDumper, dumper).represent_scalar(None, value.data, style)


def represent_ordered_mapping(dumper, tag, data):
    # TODO: Again, adjust for preferred flow style, and other stylistic details
    # NOTE: For block style this uses the compact omap notation, but for flow style
    # it does not.

    # TODO: Need to see if I can figure out a mechanism so that classes that
    # use this representer can specify which values should use flow style
    values = []
    node = yaml.SequenceNode(tag, values, flow_style=dumper.default_flow_style)
    if dumper.alias_key is not None:
        dumper.represented_objects[dumper.alias_key] = node
    for key, value in data.items():
        key_item = dumper.represent_data(key)
        value_item = dumper.represent_data(value)
        node_item = yaml.MappingNode(YAML_OMAP_TAG, [(key_item, value_item)], flow_style=False)
        values.append(node_item)
    return node


def represent_ordereddict(dumper, data):
    return represent_ordered_mapping(dumper, YAML_OMAP_TAG, data)


AsdfDumper.add_representer(tagged.TaggedList, represent_sequence)
AsdfDumper.add_representer(tagged.TaggedDict, represent_mapping)
AsdfDumper.add_representer(tagged.TaggedString, represent_scalar)
AsdfDumper.add_representer(OrderedDict, represent_ordereddict)

# ----------------------------------------------------------------------
# Handle numpy scalars


for scalar_type in util._iter_subclasses(np.floating):
    AsdfDumper.add_representer(scalar_type, lambda dumper, data: dumper.represent_float(float(data)))

for scalar_type in util._iter_subclasses(np.integer):
    AsdfDumper.add_representer(scalar_type, lambda dumper, data: dumper.represent_int(int(data)))


def represent_numpy_str(dumper, data):
    # The CSafeDumper implementation will raise an error if it
    # doesn't recognize data as a string.  The Python SafeDumper
    # has no problem with np.str_.
    return dumper.represent_str(str(data))


AsdfDumper.add_representer(np.str_, represent_numpy_str)
AsdfDumper.add_representer(np.bytes_, AsdfDumper.represent_binary)


class _IgnoreCustomTagsLoader(_yaml_base_loader):
    """
    A specialized YAML loader that ignores tags unknown to the
    base (safe) loader. This is used by `asdf.util.load_yaml`
    to read the ASDF tree as "basic" objects, ignoring the
    custom tags.
    """

    def construct_undefined(self, node):
        if isinstance(node, yaml.MappingNode):
            return self.construct_yaml_map(node)
        elif isinstance(node, yaml.SequenceNode):
            return self.construct_yaml_seq(node)
        elif isinstance(node, yaml.ScalarNode):
            return self.construct_scalar(node)
        return super().construct_undefined(node)


# pyyaml will invoke the constructor associated with None when a node's
# tag is not explicitly handled by another constructor.
_IgnoreCustomTagsLoader.add_constructor(None, _IgnoreCustomTagsLoader.construct_undefined)


class AsdfLoader(_yaml_base_loader):
    """
    A specialized YAML loader that can construct "tagged basic Python
    data types" as implemented in the `tagged` module.
    """

    def construct_undefined(self, node):
        if isinstance(node, yaml.MappingNode):
            return self._construct_tagged_mapping(node)

        if isinstance(node, yaml.SequenceNode):
            return self._construct_tagged_sequence(node)

        if isinstance(node, yaml.ScalarNode):
            return self._construct_tagged_scalar(node)

        return super().construct_undefined(node)

    def _construct_tagged_mapping(self, node):
        data = tagged.tag_object(node.tag, {})
        yield data
        data.update(self.construct_mapping(node))

    def _construct_tagged_sequence(self, node):
        data = tagged.tag_object(node.tag, [])
        yield data
        data.extend(self.construct_sequence(node))

    def _construct_tagged_scalar(self, node):
        return tagged.tag_object(node.tag, self.construct_scalar(node))

    # Custom omap deserializer that builds an OrderedDict instead
    # of a list of tuples.  Code is mostly identical to pyyaml's SafeConstructor.
    def construct_yaml_omap(self, node):
        omap = OrderedDict()
        yield omap
        if not isinstance(node, yaml.SequenceNode):
            msg = "while constructing an ordered map"
            raise yaml.constructor.ConstructorError(
                msg,
                node.start_mark,
                f"expected a sequence, but found {node.id}",
                node.start_mark,
            )
        for subnode in node.value:
            if not isinstance(subnode, yaml.MappingNode):
                msg = "while constructing an ordered map"
                raise yaml.constructor.ConstructorError(
                    msg,
                    node.start_mark,
                    f"expected a mapping of length 1, but found {subnode.id}",
                    subnode.start_mark,
                )
            if len(subnode.value) != 1:
                msg = "while constructing an ordered map"
                raise yaml.constructor.ConstructorError(
                    msg,
                    node.start_mark,
                    f"expected a single mapping item, but found {len(subnode.value)} items",
                    subnode.start_mark,
                )
            key_node, value_node = subnode.value[0]
            key = self.construct_object(key_node)
            value = self.construct_object(value_node)
            omap[key] = value


# pyyaml will invoke the constructor associated with None when a node's
# tag is not explicitly handled by another constructor.
AsdfLoader.add_constructor(None, AsdfLoader.construct_undefined)
AsdfLoader.add_constructor(YAML_TAG_PREFIX + "omap", AsdfLoader.construct_yaml_omap)


def custom_tree_to_tagged_tree(tree, ctx, _serialization_context=None):
    """
    Convert a tree, possibly containing custom data types that aren't
    directly representable in YAML, to a tree of basic data types,
    annotated with tags.
    """
    if _serialization_context is None:
        _serialization_context = ctx._create_serialization_context(BlockAccess.WRITE)

    extension_manager = _serialization_context.extension_manager

    def _convert_obj(obj, converter):
        tag = converter.select_tag(obj, _serialization_context)
        # if select_tag returns None, converter.to_yaml_tree should return a new
        # object which will be handled by a different converter
        converters_used = set()
        while tag is None:
            converters_used.add(converter)
            obj = converter.to_yaml_tree(obj, tag, _serialization_context)
            try:
                converter = extension_manager.get_converter_for_type(type(obj))
            except KeyError:
                # no converter supports this type, return it as-is
                yield obj
                return
            if converter in converters_used:
                msg = "Conversion cycle detected"
                raise TypeError(msg)
            tag = converter.select_tag(obj, _serialization_context)
        _serialization_context.assign_object(obj)
        node = converter.to_yaml_tree(obj, tag, _serialization_context)
        _serialization_context.assign_blocks()

        if isinstance(node, GeneratorType):
            generator = node
            node = next(generator)
        else:
            generator = None

        if isinstance(node, dict):
            tagged_node = tagged.TaggedDict(node, tag)
        elif isinstance(node, list):
            tagged_node = tagged.TaggedList(node, tag)
        elif isinstance(node, str):
            tagged_node = tagged.TaggedString(node)
            tagged_node._tag = tag
        else:
            msg = f"Converter returned illegal node type: {util.get_class_name(node)}"
            raise TypeError(msg)

        _serialization_context._mark_extension_used(converter.extension)

        yield tagged_node
        if generator is not None:
            yield from generator

    cfg = config.get_config()
    convert_ndarray_subclasses = cfg.convert_unknown_ndarray_subclasses
    converters_cache = {}

    def _walker(obj):
        typ = type(obj)
        if typ in converters_cache:
            return converters_cache[typ](obj)
        if extension_manager.handles_type(typ):
            converter = extension_manager.get_converter_for_type(typ)
            converters_cache[typ] = lambda obj, _converter=converter: _convert_obj(obj, _converter)
            return _convert_obj(obj, converter)
        if convert_ndarray_subclasses and isinstance(obj, np.ndarray):
            warnings.warn(
                f"A ndarray subclass ({type(obj)}) was converted as a ndarray. "
                "This behavior will be removed from a future version of ASDF. "
                "See https://asdf.readthedocs.io/en/latest/asdf/config.html#convert-unknown-ndarray-subclasses",
                AsdfConversionWarning,
            )
            converter = extension_manager.get_converter_for_type(np.ndarray)
            converters_cache[typ] = lambda obj, _converter=converter: _convert_obj(obj, _converter)
            return _convert_obj(obj, converter)

        converters_cache[typ] = lambda obj: obj
        return obj

    return treeutil.walk_and_modify(
        tree,
        _walker,
        # Walk the tree in preorder, so that extensions can return
        # container nodes with unserialized children.
        postorder=False,
        _context=ctx._tree_modification_context,
    )


def tagged_tree_to_custom_tree(tree, ctx, force_raw_types=False, _serialization_context=None):
    """
    Convert a tree containing only basic data types, annotated with
    tags, to a tree containing custom data types.
    """
    if _serialization_context is None:
        _serialization_context = ctx._create_serialization_context(BlockAccess.READ)

    extension_manager = _serialization_context.extension_manager

    def _walker(node):
        if force_raw_types:
            return node

        tag = getattr(node, "_tag", None)
        if tag is None:
            return node

        if extension_manager.handles_tag(tag):
            converter = extension_manager.get_converter_for_tag(tag)
            obj = converter.from_yaml_tree(node.data, tag, _serialization_context)
            _serialization_context.assign_object(obj)
            _serialization_context.assign_blocks()
            _serialization_context._mark_extension_used(converter.extension)
            return obj

        if not ctx._ignore_unrecognized_tag:
            warnings.warn(
                f"{tag} is not recognized, converting to raw Python data structure",
                AsdfConversionWarning,
            )
        return node

    return treeutil.walk_and_modify(
        tree,
        _walker,
        # Walk the tree in postorder, so that extensions receive
        # container nodes with children already deserialized.
        postorder=True,
        _context=ctx._tree_modification_context,
    )


def load_tree(stream):
    """
    Load YAML, returning a tree of objects.

    Parameters
    ----------
    stream : readable file-like object
        Stream containing the raw YAML content.
    """
    # The following call to yaml.load is safe because we're
    # using a loader that inherits from pyyaml's SafeLoader.
    return yaml.load(stream, Loader=AsdfLoader)  # noqa: S506


def dump_tree(tree, fd, ctx, tree_finalizer=None, _serialization_context=None):
    """
    Dump a tree of objects, possibly containing custom types, to YAML.

    Parameters
    ----------
    tree : object
        Tree of objects, possibly containing custom data types.

    fd : asdf.generic_io.GenericFile
        A file object to dump the serialized YAML to.

    ctx : Context
        The writing context.

    tree_finalizer : callable, optional
        Callback that receives the tagged tree before it is validated
        and defaults are removed.  `asdf.AsdfFile` uses this to update
        extension metadata on the tagged tree before it is fully
        serialized to YAML.
    """
    # The _serialization_context parameter allows AsdfFile to track
    # what extensions were used when converting the tree's custom
    # types.  In 3.0, it will be passed as the `ctx` instead of the
    # AsdfFile itself.
    if type(tree) is not AsdfObject:
        msg = "Root node of ASDF tree must be of type AsdfObject"
        raise TypeError(msg)

    tags = {"!": STSCI_SCHEMA_TAG_BASE + "/"}
    tree = custom_tree_to_tagged_tree(tree, ctx, _serialization_context=_serialization_context)
    if tree_finalizer is not None:
        tree_finalizer(tree)
    schema.validate(tree, ctx)

    # add yaml %TAG definitions from extensions
    if _serialization_context:
        for ext in _serialization_context._extensions_used:
            for key, val in ext.yaml_tag_handles.items():
                if key not in tags:
                    tags[key] = val

    try:
        yaml.dump_all(
            [tree],
            stream=fd,
            Dumper=AsdfDumper,
            explicit_start=True,
            explicit_end=True,
            version=_YAML_VERSION,
            allow_unicode=True,
            encoding="utf-8",
            tags=tags,
        )
    except yaml.representer.RepresenterError as err:
        if len(err.args) < 2:
            raise err
        # inspect the exception arguments to determine what object failed
        obj = err.args[1]
        msg = (
            f"Object of type[{type(obj)}] is not serializable by asdf. "
            "Please convert the object to a supported type or implement "
            "a Converter for this type to allow the tree to be serialized."
        )
        raise AsdfSerializationError(msg, obj) from err