1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
import tempfile
import collections.abc
import enum
import itertools
import time
import pytest
import postfix_mta_sts_resolver.utils as utils
@pytest.mark.parametrize("cfg", [None,
{},
{
"zones": {
"aaa": {},
"bbb": {},
}
},
])
def test_populate_cfg_defaults(cfg):
res = utils.populate_cfg_defaults(cfg)
assert isinstance(res['host'], str)
assert isinstance(res['port'], int)
assert 0 < res['port'] < 65536
assert isinstance(res['cache_grace'], (int, float))
assert isinstance(res['proactive_policy_fetching']['enabled'], bool)
assert isinstance(res['proactive_policy_fetching']['interval'], int)
assert isinstance(res['proactive_policy_fetching']['concurrency_limit'], int)
assert isinstance(res['proactive_policy_fetching']['grace_ratio'], (int, float))
assert isinstance(res['cache'], collections.abc.Mapping)
assert res['cache']['type'] in ('redis', 'sqlite', 'postgres', 'internal')
assert isinstance(res['default_zone'], collections.abc.Mapping)
assert isinstance(res['zones'], collections.abc.Mapping)
for zone in list(res['zones'].values()) + [res['default_zone']]:
assert isinstance(zone, collections.abc.Mapping)
assert 'timeout' in zone
assert 'strict_testing' in zone
def test_empty_config():
assert utils.load_config('/dev/null') == utils.populate_cfg_defaults(None)
@pytest.mark.parametrize("rec,expected", [
("v=STSv1; id=20160831085700Z;", {"v": "STSv1", "id": "20160831085700Z"}),
("v=STSv1;id=20160831085700Z;", {"v": "STSv1", "id": "20160831085700Z"}),
("v=STSv1; id=20160831085700Z", {"v": "STSv1", "id": "20160831085700Z"}),
("v=STSv1;id=20160831085700Z", {"v": "STSv1", "id": "20160831085700Z"}),
("v=STSv1; id=20160831085700Z ", {"v": "STSv1", "id": "20160831085700Z"}),
("", {}),
(" ", {}),
(" ; ; ", {}),
("v=STSv1; id=20160831085700Z;;;", {"v": "STSv1", "id": "20160831085700Z"}),
])
def test_parse_mta_sts_record(rec, expected):
assert utils.parse_mta_sts_record(rec) == expected
@pytest.mark.parametrize("contenttype,expected", [
("text/plain", True),
("TEXT/PLAIN", True),
("TeXT/PlAiN", True),
("text/plain;charset=utf-8", True),
("text/plain;charset=UTF-8", True),
("text/plain; charset=UTF-8", True),
("text/plain ; charset=UTF-8", True),
("application/octet-stream", False),
("application/octet-stream+text/plain", False),
("application/json+text/plain", False),
("text/plain+", False),
])
def test_is_plaintext(contenttype, expected):
assert utils.is_plaintext(contenttype) == expected
class TextType(enum.Enum):
ascii_byte_string = 1
nonascii_byte_string = 2
unicode_string = 3
invalid_string = 4
text_args = [
(b"aaa", TextType.ascii_byte_string),
(b"\xff", TextType.nonascii_byte_string),
("aaa", TextType.unicode_string),
(None, TextType.invalid_string),
(0, TextType.invalid_string),
]
text_params = []
for length in range(0, 5):
text_params.extend(itertools.product(text_args, repeat=length))
@pytest.mark.parametrize("vector", text_params)
def test_filter_text(vector):
if any(typ is TextType.invalid_string for (_, typ) in vector):
with pytest.raises(TypeError):
for _ in utils.filter_text(val for (val, _) in vector):
pass
else:
res = list(utils.filter_text(val for (val, _) in vector))
nonskipped = (pair for pair in vector if pair[1] is not TextType.nonascii_byte_string)
for left, (right_val, right_type) in zip(res, nonskipped):
if right_type is TextType.unicode_string:
assert left == right_val
else:
assert left.encode('ascii') == right_val
def test_setup_logger():
with tempfile.NamedTemporaryFile('r') as tmpfile:
with utils.AsyncLoggingHandler(tmpfile.name) as log_handler:
logger = utils.setup_logger("test", utils.LogLevel.info, log_handler)
logger.info("Hello World!")
time.sleep(1)
assert "Hello World!" in tmpfile.read()
def test_setup_logger_overflow():
with tempfile.NamedTemporaryFile('r') as tmpfile:
with utils.AsyncLoggingHandler(tmpfile.name, 1) as log_handler:
logger = utils.setup_logger("test", utils.LogLevel.info, log_handler)
for _ in range(10):
logger.info("Hello World!")
time.sleep(1)
assert "Hello World!" in tmpfile.read()
def test_setup_logger_stderr(capsys):
with utils.AsyncLoggingHandler() as log_handler:
logger = utils.setup_logger("test", utils.LogLevel.info, log_handler)
logger.info("Hello World!")
time.sleep(1)
captured = capsys.readouterr()
assert "Hello World!" in captured.err
|