File: pytest_tornasync.py

package info (click to toggle)
pytest-jupyter 0.10.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 240 kB
  • sloc: python: 644; sh: 23; makefile: 14
file content (100 lines) | stat: -rw-r--r-- 2,873 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Vendored fork of pytest_tornasync from
  https://github.com/eukaryote/pytest-tornasync/blob/9f1bdeec3eb5816e0183f975ca65b5f6f29fbfbb/src/pytest_tornasync/plugin.py
"""
import asyncio
from contextlib import closing

try:
    import tornado.ioloop
    import tornado.testing
    from tornado.simple_httpclient import SimpleAsyncHTTPClient
except ImportError:
    SimpleAsyncHTTPClient = object  # type:ignore[assignment,misc]

import pytest

# mypy: disable-error-code="no-untyped-call"
# Bring in local plugins.
from pytest_jupyter.jupyter_core import *  # noqa: F403


@pytest.fixture()
def io_loop(jp_asyncio_loop):
    """Get the current tornado event loop."""
    return tornado.ioloop.IOLoop.current()


@pytest.fixture()
def http_server(jp_asyncio_loop, http_server_port, jp_web_app):
    """Start a tornado HTTP server that listens on all available interfaces."""

    async def get_server():
        """Get a server asynchronously."""
        server = tornado.httpserver.HTTPServer(jp_web_app)
        server.add_socket(http_server_port[0])
        return server

    server = jp_asyncio_loop.run_until_complete(get_server())
    yield server
    server.stop()

    if hasattr(server, "close_all_connections"):
        try:
            jp_asyncio_loop.run_until_complete(server.close_all_connections())
        except asyncio.TimeoutError:
            pass

    http_server_port[0].close()


@pytest.fixture()
def http_server_port():
    """
    Port used by `http_server`.
    """
    return tornado.testing.bind_unused_port()


@pytest.fixture()
def http_server_client(http_server, jp_asyncio_loop):
    """
    Create an asynchronous HTTP client that can fetch from `http_server`.
    """

    async def get_client():
        """Get a client."""
        return AsyncHTTPServerClient(http_server=http_server)

    client = jp_asyncio_loop.run_until_complete(get_client())
    with closing(client) as context:
        yield context


class AsyncHTTPServerClient(SimpleAsyncHTTPClient):
    """An async http server client."""

    def initialize(self, *, http_server=None):
        """Initialize the client."""
        super().initialize()
        self._http_server = http_server

    def fetch(self, path, **kwargs):
        """
        Fetch `path` from test server, passing `kwargs` to the `fetch`
        of the underlying `SimpleAsyncHTTPClient`.
        """
        return super().fetch(self.get_url(path), **kwargs)

    def get_protocol(self):
        """Get the protocol for the client."""
        return "http"

    def get_http_port(self):
        """Get a port for the client."""
        for sock in self._http_server._sockets.values():
            return sock.getsockname()[1]
        return None

    def get_url(self, path):
        """Get the url for the client."""
        return f"{self.get_protocol()}://127.0.0.1:{self.get_http_port()}{path}"