File: network.py

package info (click to toggle)
python-aioxmpp 0.12.2-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 6,152 kB
  • sloc: python: 96,969; xml: 215; makefile: 155; sh: 72
file content (465 lines) | stat: -rw-r--r-- 15,052 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
########################################################################
# File name: network.py
# This file is part of: aioxmpp
#
# LICENSE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this program.  If not, see
# <http://www.gnu.org/licenses/>.
#
########################################################################
"""
:mod:`~aioxmpp.network` --- DNS resolution utilities
####################################################

This module uses :mod:`dns` to handle DNS queries.

.. versionchanged:: 0.5.4

   The module was completely rewritten in 0.5.4. The documented API stayed
   mostly the same though.

Configure the resolver
======================

.. versionadded:: 0.5.4

   The whole thread-local resolver thing was added in 0.5.4. This includes the
   magic to re-configure the used resolver when a query fails.

The module uses a thread-local resolver instance. It can be accessed using
:func:`get_resolver`. Re-read of the system-wide resolver configuration can be
forced by calling :func:`reconfigure_resolver`. To configure a custom resolver
instance, use :func:`set_resolver`.

By setting a custom resolver instance, the facilities which *automatically*
reconfigure the resolver whenever DNS timeouts occur are disabled.

.. note::

   Currently, there is no way to set a resolver per XMPP client. If such a way
   is desired, feel free to open a bug against :mod:`aioxmpp`. I cannot really
   imagine such a situation, but if you encounter one, please let me know.

.. autofunction:: get_resolver

.. autofunction:: reconfigure_resolver

.. autofunction:: set_resolver

Querying records
================

In addition to using the :class:`dns.resolver.Resolver` instance returned by
:func:`get_resolver`, one can also use :func:`repeated_query`. The latter takes
care of re-trying the query up to a configurable amount of times. It will also
automatically call :func:`reconfigure_resolver` (unless a custom resolver has
been set) if a timeout occurs and switch to TCP if problems persist.

.. autofunction:: repeated_query

SRV records
===========

.. autofunction:: find_xmpp_host_addr

.. autofunction:: lookup_srv

.. autofunction:: group_and_order_srv_records


"""

import asyncio
import functools
import itertools
import logging
import random
import threading

import dns
import dns.flags
import dns.resolver

logger = logging.getLogger(__name__)

_state = threading.local()


class ValidationError(Exception):
    pass


def get_resolver():
    """
    Return the thread-local :class:`dns.resolver.Resolver` instance used by
    :mod:`aioxmpp`.
    """

    global _state
    if not hasattr(_state, "resolver"):
        reconfigure_resolver()
    return _state.resolver


class DummyResolver:
    def set_flags(self, *args, **kwargs):
        # noop
        pass

    def query(self, *args, **kwargs):
        raise dns.resolver.NoAnswer


def reconfigure_resolver():
    """
    Reset the resolver configured for this thread to a fresh instance. This
    essentially re-reads the system-wide resolver configuration.

    If a custom resolver has been set using :func:`set_resolver`, the flag
    indicating that no automatic re-configuration shall take place is cleared.
    """

    global _state
    try:
        _state.resolver = dns.resolver.Resolver()
    except dns.resolver.NoResolverConfiguration:
        _state.resolver = DummyResolver()
    _state.overridden_resolver = False


def set_resolver(resolver):
    """
    Replace the current thread-local resolver (which can be accessed using
    :func:`get_resolver`) with `resolver`.

    This also sets an internal flag which prohibits the automatic calling of
    :func:`reconfigure_resolver` from :func:`repeated_query`. To re-allow
    automatic reconfiguration, call :func:`reconfigure_resolver`.
    """

    global _state
    _state.resolver = resolver
    _state.overridden_resolver = True


