1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
|
import asyncio
import atexit
import logging
from threading import BoundedSemaphore, Thread
from typing import Callable, Optional
import torch
# Based on graphlearn-for-pytorch repository python/distributed/event_loop.py
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/
# LICENSE: Apache v2
def to_asyncio_future(future: torch.futures.Future) -> asyncio.futures.Future:
r"""Convert a :class:`torch.futures.Future` to a :obj:`asyncio` future."""
loop = asyncio.get_event_loop()
asyncio_future = loop.create_future()
def on_done(*_):
try:
result = future.wait()
except Exception as e:
loop.call_soon_threadsafe(asyncio_future.set_exception, e)
else:
loop.call_soon_threadsafe(asyncio_future.set_result, result)
future.add_done_callback(on_done)
return asyncio_future
class ConcurrentEventLoop:
r"""Concurrent event loop context.
Args:
concurrency: max processing concurrency.
"""
def __init__(self, concurrency: int):
self._concurrency = concurrency
self._sem = BoundedSemaphore(concurrency)
self._loop = asyncio.new_event_loop()
self._runner_t = Thread(target=self._run_loop)
self._runner_t.daemon = True
def cleanup():
for _ in range(self._concurrency):
self._sem.acquire()
for _ in range(self._concurrency):
self._sem.release()
if self._runner_t.is_alive():
self._loop.stop()
self._runner_t.join(timeout=1)
logging.debug(f'{self}: Closed `ConcurrentEventLoop`')
atexit.register(cleanup)
def start_loop(self):
if not self._runner_t.is_alive():
self._runner_t.start()
def wait_all(self):
r"""Wait for all pending tasks to be finished."""
for _ in range(self._concurrency):
self._sem.acquire()
for _ in range(self._concurrency):
self._sem.release()
def add_task(self, coro, callback: Optional[Callable] = None):
r"""Adds an asynchronized coroutine task to run.
Args:
coro: The asynchronous coroutine function.
callback (callable, optional): The callback function applied on the
returned results after the coroutine task is finished.
(default: :obj:`None`)
Note that any result returned by :obj:`callback` will be ignored.
"""
def on_done(f: asyncio.futures.Future):
try:
res = f.result()
if callback is not None:
callback(res)
except Exception as e:
logging.error(f"Coroutine task failed with error: {e}")
self._sem.release()
self._sem.acquire()
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
fut.add_done_callback(on_done)
def run_task(self, coro):
r"""Runs a coroutine task synchronously.
Args:
coro: The synchronous coroutine function.
"""
with self._sem:
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
return fut.result()
def _run_loop(self):
self._loop.run_forever()
|