1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
|
# DTFIX-FUTURE: these tests need to be split so they can run under both module_utils and controller contexts
from __future__ import annotations
import contextlib
import dataclasses
import datetime
import hashlib
import itertools
import json
import pathlib
import pkgutil
import pprint
import typing as t
import pytest
from ansible.module_utils._internal._json import _profiles as target_serialization_profiles
from ansible.module_utils._internal import _json as _serialization
from ansible.module_utils._internal._datatag import AnsibleDatatagBase, NotTaggableError, AnsibleTagHelper
from ansible.module_utils._internal._datatag._tags import Deprecated
from ansible._internal._templating._lazy_containers import _AnsibleLazyTemplateMixin
from ansible._internal._templating._engine import TemplateEngine, TemplateOptions
from ansible._internal._templating._utils import TemplateContext
from ansible._internal._datatag._tags import TrustedAsTemplate, VaultedValue, Origin, SourceWasEncrypted
from ansible._internal._json import _profiles as controller_serialization_profiles
from ansible.module_utils.common.json import get_encoder, get_decoder
from ansible.module_utils._internal._json._profiles import _fallback_to_str
from ansible._internal._json._profiles import _cache_persistence
from ansible.errors import AnsibleRuntimeError
from ..mock.custom_types import CustomMapping, CustomSequence, CustomStr, CustomInt, CustomFloat
basic_values = (
None,
True,
1,
1.1,
'hi',
'汉语', # non-ASCII string
b'hi',
datetime.datetime(2024, 1, 2, 3, 4, 5, 6, datetime.timezone.utc, fold=1),
datetime.time(1, 2, 3, 4, datetime.timezone.utc, fold=1),
datetime.date(2024, 1, 2),
(1,),
[1],
CustomSequence([1]),
CustomStr('hello'),
CustomInt(42),
CustomFloat(42.0),
{1},
dict(a=1),
CustomMapping(dict(a=1)),
{(1, 2): "three"}, # hashable non-scalar key
{frozenset((1, 2)): "three"}, # hashable non-scalar key
{1: "two"}, # int key
{1.1: "two"}, # float key
{True: "two"}, # bool key
{None: "two"}, # None key
)
# DTFIX5: we need tests for recursion, specifically things like custom sequences and mappings when:
# 1) using the legacy serializer
# 2) containing types in the type map, such as tagged values
# e.g. -- does trust inversion get applied to a value inside a custom sequence or mapping
tag_values = {
Deprecated: Deprecated(msg='x'), # DTFIX5: we need more exhaustive testing of the values supported by this tag to ensure schema ID is robust
TrustedAsTemplate: TrustedAsTemplate(),
Origin: Origin(path='/tmp/x', line_num=1, col_num=2, description='y'),
VaultedValue: VaultedValue(ciphertext='x'),
SourceWasEncrypted: SourceWasEncrypted(),
}
def test_cache_persistence_schema() -> None:
"""
Check the schema ID for the cache_persistence schema to ensure it is updated when the schema changes.
Failure to update the schema ID will result in serialization/deserialiation failures for persisted data for things like cache plugins.
This test is only as comprehensive as these unit tests, so ensure profile data types are thoroughly covered.
If additional capabilities are added to the cache_persistence profile which are not tested, they will go undetected, leading to runtime failures.
"""
# DTFIX5: update tests to ensure new fields on contracts will fail this test if they have defaults which are omitted from serialization
# one possibility: monkeypatch the default field value omission away so that any new field will invalidate the schema
# DTFIX5: ensure all types/attrs included in _profiles._common_module_response_types are represented here, since they can appear in cached responses
expected_schema_id = 1
expected_schema_hash = "0bc4bec94abe6ec0f62fc9f45ea1099ea65b13f00a5e5de1699e0dbcf0de2b2c"
test_hash = hashlib.sha256()
test_hash.update(pathlib.Path(DataSet.PROFILE_DIR / _cache_persistence._Profile.profile_name).with_suffix('.txt').read_bytes())
actual_schema_id = _cache_persistence._Profile.schema_id
actual_schema_hash = test_hash.hexdigest()
next_schema_id = actual_schema_id + 1
schema_check_failure_instructions = f"""The cache_persistence schema check hash has changed. The solution depends on the reason why:
1) The schema and tests have changed:
i. Increment `ansible._internal._json._profiles._cache_persistence._Profile.schema_id` to {next_schema_id}.
ii. Update `expected_schema_id` to {next_schema_id}.
iii. Update `expected_schema_hash` to {actual_schema_hash!r}.
2) The schema is unchanged, but the tests have changed:
i. Double-check that the schema really hasn't changed.
ii. Don't forget about added/changed/removed types as well as fields on those types.
iii. Update `expected_schema_hash` to {actual_schema_hash!r}.
"""
if actual_schema_id != expected_schema_id:
raise Exception(f"The actual schema ID {actual_schema_id} does not match the expected schema ID {expected_schema_id}.")
if actual_schema_hash != expected_schema_hash:
raise Exception(schema_check_failure_instructions)
def get_profile_names() -> tuple[str, ...]:
packages = (target_serialization_profiles, controller_serialization_profiles)
names = []
for package in packages:
modules = list(pkgutil.iter_modules(package.__path__, f'{package.__name__}.'))
assert modules # ensure at least one serialization profile module was found
for module in modules:
names.append(_serialization.get_serialization_profile(module.name).profile_name)
return tuple(sorted(names))
@dataclasses.dataclass(frozen=True)
class _TestParameters:
profile_name: str
value: t.Any
tags: tuple[AnsibleDatatagBase, ...] = ()
lazy: bool = False
def __hash__(self):
return hash((self.profile_name, repr(self.value), self.tags))
def __repr__(self):
fields = ((field, getattr(self, field.name)) for field in dataclasses.fields(self))
args = (f'{f.name}={v!r}' for f, v in fields if v != f.default)
return f"{type(self).__name__}({', '.join(args)})"
def get_test_output(self) -> _TestOutput:
encoder = get_encoder(self.profile_name)
decoder = get_decoder(self.profile_name)
ctx = TemplateContext(
template_value=self.value,
templar=TemplateEngine(),
options=TemplateOptions.DEFAULT,
stop_on_template=False
) if self.lazy else contextlib.nullcontext()
with ctx:
try:
value = AnsibleTagHelper.tag(self.value, self.tags)
except NotTaggableError:
value = self.value
if self.lazy:
value = _AnsibleLazyTemplateMixin._try_create(value)
payload: str | Exception
try:
payload = json.dumps(value, cls=encoder)
except Exception as ex:
payload = ex
round_trip = None
else:
try:
round_trip = json.loads(payload, cls=decoder)
except Exception as ex:
round_trip = ex
return _TestOutput(
payload=payload,
round_trip=AnsibleTagHelper.as_native_type(round_trip),
tags=tuple(sorted(AnsibleTagHelper.tags(round_trip), key=lambda item: type(item).__name__)),
)
@dataclasses.dataclass(frozen=True)
class _TestOutput:
payload: str | Exception
round_trip: t.Any
tags: tuple[AnsibleDatatagBase, ...]
@dataclasses.dataclass(frozen=True)
class _TestCase:
parameters: _TestParameters
expected: _TestOutput
def __str__(self) -> str:
parts = [f'profile={self.parameters.profile_name}', f'value={self.parameters.value}']
if self.parameters.tags:
parts.append(f"tags={','.join(sorted(type(obj).__name__ for obj in self.parameters.tags))}")
if self.parameters.lazy:
parts.append('lazy')
return '; '.join(parts)
class DataSet:
PROFILE_DIR = pathlib.Path(__file__).parent / 'expected_serialization_profiles'
def __init__(self, generate: bool) -> None:
self.data: dict[_TestParameters, _TestOutput] = {}
self.path = self.PROFILE_DIR
self.generate = generate
def load(self) -> None:
if self.generate:
return
for source in self.path.glob('*.txt'):
self.data.update(eval(source.read_text()))
def save(self) -> None:
if not self.generate:
return
sorted_items = sorted(self.data.items(), key=lambda o: o[0].profile_name) # additional items appended to the end means the data set is unsorted
grouped_data_set = {key: dict(gen) for key, gen in itertools.groupby(sorted_items, key=lambda o: o[0].profile_name)}
for group_name, profiles in grouped_data_set.items():
content = self.generate_content(profiles)
(self.path / f'{group_name}.txt').write_text(content)
@staticmethod
def generate_content(profiles: dict[_TestParameters, _TestOutput]) -> str:
content = ["{"]
# loop the dictionary entries manually to ensure one entry per line
for key, value in profiles.items():
key_pprint = pprint.pformat(key, width=10000, indent=0, sort_dicts=False)
value_pprint = pprint.pformat(value, width=10000, indent=0, sort_dicts=False)
content.append(f"{key_pprint}: {value_pprint},")
content.append("}")
return '\n'.join(content) + '\n'
def fetch_or_create_expected(self, test_params: _TestParameters) -> _TestOutput:
if self.generate:
output = self.data[test_params] = test_params.get_test_output()
else:
try:
output = self.data[test_params]
except KeyError:
raise Exception(f'Missing {test_params} in data set. Use `generate=True` to update the data set and then review the changes.') from None
return output
class ProfileHelper:
def __init__(self, profile_name: str) -> None:
self.profile_name = profile_name
profile = _serialization.get_serialization_profile(profile_name)
supported_tags = {obj: None for obj in profile.serialize_map if issubclass(obj, AnsibleDatatagBase)}
if supported_tags:
self.supported_tag_values = tuple(tag_value for tag_type, tag_value in tag_values.items() if tag_type in supported_tags)
if not self.supported_tag_values:
raise Exception(f'Profile {profile} supports tags {supported_tags}, but no supported tag value is available.')
else:
self.supported_tag_values = tuple()
self.unsupported_tag_value = next((tag_value for tag_type, tag_value in tag_values.items() if tag_type not in supported_tags), None)
if not self.unsupported_tag_value and profile.profile_name != _cache_persistence._Profile.profile_name:
raise Exception(f'Profile {profile} supports tags {supported_tags}, but no unsupported tag value is available.')
def create_parameters_from_values(self, *values: t.Any) -> list[_TestParameters]:
return list(itertools.chain.from_iterable(self.create_parameters_from_value(value) for value in values))
def create_parameters_from_value(self, value: t.Any) -> list[_TestParameters]:
test_parameters: list[_TestParameters] = [
_TestParameters(
profile_name=self.profile_name,
value=value,
)
]
if self.supported_tag_values:
test_parameters.append(_TestParameters(
profile_name=self.profile_name,
value=value,
tags=self.supported_tag_values,
))
if self.unsupported_tag_value:
test_parameters.append(_TestParameters(
profile_name=self.profile_name,
value=value,
tags=(self.unsupported_tag_value,),
))
# test lazy containers on all non m2c profiles
if not self.profile_name.endswith("_m2c") and isinstance(value, (list, dict)):
test_parameters.extend([dataclasses.replace(p, lazy=True) for p in test_parameters])
return test_parameters
additional_test_parameters: list[_TestParameters] = []
# DTFIX5: need better testing for containers, especially for tagged values in containers
additional_test_parameters.extend(ProfileHelper(_fallback_to_str._Profile.profile_name).create_parameters_from_values(
b'\x00', # valid utf-8 strict, JSON escape sequence required
b'\x80', # utf-8 strict decoding fails, forcing the use of an error handler such as surrogateescape, JSON escape sequence required
'\udc80', # same as above, but already a string (verify that the string version is handled the same as the bytes version)
{1: "1"}, # integer key
{b'hi': "1"}, # bytes key
{TrustedAsTemplate().tag(b'hi'): "2"}, # tagged bytes key
{(b'hi',): 3}, # tuple[bytes] key
))
_generate = False
"""Set to True to regenerate all test data; a test failure will occur until it is set back to False."""
def get_test_cases() -> list[_TestCase]:
data_set = DataSet(generate=_generate)
data_set.load()
test_parameters: list[_TestParameters] = []
for profile_name in get_profile_names():
helper = ProfileHelper(profile_name)
for value in basic_values:
test_parameters.extend(helper.create_parameters_from_value(value))
test_parameters.extend(additional_test_parameters)
test_cases = [_TestCase(parameters=parameters, expected=data_set.fetch_or_create_expected(parameters)) for parameters in test_parameters]
data_set.save()
return test_cases
@pytest.mark.parametrize("test_case", get_test_cases(), ids=str)
def test_profile(test_case: _TestCase) -> None:
output = test_case.parameters.get_test_output()
if isinstance(output.payload, Exception):
if type(output.payload) is not type(test_case.expected.payload):
raise Exception('unexpected exception') from output.payload
assert str(output.payload) == str(test_case.expected.payload)
else:
assert output.payload == test_case.expected.payload
assert type(output.round_trip) is type(test_case.expected.round_trip)
if isinstance(output.round_trip, AnsibleRuntimeError):
assert str(output.round_trip._original_message) == str(test_case.expected.round_trip._original_message)
else:
assert output.round_trip == test_case.expected.round_trip
assert not set(output.tags).symmetric_difference(test_case.expected.tags)
def test_not_generate_mode():
assert not _generate, "set _generate=False to statically test expected behavior"
|