File: utility.py

package info (click to toggle)
python-mastodon 2.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 22,836 kB
  • sloc: python: 9,438; makefile: 206; sql: 98; sh: 27
file content (353 lines) | stat: -rw-r--r-- 16,296 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# utility.py - utility functions, externally usable

import re
import dateutil
import datetime
import copy
import warnings

from mastodon.errors import MastodonAPIError, MastodonIllegalArgumentError, MastodonNotFoundError, MastodonVersionError
from mastodon.compat import IMPL_HAS_BLURHASH, blurhash, IMPL_HAS_GRAPHEME, grapheme
from mastodon.internals import Mastodon as Internals

from mastodon.versions import parse_version_string, max_version, api_version

from typing import Optional, Union, Dict, Iterator, Tuple, List
from mastodon.return_types import PaginatableList, PaginationInfo, PaginatableList, MediaAttachment
from mastodon.types_base import Entity, try_cast

from ._url_regex import url_regex
import unicodedata


class Mastodon(Internals):
    def set_language(self, lang: str):
        """
        Set the locale Mastodon will use to generate responses. Valid parameters are all ISO 639-1 (two letter) or, for languages that do
        not have one, 639-3 (three letter) language codes. This affects some error messages (those related to validation) and trends.
        """
        self.lang = lang

    def retrieve_mastodon_version(self) -> str:
        """
        Determine installed Mastodon version and set major, minor and patch (not including RC info) accordingly.

        Returns the version string, possibly including rc info.
        """
        try:
            version_str = self.__normalize_version_string(
                self.__instance()["version"])
            self.__version_check_worked = True
        except Exception as e:
            # instance() was added in 1.1.0, so our best guess is 1.0.0.
            version_str = "1.0.0"
            self.__version_check_worked = False
        self.mastodon_major, self.mastodon_minor, self.mastodon_patch = parse_version_string(
            version_str)

        # If the instance has an API version, we store that as well.
        # If we have a version >= 4.3.0 but no API version, we throw a warning that this is a Weird Implementation,
        # which might help with adoption of the API versioning or at least give us a better picture of how it is going.
        found_api_version = False
        try:
            instance_v2_info = self.__instance_v2()
            if "api_versions" in instance_v2_info and instance_v2_info["api_versions"]:
                if "mastodon" in instance_v2_info["api_versions"]:
                    self.mastodon_api_version = int(
                        instance_v2_info["api_versions"]["mastodon"])
                    found_api_version = True
        except MastodonNotFoundError:
            pass
        except MastodonVersionError:
            pass

        self.__version_check_tried = True
        if not found_api_version and self.verify_minimum_version("4.3.0", cached=True):
            warnings.warn(
                "Mastodon version is detected as >= 4.3.0, but no API version found. Please report this.")
        return version_str

    def verify_minimum_version(self, version_str: str, cached: bool = False) -> bool:
        """
        Update version info from server and verify that at least the specified version is present.

        If you specify "cached", the version info update part is skipped.

        Returns True if version requirement is satisfied, False if not.
        """
        if not cached or not self.__version_check_tried or not self.__version_check_worked:
            self.retrieve_mastodon_version()
        major, minor, patch = parse_version_string(version_str)
        if major > self.mastodon_major:
            return False
        elif major == self.mastodon_major and minor > self.mastodon_minor:
            return False
        elif major == self.mastodon_major and minor == self.mastodon_minor and patch > self.mastodon_patch:
            return False
        return True

    def get_approx_server_time(self) -> datetime:
        """
        Retrieve the approximate server time

        We parse this from the hopefully present "Date" header, but make no effort to compensate for latency.
        """
        response = self.__api_request("HEAD", "/", return_response_object=True)
        if 'Date' in response.headers:
            server_time_datetime = dateutil.parser.parse(
                response.headers['Date'])

            # Make sure we're in local time
            epoch_time = self.__datetime_to_epoch(server_time_datetime)
            return datetime.datetime.fromtimestamp(epoch_time)
        else:
            raise MastodonAPIError("No server time in response.")

    ###
    # Blurhash utilities
    ###
    def decode_blurhash(self, media_dict: MediaAttachment, out_size: Tuple[int, int] = (16, 16), size_per_component: bool = True, return_linear: bool = True) -> List[List[List[float]]]:
        """
        Basic media-dict blurhash decoding.

        out_size is the desired result size in pixels, either absolute or per blurhash
        component (this is the default).

        By default, this function will return the image as linear RGB, ready for further
        scaling operations. If you want to display the image directly, set return_linear
        to False.

        Returns the decoded blurhash image as a three-dimensional list: [height][width][3],
        with the last dimension being RGB colours.

        For further info and tips for advanced usage, refer to the documentation for the
        blurhash module: https://github.com/halcy/blurhash-python
        """
        if not IMPL_HAS_BLURHASH:
            raise NotImplementedError(
                'To use the blurhash functions, please install the blurhash Python module.')

        # Figure out what size to decode to
        decode_components_x, decode_components_y = blurhash.components(
            media_dict["blurhash"])
        if size_per_component:
            decode_size_x = decode_components_x * out_size[0]
            decode_size_y = decode_components_y * out_size[1]
        else:
            decode_size_x = out_size[0]
            decode_size_y = out_size[1]

        # Decode
        decoded_image = blurhash.decode(
            media_dict["blurhash"], decode_size_x, decode_size_y, linear=return_linear)

        # And that's pretty much it.
        return decoded_image

    ###
    # Pagination
    ###
    def fetch_next(self, previous_page: Union[PaginatableList[Entity], Entity, PaginationInfo]) -> Optional[Union[PaginatableList[Entity], Entity]]:
        """
        Fetches the next page of results of a paginated request. Pass in the
        previous page in its entirety, or the pagination information dict
        returned as a part of that pages last status ('_pagination_next').

        Returns the next page or None if no further data is available.
        """
        # Duck typing to keep compat with potentially persisted pagination info
        if hasattr(previous_page, "_pagination_next"):
            params = copy.deepcopy(previous_page._pagination_next)
        elif isinstance(previous_page, dict) and '_pagination_next' in previous_page:
            params = copy.deepcopy(previous_page['_pagination_next'])
        else:
            params = copy.deepcopy(previous_page)

        if params is None:
            return None

        is_pagination_dict = False
        if isinstance(previous_page, dict):
            if all(key in ['_pagination_method', '_pagination_endpoint', 'min_id', 'max_id', 'since_id', 'limit'] for key in previous_page):
                is_pagination_dict = True

        if not "_pagination_method" in params and not "_pagination_endpoint" in params:
            raise MastodonIllegalArgumentError(
                "The passed object is not paginatable")

        method = params['_pagination_method']
        del params['_pagination_method']

        endpoint = params['_pagination_endpoint']
        del params['_pagination_endpoint']

        response_type = None
        if '_mastopy_type' in params:
            response_type = params['_mastopy_type']
            del params['_mastopy_type']

        force_pagination = False
        if not isinstance(previous_page, list):
            force_pagination = True

        if not is_pagination_dict:
            return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
        else:
            return self.__api_request(method, endpoint, params, override_type=response_type)

    def fetch_previous(self, next_page: Union[PaginatableList[Entity], Entity, PaginationInfo]) -> Optional[Union[PaginatableList[Entity], Entity]]:
        """
        Fetches the previous page of results of a paginated request. Pass in the
        previous page in its entirety, or the pagination information dict
        returned as a part of that pages first status ('_pagination_prev').

        Returns the previous page or None if no further data is available.
        """
        # Duck typing to keep compat with potentially persisted pagination info
        if hasattr(next_page, "_pagination_prev"):
            params = copy.deepcopy(next_page._pagination_prev)
        elif isinstance(next_page, dict) and '_pagination_prev' in next_page:
            params = copy.deepcopy(next_page['_pagination_prev'])
        else:
            params = copy.deepcopy(next_page)

        if params is None:
            return None

        is_pagination_dict = False
        if isinstance(next_page, dict):
            if all(key in ['_pagination_method', '_pagination_endpoint', 'min_id', 'max_id', 'since_id', 'limit'] for key in next_page):
                is_pagination_dict = True

        if not "_pagination_method" in params and not "_pagination_endpoint" in params:
            raise MastodonIllegalArgumentError(
                "The passed object is not paginatable")

        method = params['_pagination_method']
        del params['_pagination_method']

        endpoint = params['_pagination_endpoint']
        del params['_pagination_endpoint']

        response_type = None
        if '_mastopy_type' in params:
            response_type = params['_mastopy_type']
            del params['_mastopy_type']

        force_pagination = False
        if not isinstance(next_page, list):
            force_pagination = True

        if not is_pagination_dict:
            return self.__api_request(method, endpoint, params, force_pagination=force_pagination, override_type=response_type)
        else:
            return self.__api_request(method, endpoint, params, override_type=response_type)

    def fetch_remaining(self, first_page: PaginatableList[Entity]) -> PaginatableList[Entity]:
        """
        Fetches all the remaining pages of a paginated request starting from a
        first page and returns the entire set of results (including the first page
        that was passed in) as a big list.

        Be careful, as this might generate a lot of requests, depending on what you are
        fetching, and might cause you to run into rate limits very quickly.

        Does not work with grouped notifications, since they use a somewhat weird, inside-out
        pagination scheme. If you need to access these in a paginated way, use fetch_next and fetch_previous
        directly.
        """
        first_page = copy.deepcopy(first_page)

        all_pages = []
        current_page = first_page
        while current_page is not None and len(current_page) > 0:
            all_pages.extend(current_page)
            current_page = self.fetch_next(current_page)

        return all_pages

    def get_pagination_info(self, page: PaginatableList[Entity], pagination_direction: str) -> Optional[PaginationInfo]:
        """
        Extracts pagination information from a paginated response.

        Returns a PaginationInfo dictionary containing pagination information, or None if not available.

        The resulting PaginationInfo is best treated as opaque, though is unlikely to change.
        """
        if hasattr(page, "_pagination_next") and pagination_direction == "next":
            return try_cast(PaginationInfo, page._pagination_next)
        elif hasattr(page, "_pagination_prev") and pagination_direction == "previous":
            return try_cast(PaginationInfo, page._pagination_prev)
        else:
            return None

    def pagination_iterator(self, start_page: Union[PaginatableList[Entity], PaginationInfo], direction: str = "next", return_pagination_info: bool = False) -> Iterator[Entity]:
        """
        Returns an iterator that will yield all entries in a paginated request,
        starting from the given start_page (can also be just the PaginationInfo, in which case the
        first returned thing will be the result of fetch_next or fetch_previous, depending on the direction).
        and fetching new pages as needed, and breaks when no more pages are available.

        Set direction to "next" to iterate forward, or "previous" to iterate backwards.

        If return_pagination_info is True, the iterator will instead yield tuples of (Entity, PaginationInfo),
        where PaginationInfo is a dictionary containing pagination information for the current page and direction.

        Does not work with grouped notifications, since they use a somewhat weird, inside-out
        pagination scheme. If you need to access these in a paginated way, use fetch_next and fetch_previous
        directly.
        """
        if direction not in ["next", "previous"]:
            raise MastodonIllegalArgumentError(
                "Invalid pagination direction: {}".format(direction))

        # Don't rely on python type info here, this is a Danger Zone. Instead, check for
        # _pagination_endpoint
        if hasattr(start_page, "_pagination_endpoint") or (isinstance(start_page, dict) and '_pagination_endpoint' in start_page):
            current_page = self.fetch_next(
                start_page) if direction == "next" else self.fetch_previous(start_page)
        else:
            current_page = start_page

        while current_page is not None and len(current_page) > 0:
            for entry in current_page:
                if return_pagination_info:
                    yield (entry, self.get_pagination_info(current_page, direction))
                else:
                    print("CURRENT PAGE IS", current_page)
                    print("YIELDING ENTRY: ", entry)
                    yield entry

            if direction == "next":
                current_page = self.fetch_next(current_page)
            else:
                current_page = self.fetch_previous(current_page)

    @staticmethod
    def get_status_length(text: str, spoiler_text: str = "") -> int:
        """
        For a given status `text` and `spoiler_text`, return how many characters this status counts as
        when computing the status length and comparing it against the limit.

        Note that there are other limits you may run into, such as the maximum length of a URL, or the
        maximum length of a usernames domain part. But as long as you do *normal* things, this function
        will return the correct length for the status text.
        """
        if not IMPL_HAS_GRAPHEME:
            raise NotImplementedError(
                'To use the get_status_length function, please install the grapheme Python module.')

        username_regex = re.compile(
            r'(^|[^/\w])@(([a-z0-9_]+)@[a-z0-9\.\-]+[a-z0-9]+)', re.IGNORECASE)

        def countable_text(input_text: str) -> str:
            # Transform text such that it has the correct length for counting
            # post text lengths against the limit
            def _url_repl(m: re.Match) -> str:
                return m.group(2) + ("x" * 23)
            text = url_regex.sub(_url_repl, input_text)
            text = username_regex.sub(r'\1@\3', text)
            return text

        return grapheme.length(countable_text(text)) + grapheme.length(spoiler_text)