File: resolver.py

package info (click to toggle)
postfix-mta-sts-resolver 1.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 536 kB
  • sloc: python: 3,069; sh: 226; makefile: 47
file content (174 lines) | stat: -rw-r--r-- 6,940 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import asyncio
import enum
import logging
from io import BytesIO

import aiodns
import aiodns.error
import aiohttp

from . import defaults
from .utils import parse_mta_sts_record, parse_mta_sts_policy, is_plaintext, filter_text
from .constants import HARD_RESP_LIMIT, CHUNK


class BadSTSPolicy(Exception):
    pass


class STSFetchResult(enum.Enum):
    NONE = 0
    VALID = 1
    FETCH_ERROR = 2
    NOT_CHANGED = 3


_HEADERS = {"User-Agent": defaults.USER_AGENT}

# pylint: disable=too-few-public-methods
# pylint: disable=too-many-instance-attributes
# pylint: disable=too-many-statements
class STSResolver:
    def __init__(self, *, timeout=defaults.TIMEOUT, loop):
        self._loop = loop
        self._timeout = timeout
        self._resolver = aiodns.DNSResolver(timeout=timeout, loop=loop)
        self._http_timeout = aiohttp.ClientTimeout(total=timeout)
        self._proxy_info = aiohttp.helpers.proxies_from_env().get('https', None)
        self._logger = logging.getLogger("RES")

        if self._proxy_info is None:
            self._proxy = None
            self._proxy_auth = None
        else:
            self._proxy = self._proxy_info.proxy
            self._proxy_auth = self._proxy_info.proxy_auth

    # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements
    async def resolve(self, domain, last_known_id=None):
        if domain.startswith('.'):
            return STSFetchResult.NONE, None
        # Cleanup domain name
        domain = domain.rstrip('.')

        # Construct name of corresponding MTA-STS DNS record for domain
        sts_txt_domain = '_mta-sts.' + domain
        self._logger.debug("Got STS resolve request: sts_txt_domain=%s, "
                           "known_id=%s", sts_txt_domain, last_known_id)

        # Try to fetch it
        try:
            txt_records = await asyncio.wait_for(
                self._resolver.query(sts_txt_domain, 'TXT'),
                timeout=self._timeout)
        except aiodns.error.DNSError as error:
            if error.args[0] == aiodns.error.ARES_ETIMEOUT:  # pragma: no cover pylint: disable=no-else-return,no-member
                # This branch is not covered because of aiodns bug:
                # https://github.com/saghul/aiodns/pull/64
                # It's hard to decide what to do in case of timeout
                # Probably it's better to threat this as fetch error
                # so caller probably shall report such cases.
                return STSFetchResult.FETCH_ERROR, None
            elif error.args[0] == aiodns.error.ARES_ENOTFOUND:  # pylint: disable=no-else-return,no-member
                return STSFetchResult.NONE, None
            elif error.args[0] == aiodns.error.ARES_ENODATA:  # pylint: disable=no-else-return,no-member
                return STSFetchResult.NONE, None
            else:  # pragma: no cover
                return STSFetchResult.FETCH_ERROR, None
        except asyncio.TimeoutError:
            return STSFetchResult.FETCH_ERROR, None

        # workaround for floating return type of pycares
        txt_records = filter_text(rec.text for rec in txt_records)

        # RFC 8461 strictly defines version string as first field
        txt_records = [txt for txt in txt_records
                       if txt.startswith('v=STSv1')]

        # Exactly one record should exist
        if len(txt_records) != 1:
            return STSFetchResult.NONE, None

        # Validate record
        mta_sts_record = parse_mta_sts_record(txt_records[0])
        if (mta_sts_record.get('v', None) != 'STSv1'
                or 'id' not in mta_sts_record):
            return STSFetchResult.NONE, None

        self._logger.debug("Parsed STS record for domain %s: %s",
                           repr(domain), repr(mta_sts_record))

        # Obtain policy ID and return NOT_CHANGED if ID is equal to last known
        if mta_sts_record['id'] == last_known_id:
            return STSFetchResult.NOT_CHANGED, None

        # Construct corresponding URL of MTA-STS policy
        sts_policy_url = ('https://mta-sts.' +
                          domain +
                          '/.well-known/mta-sts.txt')

        # Fetch actual policy
        try:
            async with aiohttp.ClientSession(loop=self._loop,
                                             timeout=self._http_timeout) \
                                                 as session:
                async with session.get(sts_policy_url,
                                       allow_redirects=False,
                                       proxy=self._proxy, headers=_HEADERS,
                                       proxy_auth=self._proxy_auth) as resp:
                    if resp.status != 200:
                        raise BadSTSPolicy()
                    if not is_plaintext(resp.headers.get('Content-Type', '')):
                        raise BadSTSPolicy()
                    if (int(resp.headers.get('Content-Length', '0')) >
                            HARD_RESP_LIMIT):
                        raise BadSTSPolicy()
                    policy_file = BytesIO()
                    while policy_file.tell() <= HARD_RESP_LIMIT:
                        chunk = await resp.content.read(CHUNK)
                        if not chunk:
                            break
                        policy_file.write(chunk)
                    else:
                        raise BadSTSPolicy()
                    charset = (resp.charset if resp.charset is not None
                               else 'ascii')
                    policy_text = policy_file.getvalue().decode(charset)
        except Exception as exc:
            self._logger.warning("STS policy fetch for domain %s failed with "
                                 "error: %s", repr(domain), str(exc))
            return STSFetchResult.FETCH_ERROR, None

        # Parse policy
        pol = parse_mta_sts_policy(policy_text)

        self._logger.debug("Parsed policy for domain %s: %s", domain, repr(pol))

        # Validate policy
        if pol.get('version', None) != 'STSv1':
            return STSFetchResult.FETCH_ERROR, None

        try:
            max_age = int(pol.get('max_age', '-1'))
            pol['max_age'] = max_age
        except ValueError:
            return STSFetchResult.FETCH_ERROR, None

        if not 0 <= max_age <= 31557600:
            return STSFetchResult.FETCH_ERROR, None

        if 'mode' not in pol:
            return STSFetchResult.FETCH_ERROR, None

        # No MX check required for 'none' policy:
        if pol['mode'] == 'none':
            return STSFetchResult.VALID, (mta_sts_record['id'], pol)

        if pol['mode'] not in ('none', 'testing', 'enforce'):
            return STSFetchResult.FETCH_ERROR, None

        if not pol['mx']:
            return STSFetchResult.FETCH_ERROR, None

        # Policy is valid. Returning result.
        return STSFetchResult.VALID, (mta_sts_record['id'], pol)