1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
|
"""
Copyright (c) 2023 Proton AG
This file is part of Proton.
Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
"""
from dataclasses import dataclass
from typing import Awaitable, List
import aiohttp
from ..exceptions import *
from .aiohttp import AiohttpTransport
import json, base64, struct, time, asyncio, random, itertools
from urllib.parse import urlparse
from ..api import sync_wrapper
from .utils.dns import DNSParser, DNSResponseError
@dataclass
class AlternativeRoutingDNSQueryAnswer:
"""Contains the result of a successful DNS query to retrieve the
alternative routing server domain."""
expiration_time: float
domain: str
class AlternativeRoutingTransport(AiohttpTransport):
DNS_PROVIDERS = [
#dns.google
(("8.8.4.4", "8.8.8.8"), ("2001:4860:4860::8844", "2001:4860:4860::8888"), '/dns-query'),
#dns11.quad9.net
(("149.112.112.11", "9.9.9.11"), ("2620:fe::fe:11", "2620:fe::11"), '/dns-query'),
]
STRUCT_REPLY_COUNTS = struct.Struct('>HHHH')
STRUCT_REC_FORMAT = struct.Struct('>HHIH')
#Delay between DNS requests
DELAY_DNS_REQUEST = 2
TIMEOUT_DNS_REQUEST = 10
@classmethod
def _get_priority(cls):
return 5
def __init__(self, session):
super().__init__(session)
self._alternative_routes = []
@classmethod
def _compute_ar_domain(cls, host):
return b'd' + base64.b32encode(host.encode('ascii')).strip(b'=') + b".protonpro.xyz"
async def _async_dns_query(
self, domain, dns_server_ip, dns_server_path, delay=0
) -> List[AlternativeRoutingDNSQueryAnswer]:
import aiohttp
if delay > 0:
await asyncio.sleep(delay)
ardomain = self._compute_ar_domain(domain)
dns_request = DNSParser.build_query(ardomain, qtype=16, qclass=1) # TXT IN
dot_url = f'https://{dns_server_ip}{dns_server_path}'
async with aiohttp.ClientSession() as session:
async with session.post(dot_url, headers=[("Content-Type","application/dns-message")], data=dns_request) as r:
reply_data = await r.content.read()
try:
dns_answers = DNSParser.parse(reply_data)
except DNSResponseError as e:
raise ProtonAPINotReachable(str(e))
now = time.time()
# Tuples (TTL, data)
answers = []
for rec_ttl, rec_val in dns_answers:
answers.append(AlternativeRoutingDNSQueryAnswer(
expiration_time=now + rec_ttl,
domain=rec_val)
)
return answers
@property
def _http_domain(self):
return urlparse(super().http_base_url).netloc
async def _get_alternative_routes(self):
# We generate a random list of dns servers,
# we query them following that order, simultaneoulsy on IPv4/IPv6
choices_ipv4 = []
choices_ipv6 = []
for dns_server_ipv4s, dns_server_ipv6s, dns_server_path in self.DNS_PROVIDERS:
for ip in dns_server_ipv4s:
choices_ipv4.append((ip, dns_server_path))
for ip in dns_server_ipv6s:
choices_ipv6.append((ip, dns_server_path))
random.shuffle(choices_ipv4)
random.shuffle(choices_ipv6)
pending = []
i = 0
for ipv4, ipv6 in itertools.zip_longest(choices_ipv4, choices_ipv6, fillvalue=None):
if i * self.DELAY_DNS_REQUEST > self.TIMEOUT_DNS_REQUEST:
break
if ipv4 is not None:
pending.append(asyncio.create_task(self._async_dns_query(self._http_domain, ipv4[0], ipv4[1], delay=i * self.DELAY_DNS_REQUEST)))
if ipv6 is not None:
pending.append(asyncio.create_task(self._async_dns_query(self._http_domain, f'[{ipv6[0]}]', ipv6[1], delay=i * self.DELAY_DNS_REQUEST)))
i += 1
results_ok = []
results_fail = []
proton_api_not_available_errors = []
final_timestamp = time.time() + self.TIMEOUT_DNS_REQUEST
while len(pending) > 0 and len(results_ok) == 0:
done, pending = await asyncio.wait(pending, timeout=max(0.1, final_timestamp - time.time()), return_when=asyncio.FIRST_COMPLETED)
for task in done:
try:
results_ok += task.result()
except ProtonAPINotAvailable as e:
# That means that we were able to do a resolution, but it explicitly failed
proton_api_not_available_errors.append(e)
except Exception as e:
results_fail.append(e)
for task in pending:
task.cancel()
if proton_api_not_available_errors:
raise proton_api_not_available_errors[0]
if len(results_ok) == 0:
if len(self._alternative_routes) > 0:
# We have routes, but we were not able to resolve new ones. Just keep the old ones
return
else:
# No routes, and failed to get new ones
raise ProtonAPINotReachable("Couldn't resolve any alternative routing names")
domains = [x.domain for x in results_ok]
# Filter names that are in our results (we don't want duplicates)
self._alternative_routes = [
x for x in self._alternative_routes
if x.domain not in domains and x.expiration_time >= time.time()
]
# Add the results
self._alternative_routes += results_ok
# Sort them so we have the most recent on top
self._alternative_routes.sort(key=lambda x: x.expiration_time, reverse=True)
@property
def http_base_url(self):
if len(self._alternative_routes) == 0:
raise ProtonAPINotReachable("AlternativeRouting transport doesn't have any route")
path = urlparse(super().http_base_url).path
return f'https://{self._alternative_routes[0].domain}{path}'
@property
def tls_pinning_hashes(self):
return self._environment.tls_pinning_hashes_ar
async def async_api_request(
self, endpoint,
jsondata=None, data=None, additional_headers=None,
method=None, params=None, return_raw=False
):
if len(self._alternative_routes) == 0 or self._alternative_routes[0].expiration_time < time.time():
await self._get_alternative_routes()
return await super().async_api_request(endpoint, jsondata, data, additional_headers, method,
params, return_raw=return_raw)
|