async def repeated_query(qname, rdtype,
                         nattempts=None,
                         resolver=None,
                         require_ad=False,
                         executor=None):
    """
    Repeatedly fire a DNS query until either the number of allowed attempts
    (`nattempts`) is excedeed or a non-error result is returned (NXDOMAIN is
    a non-error result).

    If `nattempts` is :data:`None`, it is set to 3 if `resolver` is
    :data:`None` and to 2 otherwise. This way, no query is made without a
    possible change to a local parameter. (When using the thread-local
    resolver, it will be re-configured after the first failed query and after
    the second failed query, TCP is used. With a fixed resolver, TCP is used
    after the first failed query.)

    `qname` must be the (IDNA encoded, as :class:`bytes`) name to query,
    `rdtype` the record type to query for. If `resolver` is not :data:`None`,
    it must be a DNSPython :class:`dns.resolver.Resolver` instance; if it is
    :data:`None`, the resolver obtained from :func:`get_resolver` is used.

    If `require_ad` is :data:`True`, the peer resolver is asked to do DNSSEC
    validation and if the AD flag is missing in the response,
    :class:`ValueError` is raised. If `require_ad` is :data:`False`, the
    resolver is asked to do DNSSEC validation nevertheless, but missing
    validation (in constrast to failed validation) is not an error.

    .. note::

       This function modifies the flags of the `resolver` instance, no matter
       if it uses the thread-local resolver instance or the resolver passed as
       an argument.

    If the first query fails and `resolver` is :data:`None` and the
    thread-local resolver has not been overridden with :func:`set_resolver`,
    :func:`reconfigure_resolver` is called and the query is re-attempted
    immediately.

    If the next query after reconfiguration of the resolver (if the
    preconditions for resolver reconfigurations are not met, this applies to
    the first failing query), :func:`repeated_query` switches to TCP.

    If no result is received before the number of allowed attempts is exceeded,
    :class:`TimeoutError` is raised.

    Return the result set or :data:`None` if the domain does not exist.

    This is a coroutine; the query is executed in an `executor` using the
    :meth:`asyncio.BaseEventLoop.run_in_executor` of the current event loop. By
    default, the default executor provided by the event loop is used, but it
    can be overridden using the `executor` argument.

    If the used resolver raises :class:`dns.resolver.NoNameservers`
    (semantically, that no nameserver was able to answer the request), this
    function suspects that DNSSEC validation failed, as responding with
    SERVFAIL is what unbound does. To test that case, a simple check is made:
    the query is repeated, but with a flag set which indicates that we would
    like to do the validation ourselves. If that query succeeds, we assume that
    the error is in fact due to DNSSEC validation failure and raise
    :class:`ValidationError`. Otherwise, the answer is discarded and the
    :class:`~dns.resolver.NoNameservers` exception is treated as normal
    timeout. If the exception re-occurs in the second query, it is re-raised,
    as it indicates a serious configuration problem.
    """
    global _state

    loop = asyncio.get_event_loop()

    # tlr = thread-local resolver
    use_tlr = False
    if resolver is None:
        resolver = get_resolver()
        use_tlr = not _state.overridden_resolver

    if nattempts is None:
        if use_tlr:
            nattempts = 3
        else:
            nattempts = 2

    if nattempts <= 0:
        raise ValueError("query cannot succeed with non-positive amount "
                         "of attempts")

    qname = qname.decode("ascii")

    def handle_timeout():
        nonlocal use_tlr, resolver, use_tcp
        if use_tlr and i == 0:
            reconfigure_resolver()
            resolver = get_resolver()
        else:
            use_tcp = True

    use_tcp = False
    for i in range(nattempts):
        resolver.set_flags(dns.flags.RD | dns.flags.AD)
        try:
            answer = await loop.run_in_executor(
                executor,
                functools.partial(
                    resolver.query,
                    qname,
                    rdtype,
                    tcp=use_tcp
                )
            )

            if require_ad and not (answer.response.flags & dns.flags.AD):
                raise ValueError("DNSSEC validation not available")
        except (TimeoutError, dns.resolver.Timeout):
            handle_timeout()
            continue
        except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN):
            return None
        except (dns.resolver.NoNameservers):
            # make sure we have the correct config
            if use_tlr and i == 0:
                reconfigure_resolver()
                resolver = get_resolver()
                continue
            resolver.set_flags(dns.flags.RD | dns.flags.AD | dns.flags.CD)
            try:
                await loop.run_in_executor(
                    executor,
                    functools.partial(
                        resolver.query,
                        qname,
                        rdtype,
                        tcp=use_tcp,
                        raise_on_no_answer=False
                    ))
            except (dns.resolver.Timeout, TimeoutError):
                handle_timeout()
                continue
            except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
                pass
            raise ValidationError(
                "nameserver error, most likely DNSSEC validation failed",
            )
        break
    else:
        raise TimeoutError()

    return answer


