File: test_serialization_profiles.py

package info (click to toggle)
ansible-core 2.20.3-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 32,760 kB
  • sloc: python: 175,447; cs: 4,929; sh: 4,732; xml: 34; makefile: 21
file content (385 lines) | stat: -rw-r--r-- 15,017 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# DTFIX-FUTURE: these tests need to be split so they can run under both module_utils and controller contexts

from __future__ import annotations

import contextlib
import dataclasses
import datetime
import hashlib
import itertools
import json
import pathlib
import pkgutil
import pprint
import typing as t

import pytest

from ansible.module_utils._internal._json import _profiles as target_serialization_profiles
from ansible.module_utils._internal import _json as _serialization
from ansible.module_utils._internal._datatag import AnsibleDatatagBase, NotTaggableError, AnsibleTagHelper
from ansible.module_utils._internal._datatag._tags import Deprecated
from ansible._internal._templating._lazy_containers import _AnsibleLazyTemplateMixin
from ansible._internal._templating._engine import TemplateEngine, TemplateOptions
from ansible._internal._templating._utils import TemplateContext
from ansible._internal._datatag._tags import TrustedAsTemplate, VaultedValue, Origin, SourceWasEncrypted
from ansible._internal._json import _profiles as controller_serialization_profiles
from ansible.module_utils.common.json import get_encoder, get_decoder
from ansible.module_utils._internal._json._profiles import _fallback_to_str
from ansible._internal._json._profiles import _cache_persistence
from ansible.errors import AnsibleRuntimeError

from ..mock.custom_types import CustomMapping, CustomSequence, CustomStr, CustomInt, CustomFloat


basic_values = (
    None,
    True,
    1,
    1.1,
    'hi',
    '汉语',  # non-ASCII string
    b'hi',
    datetime.datetime(2024, 1, 2, 3, 4, 5, 6, datetime.timezone.utc, fold=1),
    datetime.time(1, 2, 3, 4, datetime.timezone.utc, fold=1),
    datetime.date(2024, 1, 2),
    (1,),
    [1],
    CustomSequence([1]),
    CustomStr('hello'),
    CustomInt(42),
    CustomFloat(42.0),
    {1},
    dict(a=1),
    CustomMapping(dict(a=1)),
    {(1, 2): "three"},  # hashable non-scalar key
    {frozenset((1, 2)): "three"},  # hashable non-scalar key
    {1: "two"},  # int key
    {1.1: "two"},  # float key
    {True: "two"},  # bool key
    {None: "two"},  # None key
)

# DTFIX5: we need tests for recursion, specifically things like custom sequences and mappings when:
#                1) using the legacy serializer
#                2) containing types in the type map, such as tagged values
#                e.g. -- does trust inversion get applied to a value inside a custom sequence or mapping

tag_values = {
    Deprecated: Deprecated(msg='x'),  # DTFIX5: we need more exhaustive testing of the values supported by this tag to ensure schema ID is robust
    TrustedAsTemplate: TrustedAsTemplate(),
    Origin: Origin(path='/tmp/x', line_num=1, col_num=2, description='y'),
    VaultedValue: VaultedValue(ciphertext='x'),
    SourceWasEncrypted: SourceWasEncrypted(),
}


def test_cache_persistence_schema() -> None:
    """
    Check the schema ID for the cache_persistence schema to ensure it is updated when the schema changes.
    Failure to update the schema ID will result in serialization/deserialiation failures for persisted data for things like cache plugins.
    This test is only as comprehensive as these unit tests, so ensure profile data types are thoroughly covered.
    If additional capabilities are added to the cache_persistence profile which are not tested, they will go undetected, leading to runtime failures.
    """
    # DTFIX5: update tests to ensure new fields on contracts will fail this test if they have defaults which are omitted from serialization
    #                one possibility: monkeypatch the default field value omission away so that any new field will invalidate the schema

    # DTFIX5: ensure all types/attrs included in _profiles._common_module_response_types are represented here, since they can appear in cached responses

    expected_schema_id = 1
    expected_schema_hash = "0bc4bec94abe6ec0f62fc9f45ea1099ea65b13f00a5e5de1699e0dbcf0de2b2c"

    test_hash = hashlib.sha256()
    test_hash.update(pathlib.Path(DataSet.PROFILE_DIR / _cache_persistence._Profile.profile_name).with_suffix('.txt').read_bytes())

    actual_schema_id = _cache_persistence._Profile.schema_id
    actual_schema_hash = test_hash.hexdigest()

    next_schema_id = actual_schema_id + 1

    schema_check_failure_instructions = f"""The cache_persistence schema check hash has changed. The solution depends on the reason why:

1) The schema and tests have changed:

   i. Increment `ansible._internal._json._profiles._cache_persistence._Profile.schema_id` to {next_schema_id}.
   ii. Update `expected_schema_id` to {next_schema_id}.
   iii. Update `expected_schema_hash` to {actual_schema_hash!r}.

2) The schema is unchanged, but the tests have changed:

   i. Double-check that the schema really hasn't changed.
   ii. Don't forget about added/changed/removed types as well as fields on those types.
   iii. Update `expected_schema_hash` to {actual_schema_hash!r}.
"""

    if actual_schema_id != expected_schema_id:
        raise Exception(f"The actual schema ID {actual_schema_id} does not match the expected schema ID {expected_schema_id}.")

    if actual_schema_hash != expected_schema_hash:
        raise Exception(schema_check_failure_instructions)


