1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
|
from __future__ import annotations
import concurrent.futures as cf
import weakref
from tlz import merge
from tornado import gen
from dask.utils import parse_timedelta
from distributed.metrics import time
from distributed.utils import TimeoutError, sync
@gen.coroutine
def _cascade_future(future, cf_future):
"""
Coroutine that waits on Dask future, then transmits its outcome to
cf_future.
"""
result = yield future._result(raiseit=False)
status = future.status
if status == "finished":
cf_future.set_result(result)
elif status == "cancelled":
cf_future.cancel()
# Necessary for wait() and as_completed() to wake up
cf_future.set_running_or_notify_cancel()
else:
try:
typ, exc, tb = result
raise exc.with_traceback(tb)
except BaseException as exc:
cf_future.set_exception(exc)
@gen.coroutine
def _wait_on_futures(futures):
for fut in futures:
try:
yield fut
except Exception:
pass
class ClientExecutor(cf.Executor):
"""
A concurrent.futures Executor that executes tasks on a dask.distributed Client.
"""
_allowed_kwargs = frozenset(
["pure", "workers", "resources", "allow_other_workers", "retries"]
)
def __init__(self, client, **kwargs):
sk = set(kwargs)
if not sk <= self._allowed_kwargs:
raise TypeError(
"unsupported arguments to ClientExecutor: %s"
% sorted(sk - self._allowed_kwargs)
)
self._client = client
self._futures = weakref.WeakSet()
self._shutdown = False
self._kwargs = kwargs
def _wrap_future(self, future):
"""
Wrap a distributed Future in a concurrent.futures Future.
"""
cf_future = cf.Future()
# Support cancelling task through .cancel() on c.f.Future
def cf_callback(cf_future):
if cf_future.cancelled() and future.status != "cancelled":
future.cancel()
cf_future.add_done_callback(cf_callback)
self._client.loop.add_callback(_cascade_future, future, cf_future)
return cf_future
def submit(self, fn, *args, **kwargs):
"""Submits a callable to be executed with the given arguments.
Schedules the callable to be executed as ``fn(*args, **kwargs)``
and returns a Future instance representing the execution of the callable.
Returns
-------
A Future representing the given call.
"""
if self._shutdown:
raise RuntimeError("cannot schedule new futures after shutdown")
future = self._client.submit(fn, *args, **merge(self._kwargs, kwargs))
self._futures.add(future)
return self._wrap_future(future)
def map(self, fn, *iterables, **kwargs):
"""Returns an iterator equivalent to ``map(fn, *iterables)``.
Parameters
----------
fn : A callable that will take as many arguments as there are
passed iterables.
iterables : One iterable for each parameter to *fn*.
timeout : The maximum number of seconds to wait. If None, then there
is no limit on the wait time.
chunksize : ignored.
Returns
-------
An iterator equivalent to: ``map(fn, *iterables)`` but the calls may
be evaluated out-of-order.
Raises
------
concurrent.futures.TimeoutError:
If the entire result iterator could not be generated before the given
timeout.
Exception:
If ``fn(*args)`` raises for any values.
"""
timeout = kwargs.pop("timeout", None)
if timeout is not None:
timeout = parse_timedelta(timeout)
end_time = timeout + time()
if "chunksize" in kwargs:
del kwargs["chunksize"]
if kwargs:
raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs))
fs = self._client.map(fn, *iterables, **self._kwargs)
# Below iterator relies on fs being an iterator itself, and not just an iterable
# (such as a list), in order to cancel remaining futures
fs = iter(fs)
# Yield must be hidden in closure so that the tasks are submitted
# before the first iterator value is required.
def result_iterator():
try:
for future in fs:
self._futures.add(future)
if timeout is not None:
try:
yield future.result(end_time - time())
except TimeoutError:
raise cf.TimeoutError
else:
yield future.result()
finally:
remaining = list(fs)
self._futures.update(remaining)
self._client.cancel(remaining)
return result_iterator()
def shutdown(self, wait=True):
"""Clean-up the resources associated with the Executor.
It is safe to call this method several times. Otherwise, no other
methods can be called after this one.
Parameters
----------
wait : If True then shutdown will not return until all running
futures have finished executing. If False then all running
futures are cancelled immediately.
"""
if not self._shutdown:
self._shutdown = True
fs = list(self._futures)
if wait:
sync(self._client.loop, _wait_on_futures, fs)
else:
self._client.cancel(fs)
|