async def lookup_srv(domain: bytes, service: str, transport: str = "tcp",
                     **kwargs):
    """
    Query the DNS for SRV records describing how the given `service` over the
    given `transport` is implemented for the given `domain`. `domain` must be
    an IDNA-encoded :class:`bytes` object; `service` must be a normal
    :class:`str`.

    Keyword arguments are passed to :func:`repeated_query`.

    Return a list of tuples ``(prio, weight, (hostname, port))``, where
    `hostname` is a IDNA-encoded :class:`bytes` object containing the hostname
    obtained from the SRV record. The other fields are also as obtained from
    the SRV records. The trailing dot is stripped from the `hostname`.

    If the DNS query returns an empty result, :data:`None` is returned. If any
    of the found SRV records has the root zone (``.``) as `hostname`, this
    indicates that the service is not available at the given `domain` and
    :class:`ValueError` is raised.
    """

    record = b".".join([
        b"_" + service.encode("ascii"),
        b"_" + transport.encode("ascii"),
        domain])

    answer = await repeated_query(
        record,
        dns.rdatatype.SRV,
        **kwargs)

    if answer is None:
        return None

    items = [
        (rec.priority, rec.weight, (str(rec.target), rec.port))
        for rec in answer
    ]

    for i, (prio, weight, (host, port)) in enumerate(items):
        if host == ".":
            raise ValueError(
                "protocol {!r} over {!r} not supported at {!r}".format(
                    service,
                    transport,
                    domain
                )
            )

        items[i] = (prio, weight, (
            host.rstrip(".").encode("ascii"),
            port))

    return items


async def lookup_tlsa(hostname, port, transport="tcp", require_ad=True,
                      **kwargs):
    """
    Query the DNS for TLSA records describing the certificates and/or keys to
    expect when contacting `hostname` at the given `port` over the given
    `transport`. `hostname` must be an IDNA-encoded :class:`bytes` object.

    The keyword arguments are passed to :func:`repeated_query`; `require_ad`
    defaults to :data:`True` here.

    Return a list of tuples ``(usage, selector, mtype, cert)`` which contains
    the information from the TLSA records.

    If no data is returned by the query, :data:`None` is returned instead.
    """
    record = b".".join([
        b"_" + str(port).encode("ascii"),
        b"_" + transport.encode("ascii"),
        hostname
    ])

    answer = await repeated_query(
        record,
        dns.rdatatype.TLSA,
        require_ad=require_ad,
        **kwargs)

    if answer is None:
        return None

    items = [
        (rec.usage, rec.selector, rec.mtype, rec.cert)
        for rec in answer
    ]

    return items


def group_and_order_srv_records(all_records, rng=None):
    """
    Order a list of SRV record information (as returned by :func:`lookup_srv`)
    and group and order them as specified by the RFC.

    Return an iterable, yielding each ``(hostname, port)`` tuple inside the
    SRV records in the order specified by the RFC. For hosts with the same
    priority, the given `rng` implementation is used (if none is given, the
    :mod:`random` module is used).
    """
    rng = rng or random

    all_records.sort(key=lambda x: x[:2])

    for priority, records in itertools.groupby(
            all_records,
            lambda x: x[0]):

        records = list(records)
        total_weight = sum(
            weight
            for _, weight, _ in records)

        while records:
            if len(records) == 1:
                yield records[0][-1]
                break

            value = rng.randint(0, total_weight)
            running_weight_sum = 0
            for i, (_, weight, addr) in enumerate(records):
                running_weight_sum += weight
                if running_weight_sum >= value:
                    yield addr
                    del records[i]
                    total_weight -= weight
                    break


async def find_xmpp_host_addr(loop, domain, attempts=3):
    domain = domain.encode("IDNA")

    items = await lookup_srv(
        service="xmpp-client",
        domain=domain,
        nattempts=attempts
    )

    if items is not None:
        return items

    return [(0, 0, (domain, 5222))]


async def find_xmpp_host_tlsa(loop, domain, attempts=3, require_ad=True):
    domain = domain.encode("IDNA")

    items = await lookup_tlsa(
        domain=domain,
        port=5222,
        nattempts=attempts,
        require_ad=require_ad
    )

    if items is not None:
        return items
    return []