File: proxy.py

package info (click to toggle)
mautrix-python 0.20.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid, trixie
  • size: 1,812 kB
  • sloc: python: 19,103; makefile: 16
file content (129 lines) | stat: -rw-r--r-- 3,910 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

from typing import Awaitable, Callable, TypeVar
import asyncio
import json
import logging
import time
import urllib.request

from aiohttp import ClientConnectionError
from yarl import URL

from mautrix.util.logging import TraceLogger

try:
    from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
except ImportError:

    class ProxyError(Exception):
        pass

    ProxyConnectionError = ProxyTimeoutError = ProxyError

RETRYABLE_PROXY_EXCEPTIONS = (
    ProxyError,
    ProxyTimeoutError,
    ProxyConnectionError,
    ClientConnectionError,
    ConnectionError,
    asyncio.TimeoutError,
)


class ProxyHandler:
    current_proxy_url: str | None = None
    log = logging.getLogger("mau.proxy")

    def __init__(self, api_url: str | None) -> None:
        self.api_url = api_url

    def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
        assert self.api_url is not None

        api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))

        # NOTE: using urllib.request to intentionally block the whole bridge until the proxy change applied
        request = urllib.request.Request(api_url, method="GET")
        self.log.debug("Requesting proxy from: %s", api_url)

        try:
            with urllib.request.urlopen(request) as f:
                response = json.loads(f.read().decode())
        except Exception:
            self.log.exception("Failed to retrieve proxy from API")
            return self.current_proxy_url
        else:
            return response["proxy_url"]

    def update_proxy_url(self, reason: str | None = None) -> bool:
        old_proxy = self.current_proxy_url
        new_proxy = None

        if self.api_url is not None:
            new_proxy = self.get_proxy_url_from_api(reason)
        else:
            new_proxy = urllib.request.getproxies().get("http")

        if old_proxy != new_proxy:
            self.log.debug("Set new proxy URL: %s", new_proxy)
            self.current_proxy_url = new_proxy
            return True

        self.log.debug("Got same proxy URL: %s", new_proxy)
        return False

    def get_proxy_url(self) -> str | None:
        if not self.current_proxy_url:
            self.update_proxy_url()

        return self.current_proxy_url


T = TypeVar("T")


async def proxy_with_retry(
    name: str,
    func: Callable[[], Awaitable[T]],
    logger: TraceLogger,
    proxy_handler: ProxyHandler,
    on_proxy_change: Callable[[], Awaitable[None]],
    max_retries: int = 10,
    min_wait_seconds: int = 0,
    max_wait_seconds: int = 60,
    multiply_wait_seconds: int = 10,
    retryable_exceptions: tuple[Exception] = RETRYABLE_PROXY_EXCEPTIONS,
    reset_after_seconds: int | None = None,
) -> T:
    errors = 0
    last_error = 0

    while True:
        try:
            return await func()
        except retryable_exceptions as e:
            errors += 1
            if errors > max_retries:
                raise
            wait = errors * multiply_wait_seconds
            wait = max(wait, min_wait_seconds)
            wait = min(wait, max_wait_seconds)
            logger.warning(
                "%s while trying to %s, retrying in %d seconds",
                e.__class__.__name__,
                name,
                wait,
            )
            if errors > 1 and proxy_handler.update_proxy_url(
                f"{e.__class__.__name__} while trying to {name}"
            ):
                await on_proxy_change()

            # If sufficient time has passed since the previous error, reset the
            # error count. Useful for long running tasks with rare failures.
            if reset_after_seconds is not None:
                now = time.time()
                if last_error and now - last_error > reset_after_seconds:
                    errors = 0
                last_error = now