1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
|
import argparse
import asyncio
import gc
import os.path
import pathlib
import socket
import ssl
PRINT = 0
async def echo_server(loop, address, unix):
if unix:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(address)
sock.listen(5)
sock.setblocking(False)
if PRINT:
print('Server listening at', address)
with sock:
while True:
client, addr = await loop.sock_accept(sock)
if PRINT:
print('Connection from', addr)
loop.create_task(echo_client(loop, client))
async def echo_client(loop, client):
try:
client.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except (OSError, NameError):
pass
with client:
while True:
data = await loop.sock_recv(client, 1000000)
if not data:
break
await loop.sock_sendall(client, data)
if PRINT:
print('Connection closed')
async def echo_client_streams(reader, writer):
sock = writer.get_extra_info('socket')
try:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except (OSError, NameError):
pass
if PRINT:
print('Connection from', sock.getpeername())
while True:
data = await reader.read(1000000)
if not data:
break
writer.write(data)
if PRINT:
print('Connection closed')
writer.close()
class EchoProtocol(asyncio.Protocol):
def connection_made(self, transport):
self.transport = transport
def connection_lost(self, exc):
self.transport = None
def data_received(self, data):
self.transport.write(data)
class EchoBufferedProtocol(asyncio.BufferedProtocol):
def connection_made(self, transport):
self.transport = transport
# Here the buffer is intended to be copied, so that the outgoing buffer
# won't be wrongly updated by next read
self.buffer = bytearray(256 * 1024)
def connection_lost(self, exc):
self.transport = None
def get_buffer(self, sizehint):
return self.buffer
def buffer_updated(self, nbytes):
self.transport.write(self.buffer[:nbytes])
async def print_debug(loop):
while True:
print(chr(27) + "[2J") # clear screen
loop.print_debug_info()
await asyncio.sleep(0.5)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--uvloop', default=False, action='store_true')
parser.add_argument('--streams', default=False, action='store_true')
parser.add_argument('--proto', default=False, action='store_true')
parser.add_argument('--addr', default='127.0.0.1:25000', type=str)
parser.add_argument('--print', default=False, action='store_true')
parser.add_argument('--ssl', default=False, action='store_true')
parser.add_argument('--buffered', default=False, action='store_true')
args = parser.parse_args()
if args.uvloop:
import uvloop
loop = uvloop.new_event_loop()
print('using UVLoop')
else:
loop = asyncio.new_event_loop()
print('using asyncio loop')
asyncio.set_event_loop(loop)
loop.set_debug(False)
if args.print:
PRINT = 1
if hasattr(loop, 'print_debug_info'):
loop.create_task(print_debug(loop))
PRINT = 0
unix = False
if args.addr.startswith('file:'):
unix = True
addr = args.addr[5:]
if os.path.exists(addr):
os.remove(addr)
else:
addr = args.addr.split(':')
addr[1] = int(addr[1])
addr = tuple(addr)
print('serving on: {}'.format(addr))
server_context = None
if args.ssl:
print('with SSL')
if hasattr(ssl, 'PROTOCOL_TLS'):
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS)
else:
server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
server_context.load_cert_chain(
(pathlib.Path(__file__).parent.parent.parent /
'tests' / 'certs' / 'ssl_cert.pem'),
(pathlib.Path(__file__).parent.parent.parent /
'tests' / 'certs' / 'ssl_key.pem'))
if hasattr(server_context, 'check_hostname'):
server_context.check_hostname = False
server_context.verify_mode = ssl.CERT_NONE
if args.streams:
if args.proto:
print('cannot use --stream and --proto simultaneously')
exit(1)
if args.buffered:
print('cannot use --stream and --buffered simultaneously')
exit(1)
print('using asyncio/streams')
if unix:
coro = asyncio.start_unix_server(echo_client_streams,
addr,
ssl=server_context)
else:
coro = asyncio.start_server(echo_client_streams,
*addr,
ssl=server_context)
srv = loop.run_until_complete(coro)
elif args.proto:
if args.streams:
print('cannot use --stream and --proto simultaneously')
exit(1)
if args.buffered:
print('using buffered protocol')
protocol = EchoBufferedProtocol
else:
print('using simple protocol')
protocol = EchoProtocol
if unix:
coro = loop.create_unix_server(protocol, addr,
ssl=server_context)
else:
coro = loop.create_server(protocol, *addr,
ssl=server_context)
srv = loop.run_until_complete(coro)
else:
if args.ssl:
print('cannot use SSL for loop.sock_* methods')
exit(1)
print('using sock_recv/sock_sendall')
loop.create_task(echo_server(loop, addr, unix))
try:
loop.run_forever()
finally:
if hasattr(loop, 'print_debug_info'):
gc.collect()
print(chr(27) + "[2J")
loop.print_debug_info()
loop.close()
|