File: aiohttp.py

package info (click to toggle)
python-proton-core 0.4.0-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 540 kB
  • sloc: python: 3,574; makefile: 15
file content (179 lines) | stat: -rw-r--r-- 6,776 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations

from proton.session.formdata import FormData
from .. import Session
from ..exceptions import *
from .base import Transport, RawResponse

import json, base64, asyncio, aiohttp, hashlib
from OpenSSL import crypto
from typing import Iterable, Union, Optional

NOT_MODIFIED = 304

# It's stupid, but we have to inherit from aiohttp.Fingerprint to trigger the correct logic in aiohttp
class AiohttpCertkeyFingerprint(aiohttp.Fingerprint):
    def __init__(self, fingerprints: Optional[Iterable[Union[bytes, str]]]) -> None:
        if fingerprints is not None:
            self._fingerprints = []
            for fp in fingerprints:
                if type(fp) == str:
                    self._fingerprints.append(base64.b64decode(fp))
                else:
                    self._fingerprints.append(fp)
        else:
            self._fingerprints = None

    def check(self, transport: asyncio.Transport) -> None:
        if not transport.get_extra_info("sslcontext"):
            return
        # Can't check anything if we don't have fingerprints
        if self._fingerprints is None:
            return
        sslobj = transport.get_extra_info("ssl_object")
        cert = sslobj.getpeercert(binary_form=True)

        cert_obj = crypto.load_certificate(crypto.FILETYPE_ASN1, cert)
        pubkey_obj = cert_obj.get_pubkey()
        pubkey = crypto.dump_publickey(crypto.FILETYPE_ASN1, pubkey_obj)

        pubkey_hash = hashlib.sha256(pubkey).digest()

        if pubkey_hash not in self._fingerprints:
            # Dump certificate, so we can diagnose if needed with:
            # base64 -d|openssl x509 -text -inform DER 
            raise ProtonAPINotReachable(f"TLS pinning verification failed: {base64.b64encode(cert)}")


class AiohttpTransport(Transport):
    def __init__(self, session: Session, form_data_transformer: FormDataTransformer = None):
        super().__init__(session)
        self._form_data_transformer = form_data_transformer or FormDataTransformer()

    @classmethod
    def _get_priority(cls):
        return 10

    @property
    def tls_pinning_hashes(self):
        return self._environment.tls_pinning_hashes

    @property
    def http_base_url(self):
        return self._environment.http_base_url

    async def async_api_request(
        self, endpoint,
        jsondata=None, data=None, additional_headers=None,
        method=None, params=None, return_raw=False
    ) -> dict | RawResponse:
        if self.tls_pinning_hashes is not None:
            ssl_specs = AiohttpCertkeyFingerprint(self.tls_pinning_hashes)
        else:
            # Validate SSL normally if we didn't have fingerprints
            import ssl
            ssl_specs = ssl.create_default_context()
            ssl_specs.verify_mode = ssl.CERT_REQUIRED

        headers = {
            'x-pm-appversion': self._session.appversion,
            'User-Agent': self._session.user_agent,
        }
        if self._session.authenticated:
            headers['x-pm-uid'] = self._session.UID
            headers['Authorization'] = 'Bearer ' + self._session.AccessToken
        headers.update(self._environment.http_extra_headers)

        async with aiohttp.ClientSession(headers=headers) as s:
            # If we don't have an explicit method, default to get if there's no data, post otherwise
            if method is None:
                if not jsondata and not data:
                    fct = s.get
                else:
                    fct = s.post
            else:
                fct = {
                    'get': s.get,
                    'post': s.post,
                    'put': s.put,
                    'delete': s.delete,
                    'patch': s.patch
                }.get(method.lower())

                if fct is None:
                    raise ValueError("Unknown method: {}".format(method))

            form_data = self._form_data_transformer.to_aiohttp_form_data(data) if data else None

            try:
                async with fct(
                        self.http_base_url + endpoint, headers=additional_headers,
                        json=jsondata, data=form_data, params=params, ssl=ssl_specs
                ) as ret:
                    if return_raw:
                        return RawResponse(ret.status, tuple(ret.headers.items()),
                                           await self._parse_json(ret, allow_unmodified=True))

                    ret_json = await self._parse_json(ret)

                return ret_json
            except aiohttp.ClientError as e:
                raise ProtonAPINotReachable("Connection error.") from e
            except asyncio.TimeoutError as e:
                raise ProtonAPINotReachable("Timeout error.") from e
            except ProtonAPINotReachable:
                raise
            except ProtonAPIError:
                raise
            except Exception as e:
                raise ProtonAPIUnexpectedError(e)

    async def _parse_json(self, ret, allow_unmodified=False):
        if allow_unmodified and ret.status == NOT_MODIFIED:
            return None

        if ret.headers['content-type'] != 'application/json':
            raise ProtonAPINotReachable("API returned non-json results")
        try:
            ret_json = await ret.json()
        except json.decoder.JSONDecodeError:
            raise ProtonAPIError(ret.status, dict(ret.headers), {})

        if ret_json['Code'] not in [1000, 1001]:
            raise ProtonAPIError(ret.status, dict(ret.headers), ret_json)

        return ret_json


class FormDataTransformer:
    @staticmethod
    def to_aiohttp_form_data(form_data: FormData) -> aiohttp.FormData:
        """
        Converts proton.session.data.FormData into aiohttp.FormData.
        https://docs.aiohttp.org/en/stable/client_reference.html#formdata
        """
        result = aiohttp.FormData()
        for field in form_data.fields:
            result.add_field(
                name=field.name, value=field.value,
                content_type=field.content_type, filename=field.filename
            )
        return result