1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
"""Pytest utils for integration tests."""
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from vcr.persisters.filesystem import FilesystemPersister
from vcr.serialize import deserialize, serialize
from tests.conftest import placeholders as _placeholders
def ensure_integration_test(cassette): # pragma: no cover
"""Ensure test is being run is actually an integration test and error if not."""
if cassette.write_protected: # pragma: no cover
is_integration_test = cassette.play_count > 0
action = "play back"
else:
is_integration_test = cassette.dirty
action = "record"
message = f"Cassette did not {action} any requests. This test can be a unit test."
assert is_integration_test, message
def filter_access_token(response): # pragma: no cover
"""Add VCR callback to filter access token."""
request_uri = response["url"]
if "api/v1/access_token" not in request_uri or response["status"]["code"] != 200:
return response
body = response["body"]["string"].decode()
for token_key in ["access", "refresh"]:
try:
token = json.loads(body)[f"{token_key}_token"]
except (KeyError, TypeError, ValueError):
continue
response["body"]["string"] = response["body"]["string"].replace(
token.encode("utf-8"), f"<{token_key.upper()}_TOKEN>".encode()
)
_placeholders[f"{token_key}_token"] = token
return response
class CustomPersister(FilesystemPersister):
"""Custom persister to handle placeholders."""
additional_placeholders = {}
@classmethod
def add_additional_placeholders(cls, placeholders: dict[str, str]): # pragma: no cover
"""Add additional placeholders."""
cls.additional_placeholders.update(placeholders)
@classmethod
def clear_additional_placeholders(cls): # pragma: no cover
"""Clear additional placeholders."""
cls.additional_placeholders = {}
@classmethod
def load_cassette(cls, cassette_path, serializer): # pragma: no cover
"""Load cassette."""
try:
with Path(cassette_path).open() as f:
cassette_content = f.read()
except OSError as error:
msg = "Cassette not found."
raise ValueError(msg) from error
for replacement, value in [
(v, f"<{k.upper()}>") for k, v in {**cls.additional_placeholders, **_placeholders}.items()
]:
cassette_content = cassette_content.replace(value, replacement)
return deserialize(cassette_content, serializer)
@classmethod
def save_cassette(cls, cassette_path, cassette_dict, serializer): # pragma: no cover
"""Save cassette."""
cassette_path = Path(cassette_path)
data = serialize(cassette_dict, serializer)
for replacement, value in [
(f"<{k.upper()}>", v) for k, v in {**cls.additional_placeholders, **_placeholders}.items()
]:
data = data.replace(value, replacement)
dirname = cassette_path.parent
if dirname and not dirname.exists():
dirname.mkdir(parents=True)
with cassette_path.open("w") as f:
f.write(data)
class CustomSerializer:
"""Custom serializer to handle binary objects in dict."""
@staticmethod
def _serialize_file(file): # pragma: no cover
with Path(file.name).open("rb") as f:
return f.read().decode("utf-8", "replace")
@staticmethod
def deserialize(cassette_string):
return json.loads(cassette_string)
@classmethod
def _serialize_dict(cls, data: dict): # pragma: no cover
"""This is to filter out buffered readers."""
new_dict = {}
for key, value in data.items():
if key == "file":
new_dict[key] = cls._serialize_file(value)
elif isinstance(value, dict):
new_dict[key] = cls._serialize_dict(value)
elif isinstance(value, list):
new_dict[key] = cls._serialize_list(value)
else:
new_dict[key] = value
return new_dict
@classmethod
def _serialize_list(cls, data: list): # pragma: no cover
new_list = []
for item in data:
if isinstance(item, dict):
new_list.append(cls._serialize_dict(item))
elif isinstance(item, list):
new_list.append(cls._serialize_list(item))
elif isinstance(item, tuple):
file = None
if item[0] == "file":
file = (item[0], cls._serialize_file(item[1]))
new_list.append(file or item)
else:
new_list.append(item)
return new_list
@classmethod
def serialize(cls, cassette_dict): # pragma: no cover
"""Serialize cassette."""
timestamp = datetime.now(tz=timezone.utc).isoformat()
try:
i = timestamp.rindex(".")
except ValueError:
pass
else:
timestamp = timestamp[:i]
cassette_dict["recorded_at"] = timestamp
return f"{json.dumps(cls._serialize_dict(cassette_dict), sort_keys=True, indent=2)}\n"
|