File: test_protocols.py

package info (click to toggle)
python-moto 5.1.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 116,520 kB
  • sloc: python: 636,725; javascript: 181; makefile: 39; sh: 3
file content (99 lines) | stat: -rw-r--r-- 3,499 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# mypy: ignore-errors
import copy
import json
import os
from enum import Enum

import pytest

from moto.core.model import OperationModel, ServiceModel
from moto.core.serialize import SERIALIZERS

TEST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "protocols")
PROTOCOL_TEST_BLACKLIST = [
    "REST XML Event Stream",
    "RPC JSON Event Stream",
]


class TestType(Enum):
    __test__ = False  # Tell test runner to ignore this class

    INPUT = "input"
    OUTPUT = "output"


def _compliance_tests(test_type=None):
    inp = test_type is None or test_type is TestType.INPUT
    out = test_type is None or test_type is TestType.OUTPUT

    for full_path in _walk_files():
        if full_path.endswith(".json"):
            for model, case, protocol in _load_cases(full_path):
                if model.get("description") in PROTOCOL_TEST_BLACKLIST:
                    continue
                description = case["description"]
                test_name = f"{protocol}-protocol-{description}"
                if "params" in case and inp:
                    yield pytest.param(model, case, protocol, id=test_name)
                elif "response" in case and out:
                    yield pytest.param(model, case, protocol, id=test_name)


def _walk_files():
    for root, _, filenames in os.walk(TEST_DIR):
        for filename in filenames:
            yield os.path.join(root, filename)


def _load_cases(full_path):
    all_test_data = json.load(open(full_path))
    protocol = os.path.basename(full_path).split(".")[0]
    for test_data in all_test_data:
        cases = test_data.pop("cases")
        description = test_data["description"]
        for index, case in enumerate(cases):
            case["description"] = description
            case["id"] = index
            yield test_data, case, protocol


@pytest.mark.parametrize(
    "json_description, case, protocol", _compliance_tests(TestType.OUTPUT)
)
def test_output_compliance(json_description: dict, case: dict, protocol):
    service_description = copy.deepcopy(json_description)
    model = ServiceModel(service_description)
    operation_model = OperationModel(case["given"], model)
    protocol_serializer = SERIALIZERS[protocol]
    serializer = protocol_serializer(operation_model)
    result = case["result"] if "error" not in case else _create_exception(case)
    resp = serializer.serialize(result)  # _to_response(result)
    assert resp["body"] == case["response"]["body"]
    assert "Content-Type" in resp["headers"]
    protocol_to_content_type = {
        "ec2": "text/xml",
        "json": "application/x-amz-json-1.0",
        "query": "text/xml",
        "query-json": "application/json",
        "rest-xml": "text/xml",
        "rest-json": "application/json",
    }
    assert resp["headers"]["Content-Type"] == protocol_to_content_type[protocol]
    headers_expected = case["response"]["headers"]
    # TODO: Get rid of this if once we get the headers sorted for all responses
    if headers_expected:
        del resp["headers"]["Content-Type"]
        assert resp["headers"] == headers_expected
    assert resp["status_code"] == case["response"]["status_code"]


def _create_exception(case):
    exc = type(case["errorCode"], (Exception,), {})()
    exc.code = case["errorCode"]
    if "errorMessage" in case:
        exc.message = case["errorMessage"]
        exc.Message = case["errorMessage"]
    for key, value in case["error"].items():
        setattr(exc, key, value)
    return exc