1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
|
# Copyright (c) 2019, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''Asyncio protocol abstraction.'''
__all__ = ('connect_rs', 'serve_rs')
import asyncio
from functools import partial
from aiorpcx.curio import Event, timeout_after, TaskTimeout
from aiorpcx.session import RPCSession, SessionBase, SessionKind
from aiorpcx.util import NetAddress
class ConnectionLostError(Exception):
pass
class RSTransport(asyncio.Protocol):
def __init__(self, session_factory, framer, kind):
self.session_factory = session_factory
self.loop = asyncio.get_event_loop()
self.session = None
self.kind = kind
self._proxy = None
self._asyncio_transport = None
self._remote_address = None
self._framer = framer
# Cleared when the send socket is full
self._can_send = Event()
self._can_send.set()
self._closed_event = Event()
self._process_messages_task = None
async def process_messages(self):
try:
await self.session.process_messages(self.receive_message)
except ConnectionLostError:
pass
finally:
self._closed_event.set()
await self.session.connection_lost()
async def receive_message(self):
return await self._framer.receive_message()
def connection_made(self, transport):
'''Called by asyncio when a connection is established.'''
self._asyncio_transport = transport
# If the Socks proxy was used then _proxy and _remote_address are already set
if self._proxy is None:
# This would throw if called on a closed SSL transport. Fixed in asyncio in
# Python 3.6.1 and 3.5.4
peername = transport.get_extra_info('peername')
self._remote_address = NetAddress(peername[0], peername[1])
self.session = self.session_factory(self)
self._framer = self._framer or self.session.default_framer()
self._process_messages_task = self.loop.create_task(self.process_messages())
def connection_lost(self, exc):
'''Called by asyncio when the connection closes.
Tear down things done in connection_made.'''
# If works around a uvloop bug; see https://github.com/MagicStack/uvloop/issues/246
if not self._asyncio_transport:
return
# Release waiting tasks
self._can_send.set()
self._framer.fail(ConnectionLostError())
def data_received(self, data):
'''Called by asyncio when a message comes in.'''
self.session.data_received(data)
self._framer.received_bytes(data)
def pause_writing(self):
'''Called by asyncio the send buffer is full.'''
if not self.is_closing():
self._can_send.clear()
self._asyncio_transport.pause_reading()
def resume_writing(self):
'''Called by asyncio the send buffer has room.'''
if not self._can_send.is_set():
self._can_send.set()
self._asyncio_transport.resume_reading()
# API exposed to session
async def write(self, message):
await self._can_send.wait()
if not self.is_closing():
framed_message = self._framer.frame(message)
self._asyncio_transport.write(framed_message)
async def close(self, force_after):
'''Close the connection and return when closed.'''
if self._asyncio_transport:
self._asyncio_transport.close()
try:
async with timeout_after(force_after):
await self._closed_event.wait()
except TaskTimeout:
await self.abort()
await self._closed_event.wait()
async def abort(self):
if self._asyncio_transport:
self._asyncio_transport.abort()
def is_closing(self):
'''Return True if the connection is closing.'''
return self._closed_event.is_set() or self._asyncio_transport.is_closing()
def proxy(self):
return self._proxy
def remote_address(self):
return self._remote_address
class RSClient:
def __init__(self, host=None, port=None, proxy=None, *, framer=None, **kwargs):
session_factory = kwargs.pop('session_factory', RPCSession)
self.protocol_factory = partial(RSTransport, session_factory, framer,
SessionKind.CLIENT)
self.host = host
self.port = port
self.proxy = proxy
self.session = None
self.loop = kwargs.get('loop', asyncio.get_event_loop())
self.kwargs = kwargs
async def create_connection(self):
'''Initiate a connection.'''
connector = self.proxy or self.loop
return await connector.create_connection(
self.protocol_factory, self.host, self.port, **self.kwargs)
async def __aenter__(self):
_transport, protocol = await self.create_connection()
self.session = protocol.session
assert isinstance(self.session, SessionBase)
return self.session
async def __aexit__(self, exc_type, exc_value, traceback):
await self.session.close()
async def serve_rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs):
loop = loop or asyncio.get_event_loop()
protocol_factory = partial(RSTransport, session_factory, framer, SessionKind.SERVER)
return await loop.create_server(protocol_factory, host, port, **kwargs)
connect_rs = RSClient
|