1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
"""
Copyright (c) 2023 Proton AG
This file is part of Proton.
Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with ProtonVPN. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
from proton.session.formdata import FormData
from .. import Session
from ..exceptions import *
from .base import Transport, RawResponse
import json, base64, asyncio, aiohttp, hashlib
from OpenSSL import crypto
from typing import Iterable, Union, Optional
NOT_MODIFIED = 304
# It's stupid, but we have to inherit from aiohttp.Fingerprint to trigger the correct logic in aiohttp
class AiohttpCertkeyFingerprint(aiohttp.Fingerprint):
def __init__(self, fingerprints: Optional[Iterable[Union[bytes, str]]]) -> None:
if fingerprints is not None:
self._fingerprints = []
for fp in fingerprints:
if type(fp) == str:
self._fingerprints.append(base64.b64decode(fp))
else:
self._fingerprints.append(fp)
else:
self._fingerprints = None
def check(self, transport: asyncio.Transport) -> None:
if not transport.get_extra_info("sslcontext"):
return
# Can't check anything if we don't have fingerprints
if self._fingerprints is None:
return
sslobj = transport.get_extra_info("ssl_object")
cert = sslobj.getpeercert(binary_form=True)
cert_obj = crypto.load_certificate(crypto.FILETYPE_ASN1, cert)
pubkey_obj = cert_obj.get_pubkey()
pubkey = crypto.dump_publickey(crypto.FILETYPE_ASN1, pubkey_obj)
pubkey_hash = hashlib.sha256(pubkey).digest()
if pubkey_hash not in self._fingerprints:
# Dump certificate, so we can diagnose if needed with:
# base64 -d|openssl x509 -text -inform DER
raise ProtonAPINotReachable(f"TLS pinning verification failed: {base64.b64encode(cert)}")
class AiohttpTransport(Transport):
def __init__(self, session: Session, form_data_transformer: FormDataTransformer = None):
super().__init__(session)
self._form_data_transformer = form_data_transformer or FormDataTransformer()
@classmethod
def _get_priority(cls):
return 10
@property
def tls_pinning_hashes(self):
return self._environment.tls_pinning_hashes
@property
def http_base_url(self):
return self._environment.http_base_url
async def async_api_request(
self, endpoint,
jsondata=None, data=None, additional_headers=None,
method=None, params=None, return_raw=False
) -> dict | RawResponse:
if self.tls_pinning_hashes is not None:
ssl_specs = AiohttpCertkeyFingerprint(self.tls_pinning_hashes)
else:
# Validate SSL normally if we didn't have fingerprints
import ssl
ssl_specs = ssl.create_default_context()
ssl_specs.verify_mode = ssl.CERT_REQUIRED
headers = {
'x-pm-appversion': self._session.appversion,
'User-Agent': self._session.user_agent,
}
if self._session.authenticated:
headers['x-pm-uid'] = self._session.UID
headers['Authorization'] = 'Bearer ' + self._session.AccessToken
headers.update(self._environment.http_extra_headers)
async with aiohttp.ClientSession(headers=headers) as s:
# If we don't have an explicit method, default to get if there's no data, post otherwise
if method is None:
if not jsondata and not data:
fct = s.get
else:
fct = s.post
else:
fct = {
'get': s.get,
'post': s.post,
'put': s.put,
'delete': s.delete,
'patch': s.patch
}.get(method.lower())
if fct is None:
raise ValueError("Unknown method: {}".format(method))
form_data = self._form_data_transformer.to_aiohttp_form_data(data) if data else None
try:
async with fct(
self.http_base_url + endpoint, headers=additional_headers,
json=jsondata, data=form_data, params=params, ssl=ssl_specs
) as ret:
if return_raw:
return RawResponse(ret.status, tuple(ret.headers.items()),
await self._parse_json(ret, allow_unmodified=True))
ret_json = await self._parse_json(ret)
return ret_json
except aiohttp.ClientError as e:
raise ProtonAPINotReachable("Connection error.") from e
except asyncio.TimeoutError as e:
raise ProtonAPINotReachable("Timeout error.") from e
except ProtonAPINotReachable:
raise
except ProtonAPIError:
raise
except Exception as e:
raise ProtonAPIUnexpectedError(e)
async def _parse_json(self, ret, allow_unmodified=False):
if allow_unmodified and ret.status == NOT_MODIFIED:
return None
if ret.headers['content-type'] != 'application/json':
raise ProtonAPINotReachable("API returned non-json results")
try:
ret_json = await ret.json()
except json.decoder.JSONDecodeError:
raise ProtonAPIError(ret.status, dict(ret.headers), {})
if ret_json['Code'] not in [1000, 1001]:
raise ProtonAPIError(ret.status, dict(ret.headers), ret_json)
return ret_json
class FormDataTransformer:
@staticmethod
def to_aiohttp_form_data(form_data: FormData) -> aiohttp.FormData:
"""
Converts proton.session.data.FormData into aiohttp.FormData.
https://docs.aiohttp.org/en/stable/client_reference.html#formdata
"""
result = aiohttp.FormData()
for field in form_data.fields:
result.add_field(
name=field.name, value=field.value,
content_type=field.content_type, filename=field.filename
)
return result
|