1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
|
##############################################################################
# Copyright 2018 Rigetti Computing
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
import asyncio
import logging
import sys
import time
from typing import Dict, Optional, Union
from warnings import warn
import zmq
import zmq.asyncio
from rpcq._base import to_msgpack, from_msgpack
import rpcq._utils as utils
from rpcq.messages import RPCError, RPCReply
if sys.version_info < (3, 7):
from rpcq.external.dataclasses import dataclass
else:
from dataclasses import dataclass
_log = logging.getLogger(__name__)
# Required values for ZeroMQ curve authentication, in lieu of a TypedDict
@dataclass
class ClientAuthConfig:
client_secret_key: bytes
client_public_key: bytes
server_public_key: bytes
class Client:
"""
Client that executes methods on a remote server by sending JSON RPC requests to a socket.
"""
def __init__(self, endpoint: str, timeout: Optional[float] = None, auth_config: Optional[ClientAuthConfig] = None):
"""
Create a client that connects to a server at <endpoint>.
:param str endpoint: Socket endpoint, e.g. "tcp://localhost:1234"
:param float timeout: Timeout in seconds for Server response, set to None to disable the timeout
:param auth_config: The configuration values necessary to enable Curve ZeroMQ authentication.
These must be provided at instantiation, so they are available when the socket is created.
"""
# TODO: leaving self.timeout for backwards compatibility; we should move towards using rpc_timeout only
self.timeout = timeout
self.rpc_timeout = timeout
self.endpoint = endpoint
self._auth_config = auth_config
self._socket = self._connect_to_socket(zmq.Context(), endpoint)
# The async socket can't be created yet because it's possible that the current event loop during Client creation
# is different to the one used later to call a method, so we need to create the socket after the first call and
# then cache it
self._async_socket_cache = None
# Mapping from request id to an event used to wake up the call that's waiting on that request.
# This is necessary to support parallel, asynchronous calls where we don't know which
# receive task will receive which reply.
self._events: Dict[str, asyncio.Event] = {}
# Cache of replies so that different tasks can share results with each other
self._replies: Dict[str, Union[RPCReply, RPCError]] = {}
def __setattr__(self, key, value):
"""
Ensure rpc_timeout attribute gets update with timeout. Currently keeping self.timeout and
self.rpc_timeout for backwards compatibility. We should move towards using rpc_timeout only.
:param key: attribute key
:param value: attribute value
:return:
"""
if key == 'timeout':
self.rpc_timeout = value
super().__setattr__(key, value)
async def call_async(self, method_name: str, *args, rpc_timeout: float = None, **kwargs):
"""
Send JSON RPC request to a backend socket and receive reply (asynchronously)
:param method_name: Method name
:param args: Args that will be passed to the remote function
:param float rpc_timeout: Timeout in seconds for Server response, set to None to disable the timeout
:param kwargs: Keyword args that will be passed to the remote function
"""
# if an rpc_timeout override is not specified, use the one set in the Client attributes
if rpc_timeout is None:
rpc_timeout = self.rpc_timeout
if rpc_timeout:
# Implementation note: this simply wraps the call in a timeout and converts to the built-in TimeoutError
try:
return await asyncio.wait_for(self._call_async(method_name, *args, **kwargs), timeout=rpc_timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"Timeout on client {self.endpoint}, method name {method_name}, class info: {self}")
else:
return await self._call_async(method_name, *args, **kwargs)
async def _call_async(self, method_name: str, *args, **kwargs):
"""
Sends a request to the socket and then wait for the reply.
To deal with multiple, asynchronous requests we do not expect that the receive reply task
scheduled from this call is the one that receives this call's reply and instead rely on
Events to signal across multiple _async_call/_recv_reply tasks.
"""
request = utils.rpc_request(method_name, *args, **kwargs)
_log.debug("Sending request: %s", request)
# setup an event to notify us when the reply is received (potentially by a task scheduled by
# another call to _async_call). we do this before we send the request to catch the case
# where the reply comes back before we re-enter this thread
self._events[request.id] = asyncio.Event()
# schedule a task to receive the reply to ensure we have a task to receive the reply
asyncio.ensure_future(self._recv_reply())
await self._async_socket.send_multipart([to_msgpack(request)])
await self._events[request.id].wait()
reply = self._replies.pop(request.id)
if isinstance(reply, RPCError):
raise utils.RPCError(reply.error)
else:
return reply.result
async def _recv_reply(self):
"""
Helper task to recieve a reply store the result and trigger the associated event.
"""
raw_reply, = await self._async_socket.recv_multipart()
reply = from_msgpack(raw_reply)
_log.debug("Received reply: %s", reply)
self._replies[reply.id] = reply
self._events.pop(reply.id).set()
def call(self, method_name: str, *args, rpc_timeout: float = None, **kwargs):
"""
Send JSON RPC request to a backend socket and receive reply
Note that this uses the default event loop to run in a blocking manner. If you would rather run in an async
fashion or provide your own event loop then use .async_call instead
:param method_name: Method name
:param args: Args that will be passed to the remote function
:param float rpc_timeout: Timeout in seconds for Server response, set to None to disable the timeout
:param kwargs: Keyword args that will be passed to the remote function
"""
# if an rpc_timeout override is not specified, use the one set in the Client attributes
if rpc_timeout is None:
rpc_timeout = self.rpc_timeout
request = utils.rpc_request(method_name, *args, **kwargs)
# Rather than change the utils.rpc_request interface in a
# non-BC way, install the timeout here. This timeout is
# communicated to the server, so that the server can terminate
# (if it so chooses) requests that will not be received by the
# client.
request.client_timeout = rpc_timeout
_log.debug("Sending request: %s", request)
self._socket.send_multipart([to_msgpack(request)])
start_time = time.time()
while True:
# Need to keep track of timeout manually in case this loop runs more than once. We subtract off already
# elapsed time from the timeout. The call to max is to make sure we don't send a negative value
# which would throw an error.
timeout = max((start_time + rpc_timeout - time.time()) * 1000, 0) if rpc_timeout is not None else None
if self._socket.poll(timeout) == 0:
raise TimeoutError(f"Timeout on client {self.endpoint}, method name {method_name}, class info: {self}")
raw_reply, = self._socket.recv_multipart()
reply = from_msgpack(raw_reply)
_log.debug("Received reply: %s", reply)
# there's a possibility that the socket will have some leftover replies from a previous
# request on it if that .call() was cancelled or timed out. Therefore, we need to discard replies that
# don't match the request just like in the call_async case.
if reply.id == request.id:
break
else:
_log.debug('Discarding reply: %s', reply)
for warning in reply.warnings:
warn(f"{warning.kind}: {warning.body}")
if isinstance(reply, RPCError):
raise utils.RPCError(reply.error)
else:
return reply.result
def close(self):
"""
Close the sockets
"""
self._socket.close()
if self._async_socket_cache:
self._async_socket_cache.close()
self._async_socket_cache = None
def _connect_to_socket(self, context: zmq.Context, endpoint: str):
"""
Connect to a DEALER socket at endpoint and turn off lingering.
:param context: ZMQ Context to use (potentially async)
:param endpoint: Endpoint
:return: Connected socket
"""
socket = context.socket(zmq.DEALER)
self.enable_auth(socket)
socket.connect(endpoint)
socket.setsockopt(zmq.LINGER, 0)
_log.debug("Client connected to endpoint %s", self.endpoint)
return socket
@property
def _async_socket(self):
"""
Creates a new async socket if one doesn't already exist for this Client
"""
if not self._async_socket_cache:
self._async_socket_cache = self._connect_to_socket(zmq.asyncio.Context(), self.endpoint)
return self._async_socket_cache
@property
def auth_configured(self) -> bool:
return self._auth_config is not None
def enable_auth(self, socket=None) -> bool:
"""
Enables Curve ZeroMQ Authentication if the necessary configuration is present
"""
if not self.auth_configured:
return False
socket.curve_secretkey = self._auth_config.client_secret_key
socket.curve_publickey = self._auth_config.client_public_key
socket.curve_serverkey = self._auth_config.server_public_key
return True
|