File: _stream.py

package info (click to toggle)
python-socks 2.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 544 kB
  • sloc: python: 5,195; sh: 8; makefile: 3
file content (91 lines) | stat: -rw-r--r-- 2,753 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import asyncio
import ssl

from .... import _abc as abc

DEFAULT_RECEIVE_SIZE = 65536


class AsyncioSocketStream(abc.AsyncSocketStream):
    _loop: asyncio.AbstractEventLoop
    _reader: asyncio.StreamReader
    _writer: asyncio.StreamWriter

    def __init__(
        self,
        loop: asyncio.AbstractEventLoop,
        reader: asyncio.StreamReader,
        writer: asyncio.StreamWriter,
    ):
        self._loop = loop
        self._reader = reader
        self._writer = writer

    async def write_all(self, data):
        self._writer.write(data)
        await self._writer.drain()

    async def read(self, max_bytes=DEFAULT_RECEIVE_SIZE):
        return await self._reader.read(max_bytes)

    async def read_exact(self, n):
        return await self._reader.readexactly(n)

    async def start_tls(
        self,
        hostname: str,
        ssl_context: ssl.SSLContext,
        ssl_handshake_timeout=None,
    ) -> 'AsyncioSocketStream':
        if hasattr(self._writer, 'start_tls'):  # Python>=3.11
            await self._writer.start_tls(
                ssl_context,
                server_hostname=hostname,
                ssl_handshake_timeout=ssl_handshake_timeout,
            )
            return self

        reader = asyncio.StreamReader()
        protocol = asyncio.StreamReaderProtocol(reader)

        transport: asyncio.Transport = await self._loop.start_tls(
            self._writer.transport,  # type: ignore
            protocol,
            ssl_context,
            server_side=False,
            server_hostname=hostname,
            ssl_handshake_timeout=ssl_handshake_timeout,
        )

        # reader.set_transport(transport)

        # Initialize the protocol, so it is made aware of being tied to
        # a TLS connection.
        # See: https://github.com/encode/httpx/issues/859
        protocol.connection_made(transport)

        writer = asyncio.StreamWriter(
            transport=transport,
            protocol=protocol,
            reader=reader,
            loop=self._loop,
        )

        stream = AsyncioSocketStream(loop=self._loop, reader=reader, writer=writer)
        # When we return a new SocketStream with new StreamReader/StreamWriter instances
        # we need to keep references to the old StreamReader/StreamWriter so that they
        # are not garbage collected and closed while we're still using them.
        stream._inner = self  # type: ignore # pylint:disable=W0212,W0201
        return stream

    async def close(self):
        self._writer.close()
        self._writer.transport.abort()  # noqa

    @property
    def reader(self):
        return self._reader  # pragma: no cover

    @property
    def writer(self):
        return self._writer  # pragma: no cover