File: test_timeouts.py

package info (click to toggle)
aiosmtplib 4.0.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 572 kB
  • sloc: python: 5,516; makefile: 20; sh: 6
file content (165 lines) | stat: -rw-r--r-- 5,008 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Timeout tests.
"""

import asyncio
import socket
import ssl

import pytest

from aiosmtplib import (
    SMTP,
    SMTPConnectTimeoutError,
    SMTPServerDisconnected,
    SMTPTimeoutError,
)
from aiosmtplib.protocol import SMTPProtocol

from .compat import cleanup_server
from .smtpd import mock_response_delayed_ok, mock_response_delayed_read


@pytest.mark.smtpd_mocks(smtp_EHLO=mock_response_delayed_ok)
async def test_command_timeout_error(smtp_client: SMTP) -> None:
    await smtp_client.connect()

    with pytest.raises(SMTPTimeoutError):
        await smtp_client.ehlo(hostname="example.com", timeout=0.0)


@pytest.mark.smtpd_mocks(smtp_DATA=mock_response_delayed_ok)
async def test_data_timeout_error(smtp_client: SMTP) -> None:
    await smtp_client.connect()
    await smtp_client.ehlo()
    await smtp_client.mail("j@example.com")
    await smtp_client.rcpt("test@example.com")
    with pytest.raises(SMTPTimeoutError):
        await smtp_client.data("HELLO WORLD", timeout=0.0)


@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_ok)
async def test_timeout_error_on_connect(smtp_client: SMTP) -> None:
    with pytest.raises(SMTPTimeoutError):
        await smtp_client.connect(timeout=0.0)

    assert smtp_client.transport is None
    assert smtp_client.protocol is None


@pytest.mark.smtpd_mocks(_handle_client=mock_response_delayed_read)
async def test_timeout_on_initial_read(smtp_client: SMTP) -> None:
    with pytest.raises(SMTPTimeoutError):
        # We need to use a timeout > 0 here to avoid timing out on connect
        await smtp_client.connect(timeout=0.01)


@pytest.mark.smtpd_mocks(smtp_STARTTLS=mock_response_delayed_ok)
async def test_timeout_on_starttls(smtp_client: SMTP) -> None:
    await smtp_client.connect()
    await smtp_client.ehlo()

    with pytest.raises(SMTPTimeoutError):
        await smtp_client.starttls(timeout=0.0)


async def test_protocol_read_response_with_timeout_times_out(
    echo_server: asyncio.AbstractServer,
    hostname: str,
    echo_server_port: int,
) -> None:
    event_loop = asyncio.get_running_loop()

    connect_future = event_loop.create_connection(
        SMTPProtocol, host=hostname, port=echo_server_port
    )

    transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0)

    with pytest.raises(SMTPTimeoutError) as exc:
        await protocol.read_response(timeout=0.0)  # type: ignore

    transport.close()

    assert str(exc.value) == "Timed out waiting for server response"


async def test_connect_timeout_error(hostname: str, unused_tcp_port: int) -> None:
    client = SMTP(hostname=hostname, port=unused_tcp_port, timeout=0.0)

    with pytest.raises(SMTPConnectTimeoutError) as exc:
        await client.connect()

    expected_message = f"Timed out connecting to {hostname} on port {unused_tcp_port}"
    assert str(exc.value) == expected_message


async def test_server_disconnected_error_after_connect_timeout(
    hostname: str,
    unused_tcp_port: int,
    sender_str: str,
    recipient_str: str,
    message_str: str,
) -> None:
    client = SMTP(hostname=hostname, port=unused_tcp_port)

    with pytest.raises(SMTPConnectTimeoutError):
        await client.connect(timeout=0.0)

    with pytest.raises(SMTPServerDisconnected):
        await client.sendmail(sender_str, [recipient_str], message_str)


async def test_protocol_timeout_on_starttls(
    bind_address: str,
    hostname: str,
    client_tls_context: ssl.SSLContext,
) -> None:
    event_loop = asyncio.get_running_loop()

    async def client_connected(
        reader: asyncio.StreamReader, writer: asyncio.StreamWriter
    ) -> None:
        await asyncio.sleep(1.0)

    server = await asyncio.start_server(
        client_connected, host=bind_address, port=0, family=socket.AF_INET
    )
    server_port = server.sockets[0].getsockname()[1] if server.sockets else 0

    connect_future = event_loop.create_connection(
        SMTPProtocol, host=hostname, port=server_port
    )

    _, protocol = await asyncio.wait_for(connect_future, timeout=1.0)

    with pytest.raises(SMTPTimeoutError):
        # STARTTLS timeout must be > 0
        await protocol.start_tls(client_tls_context, timeout=0.00001)  # type: ignore

    server.close()
    await cleanup_server(server)


async def test_protocol_connection_aborted_on_starttls(
    hostname: str,
    smtpd_server_port: int,
    client_tls_context: ssl.SSLContext,
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    event_loop = asyncio.get_running_loop()

    connect_future = event_loop.create_connection(
        SMTPProtocol, host=hostname, port=smtpd_server_port
    )
    transport, protocol = await asyncio.wait_for(connect_future, timeout=1.0)

    def mock_start_tls(*args, **kwargs) -> None:
        raise ConnectionAbortedError("Connection was aborted")

    monkeypatch.setattr(event_loop, "start_tls", mock_start_tls)

    with pytest.raises(SMTPTimeoutError):
        await protocol.start_tls(client_tls_context)

    transport.close()