File: aiohttp_closed_event.py

package info (click to toggle)
python-gql 4.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,900 kB
  • sloc: python: 21,677; makefile: 54
file content (59 lines) | stat: -rw-r--r-- 1,712 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import asyncio
import functools

from aiohttp import ClientSession


def create_aiohttp_closed_event(session: ClientSession) -> asyncio.Event:
    """Work around aiohttp issue that doesn't properly close transports on exit.

    See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209

    Returns:
       An event that will be set once all transports have been properly closed.
    """

    ssl_transports = 0
    all_is_lost = asyncio.Event()

    def connection_lost(exc, orig_lost):
        nonlocal ssl_transports

        try:
            orig_lost(exc)
        finally:
            ssl_transports -= 1
            if ssl_transports == 0:
                all_is_lost.set()

    def eof_received(orig_eof_received):
        try:  # pragma: no cover
            orig_eof_received()
        except AttributeError:  # pragma: no cover
            # It may happen that eof_received() is called after
            # _app_protocol and _transport are set to None.
            pass

    assert session.connector is not None

    for conn in session.connector._conns.values():
        for handler, _ in conn:
            proto = getattr(handler.transport, "_ssl_protocol", None)
            if proto is None:
                continue

            ssl_transports += 1
            orig_lost = proto.connection_lost
            orig_eof_received = proto.eof_received

            proto.connection_lost = functools.partial(
                connection_lost, orig_lost=orig_lost
            )
            proto.eof_received = functools.partial(
                eof_received, orig_eof_received=orig_eof_received
            )

    if ssl_transports == 0:
        all_is_lost.set()

    return all_is_lost