File: auto.py

package info (click to toggle)
python-proton-core 0.4.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 540 kB
  • sloc: python: 3,574; makefile: 15
file content (132 lines) | stat: -rw-r--r-- 5,097 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from asyncio import transports, TimeoutError
from typing import List
from unittest.mock import Mock
from urllib.parse import urlparse
import json, base64, struct, time, asyncio, random, itertools

from ..exceptions import *
from .base import Transport
from .aiohttp import AiohttpTransport
from .alternativerouting import AlternativeRoutingTransport
from ..api import sync_wrapper


class AutoTransport(Transport):
    # We assume that a given transport fails after that number of seconds
    TRANSPORT_TIMEOUT = 15

    @classmethod
    def _get_priority(cls):
        return 100

    def __init__(self, session, transport_choices: List[Transport] = None, transport_timeout: int = None):
        super().__init__(session)

        self._current_transport = None
        self._transport_choices = transport_choices or [
            (0, AiohttpTransport),
            (5, AlternativeRoutingTransport)
        ]
        self._transport_timeout = transport_timeout or self.TRANSPORT_TIMEOUT

    @property
    def is_available(self) -> bool:
        return self._current_transport is not None

    @property
    def transport_choices(self):
        return self._transport_choices

    @transport_choices.setter
    def transport_choices(self, newvalue):
        self._transport_choices = []

        for timeout, cls in newvalue:
            if not isinstance(cls, Transport):
                raise TypeError("Transports should be a subclass of Transport")
            self._transport_choices.append((timeout, cls))
        self._transport_choices.sort(key=lambda x: x[0])

    async def _ping_via_transport(self, timeout, transport):
        await asyncio.sleep(timeout)
        ping_url = "/tests/ping"
        try:
            result = await asyncio.wait_for(transport.async_api_request(ping_url), self._transport_timeout)
        except TimeoutError as error:
            raise ProtonAPINotReachable(
                f"{type(transport).__name__} transport not available: unable to reach {ping_url}"
            ) from error
        if result != {"Code": 1000}:
            raise ProtonAPINotAvailable(
                f"{type(transport).__name__} transport received unexpected response from {ping_url}:\n"
                f"{result}"
            )
        return transport

    async def find_available_transport(self):
        pending = []
        for timeout, cls in self._transport_choices:
            transport = cls(self._session)
            pending.append(asyncio.create_task(self._ping_via_transport(timeout, transport)))

        results_ok = []
        results_fail = []
        final_timestamp = time.time() + self._transport_timeout
        while len(pending) > 0 and len(results_ok) == 0:
            done, pending = await asyncio.wait(pending, timeout=max(0.1, final_timestamp - time.time()), return_when=asyncio.FIRST_COMPLETED)
            for task in done:
                try:
                    results_ok.append(task.result())
                except (ProtonAPINotAvailable, ProtonAPINotReachable) as e:
                    # That means that we were able to get to the API (wasn't reachable or was mitm'ed)
                    results_fail.append(e)
                except Exception as e:
                    # Unhandled exception, we might want to understand what is going on
                    for task in pending:
                        task.cancel()
                    raise

        for task in pending:
            task.cancel()

        if not results_ok:
            raise ProtonAPINotReachable("No working transports found")

        self._current_transport = results_ok[0]

    async def async_api_request(
        self, endpoint,
        jsondata=None, data=None, additional_headers=None, method=None, params=None,
        return_raw=False
    ):
        tries_left = 3
        while tries_left > 0:
            tries_left -= 1
            if self._current_transport is None:
                await self.find_available_transport()

            try:
                return await asyncio.wait_for(self._current_transport.async_api_request(endpoint, jsondata, data, additional_headers, method, params, return_raw=return_raw), self._transport_timeout)
            except asyncio.TimeoutError:
                # Reset transport
                self._current_transport = None

        raise ProtonAPINotReachable("Timeout accessing the API")  # we should not reach that point except in case of Timeout