File: websockets.py

package info (click to toggle)
python-gql 4.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,900 kB
  • sloc: python: 21,677; makefile: 54
file content (156 lines) | stat: -rw-r--r-- 5,051 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import logging
from ssl import SSLContext
from typing import Any, Dict, Optional, Union

import websockets
from websockets import ClientConnection
from websockets.datastructures import Headers, HeadersLike

from ...exceptions import TransportConnectionFailed, TransportProtocolError
from .connection import AdapterConnection

log = logging.getLogger("gql.transport.common.adapters.websockets")


class WebSocketsAdapter(AdapterConnection):
    """AdapterConnection implementation using the websockets library."""

    def __init__(
        self,
        url: str,
        *,
        headers: Optional[HeadersLike] = None,
        ssl: Union[SSLContext, bool] = False,
        connect_args: Optional[Dict[str, Any]] = None,
    ) -> None:
        """Initialize the transport with the given parameters.

        :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'.
        :param headers: Dict of HTTP Headers.
        :param ssl: ssl_context of the connection. Use ssl=False to disable encryption
        :param connect_args: Other parameters forwarded to
            `websockets.connect <https://websockets.readthedocs.io/en/stable/reference/\
            client.html#opening-a-connection>`_
        """
        super().__init__(
            url=url,
            connect_args=connect_args,
        )

        self._headers: Optional[HeadersLike] = headers
        self.ssl = ssl

        self.websocket: Optional[ClientConnection] = None
        self._response_headers: Optional[Headers] = None

    async def connect(self) -> None:
        """Connect to the WebSocket server."""

        assert self.websocket is None

        ssl: Optional[Union[SSLContext, bool]]
        if self.ssl:
            ssl = self.ssl
        else:
            ssl = True if self.url.startswith("wss") else None

        # Set default arguments used in the websockets.connect call
        connect_args: Dict[str, Any] = {
            "ssl": ssl,
            "additional_headers": self.headers,
        }

        if self.subprotocols:
            connect_args["subprotocols"] = self.subprotocols

        # Adding custom parameters passed from init
        connect_args.update(self.connect_args)

        # Connection to the specified url
        try:
            self.websocket = await websockets.connect(self.url, **connect_args)
        except Exception as e:
            raise TransportConnectionFailed("Connect failed") from e

        assert self.websocket.response is not None

        self._response_headers = self.websocket.response.headers

    async def send(self, message: str) -> None:
        """Send message to the WebSocket server.

        Args:
            message: String message to send

        Raises:
            TransportConnectionFailed: If connection closed
        """
        if self.websocket is None:
            raise TransportConnectionFailed("WebSocket connection is already closed")

        try:
            await self.websocket.send(message)
        except Exception as e:
            raise TransportConnectionFailed(
                f"Error trying to send data: {type(e).__name__}"
            ) from e

    async def receive(self) -> str:
        """Receive message from the WebSocket server.

        Returns:
            String message received

        Raises:
            TransportConnectionFailed: If connection closed
            TransportProtocolError: If protocol error or binary data received
        """
        # It is possible that the websocket has been already closed in another task
        if self.websocket is None:
            raise TransportConnectionFailed("Connection is already closed")

        # Wait for the next websocket frame. Can raise ConnectionClosed
        try:
            data = await self.websocket.recv()
        except Exception as e:
            raise TransportConnectionFailed(
                f"Error trying to receive data: {type(e).__name__}"
            ) from e

        # websocket.recv() can return either str or bytes
        # In our case, we should receive only str here
        if not isinstance(data, str):
            raise TransportProtocolError("Binary data received in the websocket")

        answer: str = data

        return answer

    async def close(self) -> None:
        """Close the WebSocket connection."""
        if self.websocket:
            websocket = self.websocket
            self.websocket = None
            await websocket.close()

    @property
    def headers(self) -> Optional[HeadersLike]:
        """Get the response headers from the WebSocket connection.

        Returns:
            Dictionary of response headers
        """
        if self._headers:
            return self._headers
        return {}

    @property
    def response_headers(self) -> Dict[str, str]:
        """Get the response headers from the WebSocket connection.

        Returns:
            Dictionary of response headers
        """
        if self._response_headers:
            return dict(self._response_headers.raw_items())
        return {}