1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing
import unittest
import errno
import select
import socket
import ssl
# only mbedTLS supports non-blocking mode
ssl_supports_nonblocking = hasattr(ssl, "MBEDTLS_VERSION")
# get the name of an errno error code
def errno_name(er):
if er == errno.EAGAIN:
return "EAGAIN"
if er == errno.EINPROGRESS:
return "EINPROGRESS"
return er
# do_connect establishes the socket and wraps it if tls is True.
# If handshake is true, the initial connect (and TLS handshake) is
# allowed to be performed before returning.
def do_connect(self, peer_addr, tls, handshake):
s = socket.socket()
s.setblocking(False)
try:
print("Connecting to", peer_addr)
s.connect(peer_addr)
self.fail()
except OSError as er:
print("connect:", errno_name(er.errno))
self.assertEqual(er.errno, errno.EINPROGRESS)
# wrap with ssl/tls if desired
if tls:
print("wrap socket")
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
s = ssl_context.wrap_socket(s, do_handshake_on_connect=handshake)
return s
# poll a socket and check the result
def poll(self, s, expect_writable):
poller = select.poll()
poller.register(s)
result = poller.poll(0)
print("poll:", result)
if expect_writable:
self.assertEqual(len(result), 1)
self.assertEqual(result[0][1], select.POLLOUT)
else:
self.assertEqual(result, [])
# do_test runs the test against a specific peer address.
def do_test(self, peer_addr, tls, handshake):
print()
# MicroPython plain and TLS sockets have read/write
hasRW = True
# MicroPython plain sockets have send/recv
# MicroPython TLS sockets don't have send/recv
hasSR = not tls
# connect + send
# non-blocking send should raise EAGAIN
if hasSR:
s = do_connect(self, peer_addr, tls, handshake)
poll(self, s, False)
with self.assertRaises(OSError) as ctx:
ret = s.send(b"1234")
print("send error:", errno_name(ctx.exception.errno))
self.assertEqual(ctx.exception.errno, errno.EAGAIN)
s.close()
# connect + write
# non-blocking write should return None
if hasRW:
s = do_connect(self, peer_addr, tls, handshake)
poll(self, s, tls and handshake)
ret = s.write(b"1234")
print("write:", ret)
if tls and handshake:
self.assertEqual(ret, 4)
else:
self.assertIsNone(ret)
s.close()
# connect + recv
# non-blocking recv should raise EAGAIN
if hasSR:
s = do_connect(self, peer_addr, tls, handshake)
poll(self, s, False)
with self.assertRaises(OSError) as ctx:
ret = s.recv(10)
print("recv error:", errno_name(ctx.exception.errno))
self.assertEqual(ctx.exception.errno, errno.EAGAIN)
s.close()
# connect + read
# non-blocking read should return None
if hasRW:
s = do_connect(self, peer_addr, tls, handshake)
poll(self, s, tls and handshake)
ret = s.read(10)
print("read:", ret)
self.assertIsNone(ret)
s.close()
class Test(unittest.TestCase):
# these tests use a non-existent test IP address, this way the connect takes forever and
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
def test_plain_sockets_to_nowhere(self):
do_test(self, socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
@unittest.skipIf(not ssl_supports_nonblocking, "SSL doesn't support non-blocking")
def test_ssl_sockets_to_nowhere(self):
do_test(self, socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
def test_plain_sockets(self):
do_test(self, socket.getaddrinfo("micropython.org", 80)[0][-1], False, False)
@unittest.skipIf(not ssl_supports_nonblocking, "SSL doesn't support non-blocking")
def test_ssl_sockets(self):
do_test(self, socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
if __name__ == "__main__":
unittest.main()
|