File: test_client_channel.py

package info (click to toggle)
python-grpclib 0.4.9-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 676 kB
  • sloc: python: 6,864; makefile: 2
file content (109 lines) | stat: -rw-r--r-- 3,609 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import ssl
import asyncio
import tempfile
import contextlib
from unittest.mock import patch, ANY

import pytest
import certifi
from h2.connection import H2Connection

import grpclib.client
from grpclib.client import Channel, Handler
from grpclib.config import Configuration
from grpclib.protocol import H2Protocol
from grpclib.testing import ChannelFor

from dummy_pb2 import DummyRequest, DummyReply
from dummy_grpc import DummyServiceStub
from test_functional import DummyService
from stubs import TransportStub


@pytest.mark.asyncio
async def test_concurrent_connect(loop):
    count = 5
    reqs = [DummyRequest(value="ping") for _ in range(count)]
    reps = [DummyReply(value="pong") for _ in range(count)]

    channel = Channel()

    async def create_connection(*args, **kwargs):
        await asyncio.sleep(0.01)
        return None, _channel._protocol

    stub = DummyServiceStub(channel)
    async with ChannelFor([DummyService()]) as _channel:
        with patch.object(loop, "create_connection") as po:
            po.side_effect = create_connection
            tasks = [loop.create_task(stub.UnaryUnary(req)) for req in reqs]
            replies = await asyncio.gather(*tasks)
    assert replies == reps
    po.assert_awaited_once_with(
        ANY,
        "127.0.0.1",
        50051,
        ssl=None,
        server_hostname=None,
    )


@pytest.mark.asyncio
async def test_default_ssl_context():
    with patch.object(certifi, "where", return_value=certifi.where()) as po:
        certifi_channel = Channel(ssl=True)
        assert certifi_channel._ssl
        po.assert_called_once()

    with patch.object(certifi, "where", side_effect=AssertionError):
        with patch.dict("sys.modules", {"certifi": None}):
            system_channel = Channel(ssl=True)
            assert system_channel._ssl


@pytest.mark.asyncio
async def test_ssl_target_name_override(loop):
    config = Configuration(ssl_target_name_override="example.com")

    async def create_connection(*args, **kwargs):
        h2_conn = H2Connection()
        transport = TransportStub(h2_conn)
        protocol = H2Protocol(Handler(), config.__for_test__(), h2_conn.config)
        protocol.connection_made(transport)
        return None, protocol

    with patch.object(loop, "create_connection") as po:
        po.side_effect = create_connection
        async with Channel(ssl=True, config=config) as channel:
            await channel.__connect__()
            po.assert_awaited_once_with(
                ANY, ANY, ANY, ssl=channel._ssl, server_hostname="example.com"
            )


@pytest.mark.asyncio
async def test_default_verify_paths():
    with contextlib.ExitStack() as cm:
        tf = cm.enter_context(tempfile.NamedTemporaryFile()).name
        td = cm.enter_context(tempfile.TemporaryDirectory())
        po = cm.enter_context(
            patch.object(ssl.SSLContext, "load_verify_locations"),
        )
        cm.enter_context(
            patch.dict("os.environ", SSL_CERT_FILE=tf, SSL_CERT_DIR=td),
        )
        default_verify_paths = ssl.get_default_verify_paths()
        channel = Channel(ssl=default_verify_paths)
        assert channel._ssl
        po.assert_called_once_with(tf, td, None)
        assert default_verify_paths.openssl_cafile_env == "SSL_CERT_FILE"
        assert default_verify_paths.openssl_capath_env == "SSL_CERT_DIR"


@pytest.mark.asyncio
async def test_no_ssl_support():
    with patch.object(grpclib.client, "_ssl", None):
        Channel()
        with pytest.raises(RuntimeError) as err:
            Channel(ssl=True)
        err.match("SSL is not supported")