def get_profile_names() -> tuple[str, ...]:
    packages = (target_serialization_profiles, controller_serialization_profiles)
    names = []

    for package in packages:
        modules = list(pkgutil.iter_modules(package.__path__, f'{package.__name__}.'))

        assert modules  # ensure at least one serialization profile module was found

        for module in modules:
            names.append(_serialization.get_serialization_profile(module.name).profile_name)

    return tuple(sorted(names))


@dataclasses.dataclass(frozen=True)
class _TestParameters:
    profile_name: str
    value: t.Any
    tags: tuple[AnsibleDatatagBase, ...] = ()
    lazy: bool = False

    def __hash__(self):
        return hash((self.profile_name, repr(self.value), self.tags))

    def __repr__(self):
        fields = ((field, getattr(self, field.name)) for field in dataclasses.fields(self))
        args = (f'{f.name}={v!r}' for f, v in fields if v != f.default)
        return f"{type(self).__name__}({', '.join(args)})"

    def get_test_output(self) -> _TestOutput:
        encoder = get_encoder(self.profile_name)
        decoder = get_decoder(self.profile_name)

        ctx = TemplateContext(
            template_value=self.value,
            templar=TemplateEngine(),
            options=TemplateOptions.DEFAULT,
            stop_on_template=False
        ) if self.lazy else contextlib.nullcontext()

        with ctx:
            try:
                value = AnsibleTagHelper.tag(self.value, self.tags)
            except NotTaggableError:
                value = self.value

            if self.lazy:
                value = _AnsibleLazyTemplateMixin._try_create(value)

            payload: str | Exception

            try:
                payload = json.dumps(value, cls=encoder)
            except Exception as ex:
                payload = ex
                round_trip = None
            else:
                try:
                    round_trip = json.loads(payload, cls=decoder)
                except Exception as ex:
                    round_trip = ex

            return _TestOutput(
                payload=payload,
                round_trip=AnsibleTagHelper.as_native_type(round_trip),
                tags=tuple(sorted(AnsibleTagHelper.tags(round_trip), key=lambda item: type(item).__name__)),
            )


@dataclasses.dataclass(frozen=True)
class _TestOutput:
    payload: str | Exception
    round_trip: t.Any
    tags: tuple[AnsibleDatatagBase, ...]


@dataclasses.dataclass(frozen=True)
class _TestCase:
    parameters: _TestParameters
    expected: _TestOutput

    def __str__(self) -> str:
        parts = [f'profile={self.parameters.profile_name}', f'value={self.parameters.value}']

        if self.parameters.tags:
            parts.append(f"tags={','.join(sorted(type(obj).__name__ for obj in self.parameters.tags))}")

        if self.parameters.lazy:
            parts.append('lazy')

        return '; '.join(parts)


class DataSet:
    PROFILE_DIR = pathlib.Path(__file__).parent / 'expected_serialization_profiles'

    def __init__(self, generate: bool) -> None:
        self.data: dict[_TestParameters, _TestOutput] = {}
        self.path = self.PROFILE_DIR
        self.generate = generate

    def load(self) -> None:
        if self.generate:
            return

        for source in self.path.glob('*.txt'):
            self.data.update(eval(source.read_text()))

    def save(self) -> None:
        if not self.generate:
            return

        sorted_items = sorted(self.data.items(), key=lambda o: o[0].profile_name)  # additional items appended to the end means the data set is unsorted

        grouped_data_set = {key: dict(gen) for key, gen in itertools.groupby(sorted_items, key=lambda o: o[0].profile_name)}

        for group_name, profiles in grouped_data_set.items():
            content = self.generate_content(profiles)
            (self.path / f'{group_name}.txt').write_text(content)

    @staticmethod
    def generate_content(profiles: dict[_TestParameters, _TestOutput]) -> str:
        content = ["{"]

        # loop the dictionary entries manually to ensure one entry per line
        for key, value in profiles.items():
            key_pprint = pprint.pformat(key, width=10000, indent=0, sort_dicts=False)
            value_pprint = pprint.pformat(value, width=10000, indent=0, sort_dicts=False)

            content.append(f"{key_pprint}: {value_pprint},")

        content.append("}")

        return '\n'.join(content) + '\n'

    def fetch_or_create_expected(self, test_params: _TestParameters) -> _TestOutput:
        if self.generate:
            output = self.data[test_params] = test_params.get_test_output()
        else:
            try:
                output = self.data[test_params]
            except KeyError:
                raise Exception(f'Missing {test_params} in data set. Use `generate=True` to update the data set and then review the changes.') from None

        return output


