File: dns.py

package info (click to toggle)
firefox 149.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,767,760 kB
  • sloc: cpp: 7,416,064; javascript: 6,752,859; ansic: 3,774,850; python: 1,250,473; xml: 641,578; asm: 439,191; java: 186,617; sh: 56,634; makefile: 18,856; objc: 13,092; perl: 12,763; pascal: 5,960; yacc: 4,583; cs: 3,846; lex: 1,720; ruby: 1,002; php: 436; lisp: 258; awk: 105; sql: 66; sed: 53; csh: 10; exp: 6
file content (139 lines) | stat: -rw-r--r-- 4,234 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations

import logging
import threading
from copy import copy
from itertools import product
from typing import TYPE_CHECKING

from dnslib import QTYPE, RCODE, RR, A
from dnslib.label import DNSLabel
from dnslib.server import BaseResolver, DNSLogger, DNSServer

if TYPE_CHECKING:
    from typing import Any, Callable

    from dnslib.dns import DNSRecord
    from dnslib.server import DNSHandler
    from wptserve.config import Config

logger = logging.getLogger()


class Resolver(BaseResolver):  # type: ignore[misc]
    def __init__(
        self,
        allowed_hosts: set[str],
        destination: str,
        unknown_labels: set[str] | None = None,
        ttl: int = 3600,
    ) -> None:
        super().__init__()
        self.unknown_labels = unknown_labels

        self.zone = [
            (rr.rname, QTYPE[rr.rtype], rr)
            for rr in (
                RR(
                    host if host.endswith(".") else host + ".",
                    rtype=QTYPE.A,
                    rdata=A(destination),
                    ttl=ttl,
                )
                for host in allowed_hosts
            )
        ]

    def resolve(self, request: DNSRecord, handler: DNSHandler) -> DNSRecord:
        reply = request.reply()
        qname = request.q.qname
        qtype = QTYPE[request.q.qtype]

        if self.unknown_labels:
            qlabels = {DNSLabel(label) for label in str(qname).rstrip(".").split(".")}
            if any(DNSLabel(label) in qlabels for label in self.unknown_labels):
                reply.header.rcode = RCODE.NXDOMAIN
                return reply

        has_answer = False
        for name, rtype, rr in self.zone:
            if qname.matchGlob(name):
                if qtype == rtype:
                    a = copy(rr)
                    a.rname = qname
                    reply.add_answer(a)
                    has_answer = True

        if has_answer:
            return reply

        reply.header.rcode = RCODE.NXDOMAIN
        return reply


class LoggingDNSLogger(DNSLogger):  # type: ignore[misc]
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        self.logf: Callable[[str], None]
        super().__init__(*args, logf=logger.debug, **kwargs)

    def log_error(self, *args: Any, **kwargs: Any) -> None:
        old_logf = self.logf
        self.logf = logger.error
        try:
            super().log_error(*args, **kwargs)
        finally:
            self.logf = old_logf


class DNSServerDaemon:
    def __init__(
        self,
        host: str,
        port: int,
        bind_address: bool,
        config: Config,
        wildcards: int | None = None,
        **kwargs: Any,
    ) -> None:
        if wildcards == 0:
            wildcards = max(s.count(".") for s in config["all_domains_set"])

        if wildcards is not None:
            hosts = {config["browser_host"]} | set(config["alternate_hosts"].values())
            wildcard_hosts = {".".join(("*",) * i) for i in range(1, wildcards + 1)}
            resolver_hosts = hosts | {
                ".".join(x) + "." for x in product(wildcard_hosts, hosts)
            }
            resolver = Resolver(
                resolver_hosts, config["server_host"], config["not_subdomains"]
            )
        else:
            resolver = Resolver(config["domains_set"], config["server_host"])

        self.server = DNSServer(
            resolver,
            address=host if bind_address else "",
            port=port,
            logger=LoggingDNSLogger(),
        )

        self.server_thread: threading.Thread | None = None

    def start(self) -> None:
        self.started = True
        self.server_thread = threading.Thread(target=self.server.start)
        self.server_thread.setDaemon(True)  # don't hang on exit
        self.server_thread.start()

    def stop(self) -> None:
        if self.started:
            assert self.server_thread is not None
            try:
                self.server.shutdown()
                self.server.server_close()
                self.server_thread.join()
                self.server_thread = None
            except AttributeError:
                pass
            self.started = False
        self.server = None