File: testing.py

package info (click to toggle)
python-grpclib 0.4.8-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 484 kB
  • sloc: python: 3,370; makefile: 2
file content (138 lines) | stat: -rw-r--r-- 4,024 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import asyncio

from types import TracebackType
from typing import TYPE_CHECKING, Collection, Optional, Type

from .client import Channel
from .server import Server
from .protocol import H2Protocol
from .encoding.base import CodecBase, StatusDetailsCodecBase

if TYPE_CHECKING:
    from ._typing import IServable  # noqa


class _Server(asyncio.AbstractServer):

    def get_loop(self) -> asyncio.AbstractEventLoop:
        raise NotImplementedError

    def is_serving(self) -> bool:
        raise NotImplementedError

    async def start_serving(self) -> None:
        raise NotImplementedError

    async def serve_forever(self) -> None:
        raise NotImplementedError

    def close(self) -> None:
        pass

    async def wait_closed(self) -> None:
        pass


class _InMemoryTransport(asyncio.Transport):

    def __init__(
        self,
        protocol: H2Protocol,
    ) -> None:
        super().__init__()
        self._loop = asyncio.get_event_loop()
        self._protocol = protocol

    def _write_soon(self, data: bytes) -> None:
        if not self._protocol.connection.is_closing():
            self._protocol.data_received(data)

    def write(self, data: bytes) -> None:
        if data:
            self._loop.call_soon(self._write_soon, data)

    def is_closing(self) -> bool:
        return False

    def close(self) -> None:
        pass


class ChannelFor:
    """Manages specially initialised :py:class:`~grpclib.client.Channel`
    with an in-memory transport to a :py:class:`~grpclib.server.Server`

    Example:

    .. code-block:: python3

        class Greeter(GreeterBase):
            ...

        greeter = Greeter()

        async with ChannelFor([greeter]) as channel:
            stub = GreeterStub(channel)
            response = await stub.SayHello(HelloRequest(name='Dr. Strange'))
            assert response.message == 'Hello, Dr. Strange!'
    """
    def __init__(
        self,
        services: Collection['IServable'],
        codec: Optional[CodecBase] = None,
        status_details_codec: Optional[StatusDetailsCodecBase] = None,
    ) -> None:
        """
        :param services: list of services you want to test

        :param codec: instance of a codec to encode and decode messages,
            if omitted ``ProtoCodec`` is used by default

        :param status_details_codec: instance of a status details codec to
            encode and decode error details in a trailing metadata, if omitted
            ``ProtoStatusDetailsCodec`` is used by default
        """
        self._services = services
        self._codec = codec
        self._status_details_codec = status_details_codec

    async def __aenter__(self) -> Channel:
        """
        :return: :py:class:`~grpclib.client.Channel`
        """
        self._server = Server(
            self._services,
            codec=self._codec,
            status_details_codec=self._status_details_codec,
        )
        self._server._server = _Server()
        self._server._server_closed_fut = self._server._loop.create_future()
        self._server_protocol = self._server._protocol_factory()

        self._channel = Channel(
            codec=self._codec,
            status_details_codec=self._status_details_codec,
        )
        self._channel._protocol = self._channel._protocol_factory()

        self._channel._protocol.connection_made(
            _InMemoryTransport(self._server_protocol)
        )
        self._server_protocol.connection_made(
            _InMemoryTransport(self._channel._protocol)
        )
        return self._channel

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> None:
        assert self._channel._protocol is not None
        self._channel._protocol.connection_lost(None)
        self._channel.close()

        self._server_protocol.connection_lost(None)
        self._server.close()
        await self._server.wait_closed()