File: requests.py

package info (click to toggle)
python-proton-core 0.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 568 kB
  • sloc: python: 3,672; makefile: 19
file content (156 lines) | stat: -rw-r--r-- 5,586 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Copyright (c) 2023 Proton AG

This file is part of Proton.

Proton is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Proton is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with ProtonVPN.  If not, see <https://www.gnu.org/licenses/>.
"""
import io

import requests

from ..formdata import FormData
from ..exceptions import *
from .base import Transport, RawResponse

import json

NOT_MODIFIED = 304

class RequestsTransport(Transport):
    """ This is a simple transport based on the requests library, it's not advised to use in production """
    def __init__(self, session, requests_session: requests.Session = None):
        super().__init__(session)
        
        self._s = requests_session or requests.Session()

    @classmethod
    def _get_priority(cls):
        try:
            return 3
        except ImportError:
            return None

    def _build_raw(self, ret):
        response = RawResponse(ret.status_code,
                               tuple(ret.headers.items()), None, None)

        if ret.status_code == NOT_MODIFIED:
            return response

        response.data = ret.content

        if ret.headers['content-type'] != 'application/octet-stream':
            response.json = self._parse_json(ret, allow_unmodified=True)

        return response

    def _parse_json(self, ret, allow_unmodified=False):
        if allow_unmodified and ret.status_code == NOT_MODIFIED:
            return None

        try:
            ret_json = ret.json()
        except json.decoder.JSONDecodeError:
            raise ProtonAPIError(ret.status_code, dict(ret.headers), {})

        if ret_json['Code'] not in [1000, 1001]:
            raise ProtonAPIError(ret.status_code, dict(ret.headers), ret_json)

        return ret_json

    async def async_api_request(
        self, endpoint,
        jsondata=None, data=None, additional_headers=None,
        method=None, params=None, return_raw=False
    ):
        self._s.headers['x-pm-appversion'] = self._session.appversion
        self._s.headers['User-Agent'] = self._session.user_agent

        if self._session.authenticated:
            self._s.headers['x-pm-uid'] = self._session.UID
            self._s.headers['Authorization'] = 'Bearer ' + self._session.AccessToken

        # If we don't have an explicit method, default to get if there's no data, post otherwise
        if method is None:
            if not jsondata and not data:
                fct = self._s.get
            else:
                fct = self._s.post
        else:
            fct = {
                'get': self._s.get,
                'post': self._s.post,
                'put': self._s.put,
                'delete': self._s.delete,
                'patch': self._s.patch
            }.get(method.lower())

            if fct is None:
                raise ValueError("Unknown method: {}".format(method))

        data_dict = self._get_requests_data(data) if data else None
        files_dict = self._get_requests_files(data) if data else None
        try:
            ret = fct(
                self._environment.http_base_url + endpoint,
                headers=additional_headers,
                json=jsondata,
                data=data_dict,
                files=files_dict,
                params=params
            )
        except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
            raise ProtonAPINotReachable(e)
        except (Exception, requests.exceptions.BaseHTTPError) as e:
            raise ProtonAPIUnexpectedError(e)

        if return_raw:
            return self._build_raw(ret)

        ret_json = self._parse_json(ret)

        return ret_json

    @staticmethod
    def _get_requests_data(form_data: FormData) -> dict:
        """
        Converts the FormData instance to a dict that can be passed
        as the data parameter in requests (e.g. `requests.post(url, data=data)`.

        File-like fields are ignored, use `_get_requests_files` for those.
        """
        return {
            field.name: field.value
            for field in form_data.fields if not isinstance(field.value, io.IOBase)
        }

    @staticmethod
    def _get_requests_files(form_data: FormData) -> dict:
        """
        Extracts the file-like fields to a dict that can be passed as the `files`
        parameter in requests (e.g. `requests.post(url, files=files`).
        """
        # From https://requests.readthedocs.io/en/latest/api/#requests.request:
        # files – (optional) Dictionary of 'name': file-like-objects
        # (or {'name': file-tuple}) for multipart encoding upload. file-tuple
        # can be a 2-tuple ('filename', fileobj), 3-tuple ('filename', fileobj, 'content_type')
        # or a 4-tuple ('filename', fileobj, 'content_type', custom_headers),
        # where 'content-type' is a string defining the content type of the
        # given file and custom_headers a dict-like object containing additional
        # headers to add for the file.
        return {
            field.name: (field.filename, field.value, field.content_type)
            for field in form_data.fields if isinstance(field.value, io.IOBase)
        }