File: _proxy.py

package info (click to toggle)
python-socks 2.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 544 kB
  • sloc: python: 5,195; sh: 8; makefile: 3
file content (135 lines) | stat: -rw-r--r-- 4,094 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import ssl
from typing import Any, Optional

import trio

from ._connect import connect_tcp
from ._stream import TrioSocketStream
from .._resolver import Resolver

from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError

from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector

DEFAULT_TIMEOUT = 60


class TrioProxy:
    def __init__(
        self,
        proxy_type: ProxyType,
        host: str,
        port: int,
        username: Optional[str] = None,
        password: Optional[str] = None,
        rdns: Optional[bool] = None,
        proxy_ssl: Optional[ssl.SSLContext] = None,
        forward: Optional['TrioProxy'] = None,
    ):
        self._proxy_type = proxy_type
        self._proxy_host = host
        self._proxy_port = port
        self._username = username
        self._password = password
        self._rdns = rdns

        self._proxy_ssl = proxy_ssl
        self._forward = forward

        self._resolver = Resolver()

    async def connect(
        self,
        dest_host: str,
        dest_port: int,
        dest_ssl: Optional[ssl.SSLContext] = None,
        timeout: Optional[float] = None,
        **kwargs: Any,
    ) -> TrioSocketStream:
        if timeout is None:
            timeout = DEFAULT_TIMEOUT

        local_addr = kwargs.get('local_addr')
        try:
            with trio.fail_after(timeout):
                return await self._connect(
                    dest_host=dest_host,
                    dest_port=dest_port,
                    dest_ssl=dest_ssl,
                    local_addr=local_addr,
                )
        except trio.TooSlowError as e:
            raise ProxyTimeoutError(f'Proxy connection timed out: {timeout}') from e

    async def _connect(
        self,
        dest_host: str,
        dest_port: int,
        dest_ssl: Optional[ssl.SSLContext] = None,
        local_addr: Optional[str] = None,
    ) -> TrioSocketStream:
        if self._forward is None:
            try:
                stream = await connect_tcp(
                    host=self._proxy_host,
                    port=self._proxy_port,
                    local_addr=local_addr,
                )
            except OSError as e:
                raise ProxyConnectionError(
                    e.errno,
                    "Couldn't connect to proxy"
                    f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
                ) from e
        else:
            stream = await self._forward.connect(
                dest_host=self._proxy_host,
                dest_port=self._proxy_port,
            )

        try:
            if self._proxy_ssl is not None:
                stream = await stream.start_tls(
                    hostname=self._proxy_host,
                    ssl_context=self._proxy_ssl,
                )

            connector = create_connector(
                proxy_type=self._proxy_type,
                username=self._username,
                password=self._password,
                rdns=self._rdns,
                resolver=self._resolver,
            )
            await connector.connect(
                stream=stream,
                host=dest_host,
                port=dest_port,
            )

            if dest_ssl is not None:
                stream = await stream.start_tls(
                    hostname=dest_host,
                    ssl_context=dest_ssl,
                )
        except ReplyError as e:
            await stream.close()
            raise ProxyError(e, error_code=e.error_code)
        except BaseException:  # trio.Cancelled...
            with trio.CancelScope(shield=True):
                await stream.close()
            raise

        return stream

    @classmethod
    def create(cls, *args, **kwargs):  # for backward compatibility
        return cls(*args, **kwargs)

    @classmethod
    def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
        url_args = parse_proxy_url(url)
        return cls(*url_args, **kwargs)