class ProfileHelper:
    def __init__(self, profile_name: str) -> None:
        self.profile_name = profile_name

        profile = _serialization.get_serialization_profile(profile_name)

        supported_tags = {obj: None for obj in profile.serialize_map if issubclass(obj, AnsibleDatatagBase)}

        if supported_tags:
            self.supported_tag_values = tuple(tag_value for tag_type, tag_value in tag_values.items() if tag_type in supported_tags)

            if not self.supported_tag_values:
                raise Exception(f'Profile {profile} supports tags {supported_tags}, but no supported tag value is available.')
        else:
            self.supported_tag_values = tuple()

        self.unsupported_tag_value = next((tag_value for tag_type, tag_value in tag_values.items() if tag_type not in supported_tags), None)

        if not self.unsupported_tag_value and profile.profile_name != _cache_persistence._Profile.profile_name:
            raise Exception(f'Profile {profile} supports tags {supported_tags}, but no unsupported tag value is available.')

    def create_parameters_from_values(self, *values: t.Any) -> list[_TestParameters]:
        return list(itertools.chain.from_iterable(self.create_parameters_from_value(value) for value in values))

    def create_parameters_from_value(self, value: t.Any) -> list[_TestParameters]:
        test_parameters: list[_TestParameters] = [
            _TestParameters(
                profile_name=self.profile_name,
                value=value,
            )
        ]

        if self.supported_tag_values:
            test_parameters.append(_TestParameters(
                profile_name=self.profile_name,
                value=value,
                tags=self.supported_tag_values,
            ))

        if self.unsupported_tag_value:
            test_parameters.append(_TestParameters(
                profile_name=self.profile_name,
                value=value,
                tags=(self.unsupported_tag_value,),
            ))

        # test lazy containers on all non m2c profiles
        if not self.profile_name.endswith("_m2c") and isinstance(value, (list, dict)):
            test_parameters.extend([dataclasses.replace(p, lazy=True) for p in test_parameters])

        return test_parameters


additional_test_parameters: list[_TestParameters] = []

# DTFIX5: need better testing for containers, especially for tagged values in containers

additional_test_parameters.extend(ProfileHelper(_fallback_to_str._Profile.profile_name).create_parameters_from_values(
    b'\x00',  # valid utf-8 strict, JSON escape sequence required
    b'\x80',  # utf-8 strict decoding fails, forcing the use of an error handler such as surrogateescape, JSON escape sequence required
    '\udc80',  # same as above, but already a string (verify that the string version is handled the same as the bytes version)
    {1: "1"},  # integer key
    {b'hi': "1"},  # bytes key
    {TrustedAsTemplate().tag(b'hi'): "2"},  # tagged bytes key
    {(b'hi',): 3},  # tuple[bytes] key
))


_generate = False
"""Set to True to regenerate all test data; a test failure will occur until it is set back to False."""


def get_test_cases() -> list[_TestCase]:
    data_set = DataSet(generate=_generate)
    data_set.load()

    test_parameters: list[_TestParameters] = []

    for profile_name in get_profile_names():
        helper = ProfileHelper(profile_name)

        for value in basic_values:
            test_parameters.extend(helper.create_parameters_from_value(value))

    test_parameters.extend(additional_test_parameters)

    test_cases = [_TestCase(parameters=parameters, expected=data_set.fetch_or_create_expected(parameters)) for parameters in test_parameters]

    data_set.save()

    return test_cases


@pytest.mark.parametrize("test_case", get_test_cases(), ids=str)
def test_profile(test_case: _TestCase) -> None:
    output = test_case.parameters.get_test_output()

    if isinstance(output.payload, Exception):
        if type(output.payload) is not type(test_case.expected.payload):
            raise Exception('unexpected exception') from output.payload

        assert str(output.payload) == str(test_case.expected.payload)
    else:
        assert output.payload == test_case.expected.payload
        assert type(output.round_trip) is type(test_case.expected.round_trip)

        if isinstance(output.round_trip, AnsibleRuntimeError):
            assert str(output.round_trip._original_message) == str(test_case.expected.round_trip._original_message)
        else:
            assert output.round_trip == test_case.expected.round_trip

        assert not set(output.tags).symmetric_difference(test_case.expected.tags)


def test_not_generate_mode():
    assert not _generate, "set _generate=False to statically test expected behavior"