1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
|
import asyncio
import enum
import logging
from io import BytesIO
import aiodns
import aiodns.error
import aiohttp
from . import defaults
from .utils import parse_mta_sts_record, parse_mta_sts_policy, is_plaintext, filter_text
from .constants import HARD_RESP_LIMIT, CHUNK
class BadSTSPolicy(Exception):
pass
class STSFetchResult(enum.Enum):
NONE = 0
VALID = 1
FETCH_ERROR = 2
NOT_CHANGED = 3
_HEADERS = {"User-Agent": defaults.USER_AGENT}
# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-statements
class STSResolver:
def __init__(self, *, timeout=defaults.TIMEOUT, loop):
self._loop = loop
self._timeout = timeout
self._resolver = aiodns.DNSResolver(timeout=timeout, loop=loop)
self._http_timeout = aiohttp.ClientTimeout(total=timeout)
self._proxy_info = aiohttp.helpers.proxies_from_env().get('https', None)
self._logger = logging.getLogger("RES")
if self._proxy_info is None:
self._proxy = None
self._proxy_auth = None
else:
self._proxy = self._proxy_info.proxy
self._proxy_auth = self._proxy_info.proxy_auth
# pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
async def resolve(self, domain, last_known_id=None):
if domain.startswith('.'):
return STSFetchResult.NONE, None
# Cleanup domain name
domain = domain.rstrip('.')
# Construct name of corresponding MTA-STS DNS record for domain
sts_txt_domain = '_mta-sts.' + domain
self._logger.debug("Got STS resolve request: sts_txt_domain=%s, "
"known_id=%s", sts_txt_domain, last_known_id)
# Try to fetch it
try:
txt_records = await asyncio.wait_for(
self._resolver.query(sts_txt_domain, 'TXT'),
timeout=self._timeout)
except aiodns.error.DNSError as error:
if error.args[0] == aiodns.error.ARES_ETIMEOUT: # pragma: no cover pylint: disable=no-else-return,no-member
# This branch is not covered because of aiodns bug:
# https://github.com/saghul/aiodns/pull/64
# It's hard to decide what to do in case of timeout
# Probably it's better to threat this as fetch error
# so caller probably shall report such cases.
return STSFetchResult.FETCH_ERROR, None
elif error.args[0] == aiodns.error.ARES_ENOTFOUND: # pylint: disable=no-else-return,no-member
return STSFetchResult.NONE, None
elif error.args[0] == aiodns.error.ARES_ENODATA: # pylint: disable=no-else-return,no-member
return STSFetchResult.NONE, None
else: # pragma: no cover
return STSFetchResult.FETCH_ERROR, None
except asyncio.TimeoutError:
return STSFetchResult.FETCH_ERROR, None
# workaround for floating return type of pycares
txt_records = filter_text(rec.text for rec in txt_records)
# RFC 8461 strictly defines version string as first field
txt_records = [txt for txt in txt_records
if txt.startswith('v=STSv1')]
# Exactly one record should exist
if len(txt_records) != 1:
return STSFetchResult.NONE, None
# Validate record
mta_sts_record = parse_mta_sts_record(txt_records[0])
if (mta_sts_record.get('v', None) != 'STSv1'
or 'id' not in mta_sts_record):
return STSFetchResult.NONE, None
self._logger.debug("Parsed STS record for domain %s: %s",
repr(domain), repr(mta_sts_record))
# Obtain policy ID and return NOT_CHANGED if ID is equal to last known
if mta_sts_record['id'] == last_known_id:
return STSFetchResult.NOT_CHANGED, None
# Construct corresponding URL of MTA-STS policy
sts_policy_url = ('https://mta-sts.' +
domain +
'/.well-known/mta-sts.txt')
# Fetch actual policy
try:
async with aiohttp.ClientSession(loop=self._loop,
timeout=self._http_timeout) \
as session:
async with session.get(sts_policy_url,
allow_redirects=False,
proxy=self._proxy, headers=_HEADERS,
proxy_auth=self._proxy_auth) as resp:
if resp.status != 200:
raise BadSTSPolicy()
if not is_plaintext(resp.headers.get('Content-Type', '')):
raise BadSTSPolicy()
if (int(resp.headers.get('Content-Length', '0')) >
HARD_RESP_LIMIT):
raise BadSTSPolicy()
policy_file = BytesIO()
while policy_file.tell() <= HARD_RESP_LIMIT:
chunk = await resp.content.read(CHUNK)
if not chunk:
break
policy_file.write(chunk)
else:
raise BadSTSPolicy()
charset = (resp.charset if resp.charset is not None
else 'ascii')
policy_text = policy_file.getvalue().decode(charset)
except Exception as exc:
self._logger.warning("STS policy fetch for domain %s failed with "
"error: %s", repr(domain), str(exc))
return STSFetchResult.FETCH_ERROR, None
# Parse policy
pol = parse_mta_sts_policy(policy_text)
self._logger.debug("Parsed policy for domain %s: %s", domain, repr(pol))
# Validate policy
if pol.get('version', None) != 'STSv1':
return STSFetchResult.FETCH_ERROR, None
try:
max_age = int(pol.get('max_age', '-1'))
pol['max_age'] = max_age
except ValueError:
return STSFetchResult.FETCH_ERROR, None
if not 0 <= max_age <= 31557600:
return STSFetchResult.FETCH_ERROR, None
if 'mode' not in pol:
return STSFetchResult.FETCH_ERROR, None
# No MX check required for 'none' policy:
if pol['mode'] == 'none':
return STSFetchResult.VALID, (mta_sts_record['id'], pol)
if pol['mode'] not in ('none', 'testing', 'enforce'):
return STSFetchResult.FETCH_ERROR, None
if not pol['mx']:
return STSFetchResult.FETCH_ERROR, None
# Policy is valid. Returning result.
return STSFetchResult.VALID, (mta_sts_record['id'], pol)
|