1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
import inspect
import time
try:
import graphlib
except ImportError:
from . import vendored_graphlib as graphlib
import asyncio
class Registry:
def __init__(self, *fns, parallel=True, timer=None):
self._registry = {}
self._graph = None
self._reversed = None
self.parallel = parallel
self.timer = timer
for fn in fns:
self.register(fn)
@classmethod
def from_dict(cls, d, parallel=True, timer=None):
instance = cls(parallel=parallel, timer=timer)
for key, fn in d.items():
instance.register(fn, name=key)
return instance
def register(self, fn, *, name=None):
self._registry[name or fn.__name__] = fn
# Clear caches:
self._graph = None
self._reversed = None
def _make_time_logger(self, awaitable):
async def inner():
start = time.perf_counter()
result = await awaitable
end = time.perf_counter()
self.timer(awaitable.__name__, start, end)
return result
return inner()
@property
def graph(self):
if self._graph is None:
self._graph = {
key: set(inspect.signature(fn).parameters.keys())
for key, fn in self._registry.items()
}
return self._graph
@property
def reversed(self):
if self._reversed is None:
self._reversed = dict(reversed(pair) for pair in self._registry.items())
return self._reversed
async def resolve(self, fn, **kwargs):
if not isinstance(fn, str):
# It's a fn - is it a registered one?
name = self.reversed.get(fn)
if name is None:
# Special case - since it is not registered we need to
# introspect its parameters here and use resolve_multi
params = inspect.signature(fn).parameters.keys()
to_resolve = {p for p in params if p not in kwargs}
resolved = await self.resolve_multi(to_resolve, results=kwargs)
result = fn(**{param: resolved[param] for param in params})
if asyncio.iscoroutine(result):
result = await result
return result
else:
name = fn
results = await self.resolve_multi([name], results=kwargs)
return results[name]
def _plan(self, names, results=None):
if results is None:
results = {}
ts = graphlib.TopologicalSorter()
to_do = set(names)
done = set(results.keys())
while to_do:
item = to_do.pop()
dependencies = self.graph.get(item) or set()
ts.add(item, *dependencies)
done.add(item)
# Add any not-done dependencies to the queue
to_do.update({k for k in dependencies if k not in done})
return ts
def _get_awaitable(self, name, results):
fn = self._registry[name]
kwargs = {k: v for k, v in results.items() if k in self.graph[name]}
awaitable_fn = fn
if not inspect.iscoroutinefunction(fn):
async def _awaitable(*args, **kwargs):
return fn(*args, **kwargs)
_awaitable.__name__ = fn.__name__
awaitable_fn = _awaitable
aw = awaitable_fn(**kwargs)
if self.timer:
aw = self._make_time_logger(aw)
return aw
async def _execute_sequential(self, results, ts):
for name in ts.static_order():
if name not in self._registry:
continue
results[name] = await self._get_awaitable(name, results)
async def _execute_parallel(self, results, ts):
ts.prepare()
tasks = []
def schedule():
for name in ts.get_ready():
if name not in self._registry:
ts.done(name)
continue
tasks.append(asyncio.create_task(worker(name)))
async def worker(name):
res = await self._get_awaitable(name, results)
results[name] = res
ts.done(name)
schedule()
schedule()
while tasks:
await asyncio.gather(*[tasks.pop() for _ in range(len(tasks))])
async def resolve_multi(self, names, results=None):
if results is None:
results = {}
ts = self._plan(names, results)
if self.parallel:
await self._execute_parallel(results, ts)
else:
await self._execute_sequential(results, ts)
return results
|