File: controller.py

package info (click to toggle)
python-aiosmtpd 1.4.3-1.1%2Bdeb12u1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 1,060 kB
  • sloc: python: 7,850; makefile: 158; sh: 5
file content (525 lines) | stat: -rw-r--r-- 18,846 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
# Copyright 2014-2021 The aiosmtpd Developers
# SPDX-License-Identifier: Apache-2.0

import asyncio
import errno
import os
import ssl
import sys
import threading
import time
from abc import ABCMeta, abstractmethod
from contextlib import ExitStack
from pathlib import Path
from socket import AF_INET6, SOCK_STREAM, create_connection, has_ipv6
from socket import socket as makesock
from socket import timeout as socket_timeout

try:
    from socket import AF_UNIX
except ImportError:  # pragma: on-not-win32
    AF_UNIX = None
from typing import Any, Coroutine, Dict, Optional, Union

if sys.version_info >= (3, 8):
    from typing import Literal  # pragma: py-lt-38
else:  # pragma: py-ge-38
    from typing_extensions import Literal
from warnings import warn

from public import public

from aiosmtpd.smtp import SMTP

AsyncServer = asyncio.base_events.Server

DEFAULT_READY_TIMEOUT: float = 5.0


@public
class IP6_IS:
    # Apparently errno.E* constants adapts to the OS, so on Windows they will
    # automatically use the WSAE* constants
    NO = {errno.EADDRNOTAVAIL, errno.EAFNOSUPPORT}
    YES = {errno.EADDRINUSE}


def _has_ipv6() -> bool:
    # Helper function to assist in mocking
    return has_ipv6


@public
def get_localhost() -> Literal["::1", "127.0.0.1"]:
    """Returns numeric address to localhost depending on IPv6 availability"""
    # Ref:
    #  - https://github.com/urllib3/urllib3/pull/611#issuecomment-100954017
    #  - https://github.com/python/cpython/blob/ :
    #    - v3.6.13/Lib/test/support/__init__.py#L745-L758
    #    - v3.9.1/Lib/test/support/socket_helper.py#L124-L137
    if not _has_ipv6():
        # socket.has_ipv6 only tells us of current Python's IPv6 support, not the
        # system's. But if the current Python does not support IPv6, it's pointless to
        # explore further.
        return "127.0.0.1"
    try:
        with makesock(AF_INET6, SOCK_STREAM) as sock:
            sock.bind(("::1", 0))
        # If we reach this point, that means we can successfully bind ::1 (on random
        # unused port), so IPv6 is definitely supported
        return "::1"
    except OSError as e:
        if e.errno in IP6_IS.NO:
            return "127.0.0.1"
        if e.errno in IP6_IS.YES:
            # We shouldn't ever get these errors, but if we do, that means IPv6 is
            # supported
            return "::1"
        # Other kinds of errors MUST be raised so we can inspect
        raise


def _server_to_client_ssl_ctx(server_ctx: ssl.SSLContext) -> ssl.SSLContext:
    """
    Given an SSLContext object with TLS_SERVER_PROTOCOL return a client
    context that can connect to the server.
    """
    client_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
    client_ctx.options = server_ctx.options
    client_ctx.check_hostname = False
    # We do not verify the ssl cert for the server here simply because this
    # is a local connection to poke at the server for it to do its lazy
    # initialization sequence. The only purpose of this client context
    # is to make a connection to the *local* server created using the same
    # code. That is also the reason why we disable cert verification below
    # and the flake8 check for the same.
    client_ctx.verify_mode = ssl.CERT_NONE    # noqa: DUO122
    return client_ctx


