1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
|
#!/usr/bin/env python3
import logging
import threading
import time
from typing import Optional
import zmq
from parsl import curvezmq
from parsl.addresses import tcp_url
from parsl.errors import InternalConsistencyError
from parsl.executors.high_throughput.errors import (
CommandClientBadError,
CommandClientTimeoutError,
)
logger = logging.getLogger(__name__)
class CommandClient:
""" CommandClient
"""
def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
"""
Parameters
----------
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
cert_dir: str | None
Path to the certificate directory. Setting this to None will disable encryption.
default: None
"""
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.ip_address = ip_address
self.port_range = port_range
self.port = None
self.create_socket_and_bind()
self._lock = threading.Lock()
self.ok = True
def create_socket_and_bind(self):
""" Creates socket and binds to a port.
Upon recreating the socket, we bind to the same port.
"""
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
self.zmq_socket.setsockopt(zmq.LINGER, 0)
if self.port is None:
self.port = self.zmq_socket.bind_to_random_port(tcp_url(self.ip_address),
min_port=self.port_range[0],
max_port=self.port_range[1])
else:
self.zmq_socket.bind(tcp_url(self.ip_address, self.port))
def run(self, message, max_retries=3, timeout_s=None):
""" This function needs to be fast at the same time aware of the possibility of
ZMQ pipes overflowing.
We could set copy=False and get slightly better latency but this results
in ZMQ sockets reaching a broken state once there are ~10k tasks in flight.
This issue can be magnified if each the serialized buffer itself is larger.
"""
if not self.ok:
raise CommandClientBadError()
start_time_s = time.monotonic()
reply = '__PARSL_ZMQ_PIPES_MAGIC__'
with self._lock:
for _ in range(max_retries):
try:
logger.debug("Sending command client command")
if timeout_s is not None:
remaining_time_s = start_time_s + timeout_s - time.monotonic()
poll_result = self.zmq_socket.poll(timeout=remaining_time_s * 1000, flags=zmq.POLLOUT)
if poll_result == zmq.POLLOUT:
pass # this is OK, so continue
elif poll_result == 0:
raise CommandClientTimeoutError("Waiting for command channel to be ready for a command")
else:
raise InternalConsistencyError(f"ZMQ poll returned unexpected value: {poll_result}")
self.zmq_socket.send_pyobj(message, copy=True)
if timeout_s is not None:
logger.debug("Polling for command client response or timeout")
remaining_time_s = start_time_s + timeout_s - time.monotonic()
poll_result = self.zmq_socket.poll(timeout=remaining_time_s * 1000, flags=zmq.POLLIN)
if poll_result == zmq.POLLIN:
pass # this is OK, so continue
elif poll_result == 0:
logger.error("Command timed-out - command client is now bad forever")
self.ok = False
raise CommandClientTimeoutError("Waiting for a reply from command channel")
else:
raise InternalConsistencyError(f"ZMQ poll returned unexpected value: {poll_result}")
logger.debug("Receiving command client response")
reply = self.zmq_socket.recv_pyobj()
logger.debug("Received command client response")
except zmq.ZMQError:
logger.exception("Potential ZMQ REQ-REP deadlock caught")
logger.info("Trying to reestablish context")
self.zmq_context.recreate()
self.create_socket_and_bind()
else:
break
if reply == '__PARSL_ZMQ_PIPES_MAGIC__':
logger.error("Command channel run retries exhausted. Unable to run command")
raise Exception("Command Channel retries exhausted")
return reply
def close(self):
self.zmq_socket.close()
self.zmq_context.term()
class TasksOutgoing:
""" Outgoing task queue from the executor to the Interchange
"""
def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
"""
Parameters
----------
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
cert_dir: str | None
Path to the certificate directory. Setting this to None will disable encryption.
default: None
"""
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.zmq_socket = self.zmq_context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])
self.poller = zmq.Poller()
self.poller.register(self.zmq_socket, zmq.POLLOUT)
def put(self, message):
""" This function needs to be fast at the same time aware of the possibility of
ZMQ pipes overflowing.
The timeout increases slowly if contention is detected on ZMQ pipes.
We could set copy=False and get slightly better latency but this results
in ZMQ sockets reaching a broken state once there are ~10k tasks in flight.
This issue can be magnified if each the serialized buffer itself is larger.
"""
timeout_ms = 1
while True:
socks = dict(self.poller.poll(timeout=timeout_ms))
if self.zmq_socket in socks and socks[self.zmq_socket] == zmq.POLLOUT:
# The copy option adds latency but reduces the risk of ZMQ overflow
logger.debug("Sending TasksOutgoing message")
self.zmq_socket.send_pyobj(message, copy=True)
logger.debug("Sent TasksOutgoing message")
return
else:
timeout_ms *= 2
logger.debug("Not sending due to non-ready zmq pipe, timeout: {} ms".format(timeout_ms))
def close(self):
self.zmq_socket.close()
self.zmq_context.term()
class ResultsIncoming:
""" Incoming results queue from the Interchange to the executor
"""
def __init__(self, ip_address, port_range, cert_dir: Optional[str] = None):
"""
Parameters
----------
ip_address: str
IP address of the client (where Parsl runs)
port_range: tuple(int, int)
Port range for the comms between client and interchange
cert_dir: str | None
Path to the certificate directory. Setting this to None will disable encryption.
default: None
"""
self.zmq_context = curvezmq.ClientContext(cert_dir)
self.results_receiver = self.zmq_context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port(tcp_url(ip_address),
min_port=port_range[0],
max_port=port_range[1])
self.poller = zmq.Poller()
self.poller.register(self.results_receiver, zmq.POLLIN)
def get(self, timeout_ms=None):
"""Get a message from the queue, returning None if timeout expires
without a message. timeout is measured in milliseconds.
"""
logger.debug("Waiting for ResultsIncoming message")
socks = dict(self.poller.poll(timeout=timeout_ms))
if self.results_receiver in socks and socks[self.results_receiver] == zmq.POLLIN:
m = self.results_receiver.recv_multipart()
logger.debug("Received ResultsIncoming message")
return m
else:
return None
def close(self):
self.results_receiver.close()
self.zmq_context.term()
|