1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
import asyncio
import ssl
from typing import Any, Optional, Tuple
import warnings
import sys
from ...._types import ProxyType
from ...._helpers import parse_proxy_url
from ...._errors import ProxyConnectionError, ProxyTimeoutError, ProxyError
from ...._protocols.errors import ReplyError
from ...._connectors.factory_async import create_connector
from .._resolver import Resolver
from ._stream import AsyncioSocketStream
from ._connect import connect_tcp
if sys.version_info >= (3, 11):
import asyncio as async_timeout # pylint:disable=reimported
else:
import async_timeout
DEFAULT_TIMEOUT = 60
class AsyncioProxy:
def __init__(
self,
proxy_type: ProxyType,
host: str,
port: int,
username: Optional[str] = None,
password: Optional[str] = None,
rdns: Optional[bool] = None,
proxy_ssl: Optional[ssl.SSLContext] = None,
forward: Optional['AsyncioProxy'] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
if loop is not None: # pragma: no cover
warnings.warn(
'The loop argument is deprecated and scheduled for removal in the future.',
DeprecationWarning,
stacklevel=2,
)
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._proxy_type = proxy_type
self._proxy_host = host
self._proxy_port = port
self._username = username
self._password = password
self._rdns = rdns
self._proxy_ssl = proxy_ssl
self._forward = forward
self._resolver = Resolver(loop=loop)
async def connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> AsyncioSocketStream:
if timeout is None:
timeout = DEFAULT_TIMEOUT
local_addr = kwargs.get('local_addr')
try:
async with async_timeout.timeout(timeout):
return await self._connect(
dest_host=dest_host,
dest_port=dest_port,
dest_ssl=dest_ssl,
local_addr=local_addr,
)
except asyncio.TimeoutError as e:
raise ProxyTimeoutError('Proxy connection timed out: {}'.format(timeout)) from e
async def _connect(
self,
dest_host: str,
dest_port: int,
dest_ssl: Optional[ssl.SSLContext] = None,
local_addr: Optional[Tuple[str, int]] = None,
) -> AsyncioSocketStream:
if self._forward is None:
try:
stream = await connect_tcp(
host=self._proxy_host,
port=self._proxy_port,
loop=self._loop,
local_addr=local_addr,
)
except OSError as e:
raise ProxyConnectionError(
e.errno,
"Couldn't connect to proxy"
f" {self._proxy_host}:{self._proxy_port} [{e.strerror}]",
) from e
else:
stream = await self._forward.connect(
dest_host=self._proxy_host,
dest_port=self._proxy_port,
)
try:
if self._proxy_ssl is not None:
stream = await stream.start_tls(
hostname=self._proxy_host,
ssl_context=self._proxy_ssl,
)
connector = create_connector(
proxy_type=self._proxy_type,
username=self._username,
password=self._password,
rdns=self._rdns,
resolver=self._resolver,
)
await connector.connect(
stream=stream,
host=dest_host,
port=dest_port,
)
if dest_ssl is not None:
stream = await stream.start_tls(
hostname=dest_host,
ssl_context=dest_ssl,
)
except ReplyError as e:
await stream.close()
raise ProxyError(e, error_code=e.error_code)
except (asyncio.CancelledError, Exception):
await stream.close()
raise
return stream
@classmethod
def create(cls, *args, **kwargs): # for backward compatibility
return cls(*args, **kwargs)
@classmethod
def from_url(cls, url: str, **kwargs) -> 'AsyncioProxy':
url_args = parse_proxy_url(url)
return cls(*url_args, **kwargs)
|