class _FakeServer(asyncio.StreamReaderProtocol):
    """
    Returned by _factory_invoker() in lieu of an SMTP instance in case
    factory() failed to instantiate an SMTP instance.
    """

    def __init__(self, loop: asyncio.AbstractEventLoop):
        # Imitate what SMTP does
        super().__init__(
            asyncio.StreamReader(loop=loop),
            client_connected_cb=self._client_connected_cb,
            loop=loop,
        )

    def _client_connected_cb(
            self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        pass


@public
class BaseController(metaclass=ABCMeta):
    smtpd = None
    server: Optional[AsyncServer] = None
    server_coro: Optional[Coroutine] = None
    _factory_invoked: threading.Event = None

    def __init__(
        self,
        handler: Any,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        *,
        ssl_context: Optional[ssl.SSLContext] = None,
        # SMTP parameters
        server_hostname: Optional[str] = None,
        **SMTP_parameters,
    ):
        self.handler = handler
        if loop is None:
            self.loop = asyncio.new_event_loop()
        else:
            self.loop = loop
        self.ssl_context = ssl_context
        self.SMTP_kwargs: Dict[str, Any] = {}
        if "server_kwargs" in SMTP_parameters:
            warn(
                "server_kwargs will be removed in version 2.0. "
                "Just specify the keyword arguments to forward to SMTP "
                "as kwargs to this __init__ method.",
                DeprecationWarning,
            )
            self.SMTP_kwargs = SMTP_parameters.pop("server_kwargs")
        self.SMTP_kwargs.update(SMTP_parameters)
        if server_hostname:
            self.SMTP_kwargs["hostname"] = server_hostname
        # Emulate previous behavior of defaulting enable_SMTPUTF8 to True
        # It actually conflicts with SMTP class's default, but the reasoning is
        # discussed in the docs.
        self.SMTP_kwargs.setdefault("enable_SMTPUTF8", True)
        #
        self._factory_invoked = threading.Event()

    def factory(self):
        """Subclasses can override this to customize the handler/server creation."""
        return SMTP(self.handler, **self.SMTP_kwargs)

    def _factory_invoker(self) -> Union[SMTP, _FakeServer]:
        """Wraps factory() to catch exceptions during instantiation"""
        try:
            self.smtpd = self.factory()
            if self.smtpd is None:
                raise RuntimeError("factory() returned None")
            return self.smtpd
        except Exception as err:
            self._thread_exception = err
            return _FakeServer(self.loop)
        finally:
            self._factory_invoked.set()

    @abstractmethod
    def _create_server(self) -> Coroutine:
        """
        Overridden by subclasses to actually perform the async binding to the
        listener endpoint. When overridden, MUST refer the _factory_invoker() method.
        """
        raise NotImplementedError

    def _cleanup(self):
        """Reset internal variables to prevent contamination"""
        self._thread_exception = None
        self._factory_invoked.clear()
        self.server_coro = None
        self.server = None
        self.smtpd = None

    def cancel_tasks(self, stop_loop: bool = True):
        """
        Convenience method to stop the loop and cancel all tasks.
        Use loop.call_soon_threadsafe() to invoke this.
        """
        if stop_loop:  # pragma: nobranch
            self.loop.stop()
        try:
            _all_tasks = asyncio.all_tasks  # pytype: disable=module-attr
        except AttributeError:  # pragma: py-gt-36
            _all_tasks = asyncio.Task.all_tasks  # pytype: disable=attribute-error
        for task in _all_tasks(self.loop):
            # This needs to be invoked in a thread-safe way
            task.cancel()


@public
class BaseThreadedController(BaseController, metaclass=ABCMeta):
    _thread: Optional[threading.Thread] = None
    _thread_exception: Optional[Exception] = None

    def __init__(
        self,
        handler: Any,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        *,
        ready_timeout: float = DEFAULT_READY_TIMEOUT,
        ssl_context: Optional[ssl.SSLContext] = None,
        # SMTP parameters
        server_hostname: Optional[str] = None,
        **SMTP_parameters,
    ):
        super().__init__(
            handler,
            loop,
            ssl_context=ssl_context,
            server_hostname=server_hostname,
            **SMTP_parameters,
        )
        self.ready_timeout = float(
            os.getenv("AIOSMTPD_CONTROLLER_TIMEOUT", ready_timeout)
        )

    @abstractmethod
    def _trigger_server(self):
        """
        Overridden by subclasses to trigger asyncio to actually initialize the SMTP
        class (it's lazy initialization, done only on initial connection).
        """
        raise NotImplementedError

    def _run(self, ready_event: threading.Event) -> None:
        asyncio.set_event_loop(self.loop)
        try:
            # Need to do two-step assignments here to ensure IDEs can properly
            # detect the types of the vars. Cannot use `assert isinstance`, because
            # Python 3.6 in asyncio debug mode has a bug wherein CoroWrapper is not
            # an instance of Coroutine
            self.server_coro = self._create_server()
            srv: AsyncServer = self.loop.run_until_complete(self.server_coro)
            self.server = srv
        except Exception as error:  # pragma: on-wsl
            # Usually will enter this part only if create_server() cannot bind to the
            # specified host:port.
            #
            # Somehow WSL 1.0 (Windows Subsystem for Linux) allows multiple
            # listeners on one port?!
            # That is why we add "pragma: on-wsl" there, so this block will not affect
            # coverage on WSL 1.0.
            self._thread_exception = error
            return
        self.loop.call_soon(ready_event.set)
        self.loop.run_forever()
        # We reach this point when loop is ended (by external code)
        # Perform some stoppages to ensure endpoint no longer bound.
        self.server.close()
        self.loop.run_until_complete(self.server.wait_closed())
        self.loop.close()
        self.server = None

    def start(self):
        """
        Start a thread and run the asyncio event loop in that thread
        """
        assert self._thread is None, "SMTP daemon already running"
        self._factory_invoked.clear()

        ready_event = threading.Event()
        self._thread = threading.Thread(target=self._run, args=(ready_event,))
        self._thread.daemon = True
        self._thread.start()
        # Wait a while until the server is responding.
        start = time.monotonic()
        if not ready_event.wait(self.ready_timeout):
            # An exception within self._run will also result in ready_event not set
            # So, we first test for that, before raising TimeoutError
            if self._thread_exception is not None:  # pragma: on-wsl
                # See comment about WSL1.0 in the _run() method
                raise self._thread_exception
            else:
                raise TimeoutError(
                    "SMTP server failed to start within allotted time. "
                    "This might happen if the system is too busy. "
                    "Try increasing the `ready_timeout` parameter."
                )
        respond_timeout = self.ready_timeout - (time.monotonic() - start)

        # Apparently create_server invokes factory() "lazily", so exceptions in
        # factory() go undetected. To trigger factory() invocation we need to open
        # a connection to the server and 'exchange' some traffic.
        try:
            self._trigger_server()
        except socket_timeout:
            # We totally don't care of timeout experienced by _testconn,
            pass
        except Exception:
            # Raise other exceptions though
            raise
        if not self._factory_invoked.wait(respond_timeout):
            raise TimeoutError(
                "SMTP server started, but not responding within allotted time. "
                "This might happen if the system is too busy. "
                "Try increasing the `ready_timeout` parameter."
            )
        if self._thread_exception is not None:
            raise self._thread_exception

        # Defensive
        if self.smtpd is None:
            raise RuntimeError("Unknown Error, failed to init SMTP server")

    def stop(self, no_assert: bool = False):
        """
        Stop the loop, the tasks in the loop, and terminate the thread as well.
        """
        assert no_assert or self._thread is not None, "SMTP daemon not running"
        self.loop.call_soon_threadsafe(self.cancel_tasks)
        if self._thread is not None:
            self._thread.join()
            self._thread = None
        self._cleanup()


@public
class BaseUnthreadedController(BaseController, metaclass=ABCMeta):
    def __init__(
        self,
        handler: Any,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        *,
        ssl_context: Optional[ssl.SSLContext] = None,
        # SMTP parameters
        server_hostname: Optional[str] = None,
        **SMTP_parameters,
    ):
        super().__init__(
            handler,
            loop,
            ssl_context=ssl_context,
            server_hostname=server_hostname,
            **SMTP_parameters,
        )
        self.ended = threading.Event()

    def begin(self):
        """
        Sets up the asyncio server task and inject it into the asyncio event loop.
        Does NOT actually start the event loop itself.
        """
        asyncio.set_event_loop(self.loop)
        # Need to do two-step assignments here to ensure IDEs can properly
        # detect the types of the vars. Cannot use `assert isinstance`, because
        # Python 3.6 in asyncio debug mode has a bug wherein CoroWrapper is not
        # an instance of Coroutine
        self.server_coro = self._create_server()
        srv: AsyncServer = self.loop.run_until_complete(self.server_coro)
        self.server = srv

    async def finalize(self):
        """
        Perform orderly closing of the server listener.
        NOTE: This is an async method; await this from an async or use
        loop.create_task() (if loop is still running), or
        loop.run_until_complete() (if loop has stopped)
        """
        self.ended.clear()
        server = self.server
        server.close()
        await server.wait_closed()
        self.server_coro.close()
        self._cleanup()
        self.ended.set()

    def end(self):
        """
        Convenience method to asynchronously invoke finalize().
        Consider using loop.call_soon_threadsafe to invoke this method, especially
        if your loop is running in a different thread. You can afterwards .wait() on
        ended attribute (a threading.Event) to check for completion, if needed.
        """
        self.ended.clear()
        if self.loop.is_running():
            self.loop.create_task(self.finalize())
        else:
            self.loop.run_until_complete(self.finalize())


@public
class InetMixin(BaseController, metaclass=ABCMeta):
    def __init__(
        self,
        handler: Any,
        hostname: Optional[str] = None,
        port: int = 8025,
        loop: Optional[asyncio.AbstractEventLoop] = None,
        **kwargs,
    ):
        super().__init__(
            handler,
            loop,
            **kwargs,
        )
        self._localhost = get_localhost()
        self.hostname = self._localhost if hostname is None else hostname
        self.port = port

    def _create_server(self) -> Coroutine:
        """
        Creates a 'server task' that listens on an INET host:port.
        Does NOT actually start the protocol object itself;
        _factory_invoker() is only called upon fist connection attempt.
        """
        return self.loop.create_server(
            self._factory_invoker,
            host=self.hostname,
            port=self.port,
            ssl=self.ssl_context,
        )

    def _trigger_server(self):
        """
        Opens a socket connection to the newly launched server, wrapping in an SSL
        Context if necessary, and read some data from it to ensure that factory()
        gets invoked.
        """
        # At this point, if self.hostname is Falsy, it most likely is "" (bind to all
        # addresses). In such case, it should be safe to connect to localhost)
        hostname = self.hostname or self._localhost
        with ExitStack() as stk:
            s = stk.enter_context(create_connection((hostname, self.port), 1.0))
            if self.ssl_context:
                client_ctx = _server_to_client_ssl_ctx(self.ssl_context)
                s = stk.enter_context(client_ctx.wrap_socket(s))
            s.recv(1024)


@public
class UnixSocketMixin(BaseController, metaclass=ABCMeta):  # pragma: no-unixsock
    def __init__(
        self,
        handler: Any,
        unix_socket: Union[str, Path],
        loop: Optional[asyncio.AbstractEventLoop] = None,
        **kwargs,
    ):
        super().__init__(
            handler,
            loop,
            **kwargs,
        )
        self.unix_socket = str(unix_socket)

    def _create_server(self) -> Coroutine:
        """
        Creates a 'server task' that listens on a Unix Socket file.
        Does NOT actually start the protocol object itself;
        _factory_invoker() is only called upon fist connection attempt.
        """
        return self.loop.create_unix_server(
            self._factory_invoker,
            path=self.unix_socket,
            ssl=self.ssl_context,
        )

    def _trigger_server(self):
        """
        Opens a socket connection to the newly launched server, wrapping in an SSL
        Context if necessary, and read some data from it to ensure that factory()
        gets invoked.
        """
        with ExitStack() as stk:
            s: makesock = stk.enter_context(makesock(AF_UNIX, SOCK_STREAM))
            s.connect(self.unix_socket)
            if self.ssl_context:
                client_ctx = _server_to_client_ssl_ctx(self.ssl_context)
                s = stk.enter_context(client_ctx.wrap_socket(s))
            s.recv(1024)


@public
class Controller(InetMixin, BaseThreadedController):
    """Provides a multithreaded controller that listens on an INET endpoint"""

    def _trigger_server(self):
        # Prevent confusion on which _trigger_server() to invoke.
        # Or so LGTM.com claimed
        InetMixin._trigger_server(self)


@public
class UnixSocketController(  # pragma: no-unixsock
    UnixSocketMixin, BaseThreadedController
):
    """Provides a multithreaded controller that listens on a Unix Socket file"""

    def _trigger_server(self):  # pragma: no-unixsock
        # Prevent confusion on which _trigger_server() to invoke.
        # Or so LGTM.com claimed
        UnixSocketMixin._trigger_server(self)


@public
class UnthreadedController(InetMixin, BaseUnthreadedController):
    """Provides an unthreaded controller that listens on an INET endpoint"""


@public
class UnixSocketUnthreadedController(  # pragma: no-unixsock
    UnixSocketMixin, BaseUnthreadedController
):
    """Provides an unthreaded controller that listens on a Unix Socket file"""