1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
# Copyright (c) 2019, Neil Booth
#
# All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from functools import partial
try:
import websockets
except ImportError:
websockets = None
from aiorpcx.curio import spawn
from aiorpcx.session import RPCSession, SessionKind
from aiorpcx.util import NetAddress
__all__ = ('serve_ws', 'connect_ws')
class WSTransport:
'''Implementation of a websocket transport for session.py.'''
def __init__(self, websocket, session_factory, kind):
self.websocket = websocket
self.kind = kind
self.session = session_factory(self)
self.closing = False
@classmethod
async def ws_server(cls, session_factory, websocket, _path):
transport = cls(websocket, session_factory, SessionKind.SERVER)
await transport.process_messages()
@classmethod
async def ws_client(cls, uri, **kwargs):
session_factory = kwargs.pop('session_factory', RPCSession)
websocket = await websockets.connect(uri, **kwargs)
return cls(websocket, session_factory, SessionKind.CLIENT)
async def recv_message(self):
message = await self.websocket.recv()
# It might be nice to avoid the redundant conversions
if isinstance(message, str):
message = message.encode()
self.session.data_received(message)
return message
async def process_messages(self):
try:
await self.session.process_messages(self.recv_message)
except websockets.ConnectionClosed:
pass
finally:
await self.session.connection_lost()
# API exposed to session
async def write(self, framed_message):
# Prefer to send as text
try:
framed_message = framed_message.decode()
except UnicodeDecodeError:
pass
await self.websocket.send(framed_message)
async def close(self, _force_after=0):
'''Close the connection and return when closed.'''
self.closing = True
await self.websocket.close()
async def abort(self):
'''Abort the connection. For now this just calls close().'''
self.closing = True
await self.close()
def is_closing(self):
'''Return True if the connection is closing.'''
return self.closing
def proxy(self):
return None
def remote_address(self):
result = self.websocket.remote_address
if result:
result = NetAddress(*result[:2])
return result
class WSClient:
def __init__(self, uri, **kwargs):
self.uri = uri
self.session_factory = kwargs.pop('session_factory', RPCSession)
self.kwargs = kwargs
self.transport = None
self.process_messages_task = None
async def __aenter__(self):
self.transport = await WSTransport.ws_client(self.uri, **self.kwargs)
self.process_messages_task = await spawn(self.transport.process_messages())
return self.transport.session
async def __aexit__(self, exc_type, exc_value, traceback):
await self.transport.close()
assert self.process_messages_task.done()
def serve_ws(session_factory, *args, **kwargs):
ws_handler = partial(WSTransport.ws_server, session_factory)
return websockets.serve(ws_handler, *args, **kwargs)
connect_ws = WSClient
|