1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
from typing import Any, Optional
import warnings
import trio
from ..._types import ProxyType
from ..._helpers import parse_proxy_url
from ..._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ._stream import TrioSocketStream
from ._resolver import Resolver
from ._connect import connect_tcp
from ..._protocols.errors import ReplyError
from ..._connectors.factory_async import create_connector
DEFAULT_TIMEOUT = 60
class TrioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
):
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._password = password
self._username = username
self._rdns = rdns
self._resolver = Resolver()
async def connect(
self,
dest_host: str,
dest_port: int,
timeout: Optional[float] = None,
**kwargs: Any,
) -> trio.socket.SocketType:
if timeout is None:
timeout = DEFAULT_TIMEOUT
_socket = kwargs.get('_socket')
if _socket is not None:
warnings.warn(
"The '_socket' argument is deprecated and will be removed in the future",
DeprecationWarning,
stacklevel=2,
)
local_addr = kwargs.get('local_addr')
try:
with trio.fail_after(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
_socket=_socket,
local_addr=local_addr,
)
except trio.TooSlowError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
_socket=None,
local_addr=None,
) -> trio.socket.SocketType:
if _socket is None:
try:
_socket = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
local_addr=local_addr,
)
except OSError as e:
msg = 'Could not connect to proxy {}:{} [{}]'.format(
self._proxy_host,
self._proxy_port,
e.strerror,
)
raise ProxyConnectionError(e.errno, msg) from e
stream = TrioSocketStream(sock=_socket)
try:
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
return _socket
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except BaseException: # trio.Cancelled...
with trio.CancelScope(shield=True):
await stream.close()
raise
@property
def proxy_host(self):
return self._proxy_host
@property
def proxy_port(self):
return self._proxy_port
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'TrioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)
|