File: event_loop.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (103 lines) | stat: -rw-r--r-- 3,309 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import asyncio
import atexit
import logging
from threading import BoundedSemaphore, Thread
from typing import Callable, Optional

import torch

# Based on graphlearn-for-pytorch repository python/distributed/event_loop.py
# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/
# LICENSE: Apache v2


def to_asyncio_future(future: torch.futures.Future) -> asyncio.futures.Future:
    r"""Convert a :class:`torch.futures.Future` to a :obj:`asyncio` future."""
    loop = asyncio.get_event_loop()
    asyncio_future = loop.create_future()

    def on_done(*_):
        try:
            result = future.wait()
        except Exception as e:
            loop.call_soon_threadsafe(asyncio_future.set_exception, e)
        else:
            loop.call_soon_threadsafe(asyncio_future.set_result, result)

    future.add_done_callback(on_done)

    return asyncio_future


class ConcurrentEventLoop:
    r"""Concurrent event loop context.

    Args:
        concurrency: max processing concurrency.
    """
    def __init__(self, concurrency: int):
        self._concurrency = concurrency
        self._sem = BoundedSemaphore(concurrency)
        self._loop = asyncio.new_event_loop()
        self._runner_t = Thread(target=self._run_loop)
        self._runner_t.daemon = True

        def cleanup():
            for _ in range(self._concurrency):
                self._sem.acquire()
            for _ in range(self._concurrency):
                self._sem.release()
            if self._runner_t.is_alive():
                self._loop.stop()
                self._runner_t.join(timeout=1)
                logging.debug(f'{self}: Closed `ConcurrentEventLoop`')

        atexit.register(cleanup)

    def start_loop(self):
        if not self._runner_t.is_alive():
            self._runner_t.start()

    def wait_all(self):
        r"""Wait for all pending tasks to be finished."""
        for _ in range(self._concurrency):
            self._sem.acquire()
        for _ in range(self._concurrency):
            self._sem.release()

    def add_task(self, coro, callback: Optional[Callable] = None):
        r"""Adds an asynchronized coroutine task to run.

        Args:
            coro: The asynchronous coroutine function.
            callback (callable, optional): The callback function applied on the
                returned results after the coroutine task is finished.
                (default: :obj:`None`)

        Note that any result returned by :obj:`callback` will be ignored.
        """
        def on_done(f: asyncio.futures.Future):
            try:
                res = f.result()
                if callback is not None:
                    callback(res)
            except Exception as e:
                logging.error(f"Coroutine task failed with error: {e}")
            self._sem.release()

        self._sem.acquire()
        fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
        fut.add_done_callback(on_done)

    def run_task(self, coro):
        r"""Runs a coroutine task synchronously.

        Args:
            coro: The synchronous coroutine function.
        """
        with self._sem:
            fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
            return fut.result()

    def _run_loop(self):
        self._loop.run_forever()