1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
|
# Copyright 2014-2021 The aiosmtpd Developers
# SPDX-License-Identifier: Apache-2.0
import asyncio
import errno
import os
import ssl
import sys
import threading
import time
from abc import ABCMeta, abstractmethod
from contextlib import ExitStack
from pathlib import Path
from socket import AF_INET6, SOCK_STREAM, create_connection, has_ipv6
from socket import socket as makesock
from socket import timeout as socket_timeout
try:
from socket import AF_UNIX
except ImportError: # pragma: on-not-win32
AF_UNIX = None
from typing import Any, Coroutine, Dict, Optional, Union
if sys.version_info >= (3, 8):
from typing import Literal # pragma: py-lt-38
else: # pragma: py-ge-38
from typing_extensions import Literal
from warnings import warn
from public import public
from aiosmtpd.smtp import SMTP
AsyncServer = asyncio.base_events.Server
DEFAULT_READY_TIMEOUT: float = 5.0
@public
class IP6_IS:
# Apparently errno.E* constants adapts to the OS, so on Windows they will
# automatically use the WSAE* constants
NO = {errno.EADDRNOTAVAIL, errno.EAFNOSUPPORT}
YES = {errno.EADDRINUSE}
def _has_ipv6() -> bool:
# Helper function to assist in mocking
return has_ipv6
@public
def get_localhost() -> Literal["::1", "127.0.0.1"]:
"""Returns numeric address to localhost depending on IPv6 availability"""
# Ref:
# - https://github.com/urllib3/urllib3/pull/611#issuecomment-100954017
# - https://github.com/python/cpython/blob/ :
# - v3.6.13/Lib/test/support/__init__.py#L745-L758
# - v3.9.1/Lib/test/support/socket_helper.py#L124-L137
if not _has_ipv6():
# socket.has_ipv6 only tells us of current Python's IPv6 support, not the
# system's. But if the current Python does not support IPv6, it's pointless to
# explore further.
return "127.0.0.1"
try:
with makesock(AF_INET6, SOCK_STREAM) as sock:
sock.bind(("::1", 0))
# If we reach this point, that means we can successfully bind ::1 (on random
# unused port), so IPv6 is definitely supported
return "::1"
except OSError as e:
if e.errno in IP6_IS.NO:
return "127.0.0.1"
if e.errno in IP6_IS.YES:
# We shouldn't ever get these errors, but if we do, that means IPv6 is
# supported
return "::1"
# Other kinds of errors MUST be raised so we can inspect
raise
def _server_to_client_ssl_ctx(server_ctx: ssl.SSLContext) -> ssl.SSLContext:
"""
Given an SSLContext object with TLS_SERVER_PROTOCOL return a client
context that can connect to the server.
"""
client_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
client_ctx.options = server_ctx.options
client_ctx.check_hostname = False
# We do not verify the ssl cert for the server here simply because this
# is a local connection to poke at the server for it to do its lazy
# initialization sequence. The only purpose of this client context
# is to make a connection to the *local* server created using the same
# code. That is also the reason why we disable cert verification below
# and the flake8 check for the same.
client_ctx.verify_mode = ssl.CERT_NONE # noqa: DUO122
return client_ctx
class _FakeServer(asyncio.StreamReaderProtocol):
"""
Returned by _factory_invoker() in lieu of an SMTP instance in case
factory() failed to instantiate an SMTP instance.
"""
def __init__(self, loop: asyncio.AbstractEventLoop):
# Imitate what SMTP does
super().__init__(
asyncio.StreamReader(loop=loop),
client_connected_cb=self._client_connected_cb,
loop=loop,
)
def _client_connected_cb(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
pass
@public
class BaseController(metaclass=ABCMeta):
smtpd = None
server: Optional[AsyncServer] = None
server_coro: Optional[Coroutine] = None
_factory_invoked: threading.Event = None
def __init__(
self,
handler: Any,
loop: Optional[asyncio.AbstractEventLoop] = None,
*,
ssl_context: Optional[ssl.SSLContext] = None,
# SMTP parameters
server_hostname: Optional[str] = None,
**SMTP_parameters,
):
self.handler = handler
if loop is None:
self.loop = asyncio.new_event_loop()
else:
self.loop = loop
self.ssl_context = ssl_context
self.SMTP_kwargs: Dict[str, Any] = {}
if "server_kwargs" in SMTP_parameters:
warn(
"server_kwargs will be removed in version 2.0. "
"Just specify the keyword arguments to forward to SMTP "
"as kwargs to this __init__ method.",
DeprecationWarning,
)
self.SMTP_kwargs = SMTP_parameters.pop("server_kwargs")
self.SMTP_kwargs.update(SMTP_parameters)
if server_hostname:
self.SMTP_kwargs["hostname"] = server_hostname
# Emulate previous behavior of defaulting enable_SMTPUTF8 to True
# It actually conflicts with SMTP class's default, but the reasoning is
# discussed in the docs.
self.SMTP_kwargs.setdefault("enable_SMTPUTF8", True)
#
self._factory_invoked = threading.Event()
def factory(self):
"""Subclasses can override this to customize the handler/server creation."""
return SMTP(self.handler, **self.SMTP_kwargs)
def _factory_invoker(self) -> Union[SMTP, _FakeServer]:
"""Wraps factory() to catch exceptions during instantiation"""
try:
self.smtpd = self.factory()
if self.smtpd is None:
raise RuntimeError("factory() returned None")
return self.smtpd
except Exception as err:
self._thread_exception = err
return _FakeServer(self.loop)
finally:
self._factory_invoked.set()
@abstractmethod
def _create_server(self) -> Coroutine:
"""
Overridden by subclasses to actually perform the async binding to the
listener endpoint. When overridden, MUST refer the _factory_invoker() method.
"""
raise NotImplementedError
def _cleanup(self):
"""Reset internal variables to prevent contamination"""
self._thread_exception = None
self._factory_invoked.clear()
self.server_coro = None
self.server = None
self.smtpd = None
def cancel_tasks(self, stop_loop: bool = True):
"""
Convenience method to stop the loop and cancel all tasks.
Use loop.call_soon_threadsafe() to invoke this.
"""
if stop_loop: # pragma: nobranch
self.loop.stop()
try:
_all_tasks = asyncio.all_tasks # pytype: disable=module-attr
except AttributeError: # pragma: py-gt-36
_all_tasks = asyncio.Task.all_tasks # pytype: disable=attribute-error
for task in _all_tasks(self.loop):
# This needs to be invoked in a thread-safe way
task.cancel()
@public
class BaseThreadedController(BaseController, metaclass=ABCMeta):
_thread: Optional[threading.Thread] = None
_thread_exception: Optional[Exception] = None
def __init__(
self,
handler: Any,
loop: Optional[asyncio.AbstractEventLoop] = None,
*,
ready_timeout: float = DEFAULT_READY_TIMEOUT,
ssl_context: Optional[ssl.SSLContext] = None,
# SMTP parameters
server_hostname: Optional[str] = None,
**SMTP_parameters,
):
super().__init__(
handler,
loop,
ssl_context=ssl_context,
server_hostname=server_hostname,
**SMTP_parameters,
)
self.ready_timeout = float(
os.getenv("AIOSMTPD_CONTROLLER_TIMEOUT", ready_timeout)
)
@abstractmethod
def _trigger_server(self):
"""
Overridden by subclasses to trigger asyncio to actually initialize the SMTP
class (it's lazy initialization, done only on initial connection).
"""
raise NotImplementedError
def _run(self, ready_event: threading.Event) -> None:
asyncio.set_event_loop(self.loop)
try:
# Need to do two-step assignments here to ensure IDEs can properly
# detect the types of the vars. Cannot use `assert isinstance`, because
# Python 3.6 in asyncio debug mode has a bug wherein CoroWrapper is not
# an instance of Coroutine
self.server_coro = self._create_server()
srv: AsyncServer = self.loop.run_until_complete(self.server_coro)
self.server = srv
except Exception as error: # pragma: on-wsl
# Usually will enter this part only if create_server() cannot bind to the
# specified host:port.
#
# Somehow WSL 1.0 (Windows Subsystem for Linux) allows multiple
# listeners on one port?!
# That is why we add "pragma: on-wsl" there, so this block will not affect
# coverage on WSL 1.0.
self._thread_exception = error
return
self.loop.call_soon(ready_event.set)
self.loop.run_forever()
# We reach this point when loop is ended (by external code)
# Perform some stoppages to ensure endpoint no longer bound.
self.server.close()
self.loop.run_until_complete(self.server.wait_closed())
self.loop.close()
self.server = None
def start(self):
"""
Start a thread and run the asyncio event loop in that thread
"""
assert self._thread is None, "SMTP daemon already running"
self._factory_invoked.clear()
ready_event = threading.Event()
self._thread = threading.Thread(target=self._run, args=(ready_event,))
self._thread.daemon = True
self._thread.start()
# Wait a while until the server is responding.
start = time.monotonic()
if not ready_event.wait(self.ready_timeout):
# An exception within self._run will also result in ready_event not set
# So, we first test for that, before raising TimeoutError
if self._thread_exception is not None: # pragma: on-wsl
# See comment about WSL1.0 in the _run() method
raise self._thread_exception
else:
raise TimeoutError(
"SMTP server failed to start within allotted time. "
"This might happen if the system is too busy. "
"Try increasing the `ready_timeout` parameter."
)
respond_timeout = self.ready_timeout - (time.monotonic() - start)
# Apparently create_server invokes factory() "lazily", so exceptions in
# factory() go undetected. To trigger factory() invocation we need to open
# a connection to the server and 'exchange' some traffic.
try:
self._trigger_server()
except socket_timeout:
# We totally don't care of timeout experienced by _testconn,
pass
except Exception:
# Raise other exceptions though
raise
if not self._factory_invoked.wait(respond_timeout):
raise TimeoutError(
"SMTP server started, but not responding within allotted time. "
"This might happen if the system is too busy. "
"Try increasing the `ready_timeout` parameter."
)
if self._thread_exception is not None:
raise self._thread_exception
# Defensive
if self.smtpd is None:
raise RuntimeError("Unknown Error, failed to init SMTP server")
def stop(self, no_assert: bool = False):
"""
Stop the loop, the tasks in the loop, and terminate the thread as well.
"""
assert no_assert or self._thread is not None, "SMTP daemon not running"
self.loop.call_soon_threadsafe(self.cancel_tasks)
if self._thread is not None:
self._thread.join()
self._thread = None
self._cleanup()
@public
class BaseUnthreadedController(BaseController, metaclass=ABCMeta):
def __init__(
self,
handler: Any,
loop: Optional[asyncio.AbstractEventLoop] = None,
*,
ssl_context: Optional[ssl.SSLContext] = None,
# SMTP parameters
server_hostname: Optional[str] = None,
**SMTP_parameters,
):
super().__init__(
handler,
loop,
ssl_context=ssl_context,
server_hostname=server_hostname,
**SMTP_parameters,
)
self.ended = threading.Event()
def begin(self):
"""
Sets up the asyncio server task and inject it into the asyncio event loop.
Does NOT actually start the event loop itself.
"""
asyncio.set_event_loop(self.loop)
# Need to do two-step assignments here to ensure IDEs can properly
# detect the types of the vars. Cannot use `assert isinstance`, because
# Python 3.6 in asyncio debug mode has a bug wherein CoroWrapper is not
# an instance of Coroutine
self.server_coro = self._create_server()
srv: AsyncServer = self.loop.run_until_complete(self.server_coro)
self.server = srv
async def finalize(self):
"""
Perform orderly closing of the server listener.
NOTE: This is an async method; await this from an async or use
loop.create_task() (if loop is still running), or
loop.run_until_complete() (if loop has stopped)
"""
self.ended.clear()
server = self.server
server.close()
await server.wait_closed()
self.server_coro.close()
self._cleanup()
self.ended.set()
def end(self):
"""
Convenience method to asynchronously invoke finalize().
Consider using loop.call_soon_threadsafe to invoke this method, especially
if your loop is running in a different thread. You can afterwards .wait() on
ended attribute (a threading.Event) to check for completion, if needed.
"""
self.ended.clear()
if self.loop.is_running():
self.loop.create_task(self.finalize())
else:
self.loop.run_until_complete(self.finalize())
@public
class InetMixin(BaseController, metaclass=ABCMeta):
def __init__(
self,
handler: Any,
hostname: Optional[str] = None,
port: int = 8025,
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(
handler,
loop,
**kwargs,
)
self._localhost = get_localhost()
self.hostname = self._localhost if hostname is None else hostname
self.port = port
def _create_server(self) -> Coroutine:
"""
Creates a 'server task' that listens on an INET host:port.
Does NOT actually start the protocol object itself;
_factory_invoker() is only called upon fist connection attempt.
"""
return self.loop.create_server(
self._factory_invoker,
host=self.hostname,
port=self.port,
ssl=self.ssl_context,
)
def _trigger_server(self):
"""
Opens a socket connection to the newly launched server, wrapping in an SSL
Context if necessary, and read some data from it to ensure that factory()
gets invoked.
"""
# At this point, if self.hostname is Falsy, it most likely is "" (bind to all
# addresses). In such case, it should be safe to connect to localhost)
hostname = self.hostname or self._localhost
with ExitStack() as stk:
s = stk.enter_context(create_connection((hostname, self.port), 1.0))
if self.ssl_context:
client_ctx = _server_to_client_ssl_ctx(self.ssl_context)
s = stk.enter_context(client_ctx.wrap_socket(s))
s.recv(1024)
@public
class UnixSocketMixin(BaseController, metaclass=ABCMeta): # pragma: no-unixsock
def __init__(
self,
handler: Any,
unix_socket: Union[str, Path],
loop: Optional[asyncio.AbstractEventLoop] = None,
**kwargs,
):
super().__init__(
handler,
loop,
**kwargs,
)
self.unix_socket = str(unix_socket)
def _create_server(self) -> Coroutine:
"""
Creates a 'server task' that listens on a Unix Socket file.
Does NOT actually start the protocol object itself;
_factory_invoker() is only called upon fist connection attempt.
"""
return self.loop.create_unix_server(
self._factory_invoker,
path=self.unix_socket,
ssl=self.ssl_context,
)
def _trigger_server(self):
"""
Opens a socket connection to the newly launched server, wrapping in an SSL
Context if necessary, and read some data from it to ensure that factory()
gets invoked.
"""
with ExitStack() as stk:
s: makesock = stk.enter_context(makesock(AF_UNIX, SOCK_STREAM))
s.connect(self.unix_socket)
if self.ssl_context:
client_ctx = _server_to_client_ssl_ctx(self.ssl_context)
s = stk.enter_context(client_ctx.wrap_socket(s))
s.recv(1024)
@public
class Controller(InetMixin, BaseThreadedController):
"""Provides a multithreaded controller that listens on an INET endpoint"""
def _trigger_server(self):
# Prevent confusion on which _trigger_server() to invoke.
# Or so LGTM.com claimed
InetMixin._trigger_server(self)
@public
class UnixSocketController( # pragma: no-unixsock
UnixSocketMixin, BaseThreadedController
):
"""Provides a multithreaded controller that listens on a Unix Socket file"""
def _trigger_server(self): # pragma: no-unixsock
# Prevent confusion on which _trigger_server() to invoke.
# Or so LGTM.com claimed
UnixSocketMixin._trigger_server(self)
@public
class UnthreadedController(InetMixin, BaseUnthreadedController):
"""Provides an unthreaded controller that listens on an INET endpoint"""
@public
class UnixSocketUnthreadedController( # pragma: no-unixsock
UnixSocketMixin, BaseUnthreadedController
):
"""Provides an unthreaded controller that listens on a Unix Socket file"""
|