1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
|
import logging
from ssl import SSLContext
from typing import Any, Dict, Optional, Union
import websockets
from websockets import ClientConnection
from websockets.datastructures import Headers, HeadersLike
from ...exceptions import TransportConnectionFailed, TransportProtocolError
from .connection import AdapterConnection
log = logging.getLogger("gql.transport.common.adapters.websockets")
class WebSocketsAdapter(AdapterConnection):
"""AdapterConnection implementation using the websockets library."""
def __init__(
self,
url: str,
*,
headers: Optional[HeadersLike] = None,
ssl: Union[SSLContext, bool] = False,
connect_args: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the transport with the given parameters.
:param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
:param headers: Dict of HTTP Headers.
:param ssl: ssl_context of the connection. Use ssl=False to disable encryption
:param connect_args: Other parameters forwarded to
`websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
client.html#opening-a-connection>`_
"""
super().__init__(
url=url,
connect_args=connect_args,
)
self._headers: Optional[HeadersLike] = headers
self.ssl = ssl
self.websocket: Optional[ClientConnection] = None
self._response_headers: Optional[Headers] = None
async def connect(self) -> None:
"""Connect to the WebSocket server."""
assert self.websocket is None
ssl: Optional[Union[SSLContext, bool]]
if self.ssl:
ssl = self.ssl
else:
ssl = True if self.url.startswith("wss") else None
# Set default arguments used in the websockets.connect call
connect_args: Dict[str, Any] = {
"ssl": ssl,
"additional_headers": self.headers,
}
if self.subprotocols:
connect_args["subprotocols"] = self.subprotocols
# Adding custom parameters passed from init
connect_args.update(self.connect_args)
# Connection to the specified url
try:
self.websocket = await websockets.connect(self.url, **connect_args)
except Exception as e:
raise TransportConnectionFailed("Connect failed") from e
assert self.websocket.response is not None
self._response_headers = self.websocket.response.headers
async def send(self, message: str) -> None:
"""Send message to the WebSocket server.
Args:
message: String message to send
Raises:
TransportConnectionFailed: If connection closed
"""
if self.websocket is None:
raise TransportConnectionFailed("WebSocket connection is already closed")
try:
await self.websocket.send(message)
except Exception as e:
raise TransportConnectionFailed(
f"Error trying to send data: {type(e).__name__}"
) from e
async def receive(self) -> str:
"""Receive message from the WebSocket server.
Returns:
String message received
Raises:
TransportConnectionFailed: If connection closed
TransportProtocolError: If protocol error or binary data received
"""
# It is possible that the websocket has been already closed in another task
if self.websocket is None:
raise TransportConnectionFailed("Connection is already closed")
# Wait for the next websocket frame. Can raise ConnectionClosed
try:
data = await self.websocket.recv()
except Exception as e:
raise TransportConnectionFailed(
f"Error trying to receive data: {type(e).__name__}"
) from e
# websocket.recv() can return either str or bytes
# In our case, we should receive only str here
if not isinstance(data, str):
raise TransportProtocolError("Binary data received in the websocket")
answer: str = data
return answer
async def close(self) -> None:
"""Close the WebSocket connection."""
if self.websocket:
websocket = self.websocket
self.websocket = None
await websocket.close()
@property
def headers(self) -> Optional[HeadersLike]:
"""Get the response headers from the WebSocket connection.
Returns:
Dictionary of response headers
"""
if self._headers:
return self._headers
return {}
@property
def response_headers(self) -> Dict[str, str]:
"""Get the response headers from the WebSocket connection.
Returns:
Dictionary of response headers
"""
if self._response_headers:
return dict(self._response_headers.raw_items())
return {}
|