File: auth.py

package info (click to toggle)
python-asyncprawcore 3.0.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,328 kB
  • sloc: python: 2,224; makefile: 4
file content (478 lines) | stat: -rw-r--r-- 19,339 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""Provides Authentication and Authorization classes."""

from __future__ import annotations

import inspect
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, Callable

from aiohttp import ClientRequest
from aiohttp.helpers import BasicAuth
from yarl import URL

from . import const
from .codes import codes
from .exceptions import InvalidInvocation, OAuthException, ResponseException

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Awaitable

    from aiohttp import ClientResponse

    from asyncprawcore.requestor import Requestor


class BaseAuthenticator(ABC):
    """Provide the base authenticator object that stores OAuth2 credentials."""

    @abstractmethod
    def _auth(self) -> BasicAuth:
        pass

    def __init__(
        self,
        requestor: Requestor,
        client_id: str,
        redirect_uri: str | None = None,
    ) -> None:
        """Represent a single authentication to Reddit's API.

        :param requestor: An instance of :class:`.Requestor`.
        :param client_id: The OAuth2 client ID to use with the session.
        :param redirect_uri: The redirect URI exactly as specified in your OAuth
            application settings on Reddit. This parameter is required if you want to
            use the :meth:`~.Authorizer.authorize_url` method, or the
            :meth:`~.Authorizer.authorize` method of the :class:`.Authorizer` class
            (default: ``None``).

        """
        self._requestor = requestor
        self.client_id = client_id
        self.redirect_uri = redirect_uri

    @asynccontextmanager
    async def _post(self, *, url: str, **data: Any) -> AsyncGenerator[ClientResponse]:
        async with self._requestor.request(
            "POST",
            url,
            auth=self._auth(),
            data=sorted(data.items()),
            headers={"Connection": "close"},
        ) as response:
            if response.status != codes["ok"]:
                raise ResponseException(response)
            yield response

    def authorize_url(self, duration: str, scopes: list[str], state: str, implicit: bool = False) -> str:
        """Return the URL used out-of-band to grant access to your application.

        :param duration: Either ``"permanent"`` or ``"temporary"``. ``"temporary"``
            authorizations generate access tokens that last only 1 hour. ``"permanent"``
            authorizations additionally generate a refresh token that can be
            indefinitely used to generate new hour-long access tokens. Only
            ``"temporary"`` can be specified if ``implicit`` is set to ``True``.
        :param scopes: A list of OAuth scopes to request authorization for.
        :param state: A string that will be reflected in the callback to
            ``redirect_uri``. Elements must be printable ASCII characters in the range
            ``0x20`` through ``0x7E`` inclusive. This value should be temporarily unique
            to the client for whom the URL was generated.
        :param implicit: Use the implicit grant flow (default: ``False``). This flow is
            only available for ``UntrustedAuthenticators``.

        :returns: URL to be used out-of-band for granting access to your application.

        :raises: :class:`.InvalidInvocation` if ``redirect_uri`` is not provided, if
            ``implicit`` is ``True`` and an authenticator other than
            :class:`.UntrustedAuthenticator` is used, or ``implicit`` is ``True`` and
            ``duration`` is ``"permanent"``.

        """
        if self.redirect_uri is None:
            msg = "redirect URI not provided"
            raise InvalidInvocation(msg)
        if implicit and not isinstance(self, UntrustedAuthenticator):
            msg = "Only UntrustedAuthenticator instances can use the implicit grant flow."
            raise InvalidInvocation(msg)
        if implicit and duration != "temporary":
            msg = "The implicit grant flow only supports temporary access tokens."
            raise InvalidInvocation(msg)

        params = {
            "client_id": self.client_id,
            "duration": duration,
            "redirect_uri": self.redirect_uri,
            "response_type": "token" if implicit else "code",
            "scope": " ".join(scopes),
            "state": state,
        }
        url = self._requestor.reddit_url + const.AUTHORIZATION_PATH
        request = ClientRequest("GET", URL(url), params=params)
        return str(request.url)

    async def revoke_token(self, token: str, token_type: str | None = None) -> None:
        """Ask Reddit to revoke the provided token.

        :param token: The access or refresh token to revoke.
        :param token_type: When provided, hint to Reddit what the token type is for a
            possible efficiency gain. The value can be either ``"access_token"`` or
            ``"refresh_token"``.

        """
        data = {"token": token}
        if token_type is not None:
            data["token_type_hint"] = token_type
        url = self._requestor.reddit_url + const.REVOKE_TOKEN_PATH
        async with self._post(url=url, **data) as _:
            pass  # The response is not used.


class BaseAuthorizer:
    """Superclass for OAuth2 authorization tokens and scopes."""

    AUTHENTICATOR_CLASS: tuple | type = BaseAuthenticator

    def __init__(self, authenticator: BaseAuthenticator) -> None:
        """Represent a single authorization to Reddit's API.

        :param authenticator: An instance of :class:`.BaseAuthenticator`.

        """
        self._authenticator = authenticator
        self._clear_access_token()
        self._validate_authenticator()

    def _clear_access_token(self):
        self._expiration_timestamp_ns: int
        self.access_token: str | None = None
        self.scopes: set[str] | None = None

    async def _request_token(self, **data: Any):
        url = self._authenticator._requestor.reddit_url + const.ACCESS_TOKEN_PATH
        pre_request_timestamp_ns = time.monotonic_ns()
        async with self._authenticator._post(url=url, **data) as response:
            payload = await response.json()
        if "error" in payload:  # Why are these OKAY responses?
            raise OAuthException(response, payload["error"], payload.get("error_description"))

        self._expiration_timestamp_ns = pre_request_timestamp_ns + (payload["expires_in"] + 10) * const.NANOSECONDS
        self.access_token = payload["access_token"]
        if "refresh_token" in payload:
            self.refresh_token = payload["refresh_token"]
        self.scopes = set(payload["scope"].split(" "))

    def _validate_authenticator(self):
        if not isinstance(self._authenticator, self.AUTHENTICATOR_CLASS):
            msg = "Must use an authenticator of type"
            if isinstance(self.AUTHENTICATOR_CLASS, type):
                msg += f" {self.AUTHENTICATOR_CLASS.__name__}."
            else:
                msg += f" {' or '.join([i.__name__ for i in self.AUTHENTICATOR_CLASS])}."
            raise InvalidInvocation(msg)

    def is_valid(self) -> bool:
        """Return whether the :class`.Authorizer` is ready to authorize requests.

        A ``True`` return value does not guarantee that the ``access_token`` is actually
        valid on the server side.

        """
        return self.access_token is not None and time.monotonic_ns() < self._expiration_timestamp_ns

    async def revoke(self) -> None:
        """Revoke the current Authorization."""
        if self.access_token is None:
            msg = "no token available to revoke"
            raise InvalidInvocation(msg)

        await self._authenticator.revoke_token(self.access_token, "access_token")
        self._clear_access_token()


class TrustedAuthenticator(BaseAuthenticator):
    """Store OAuth2 authentication credentials for web, or script type apps."""

    RESPONSE_TYPE: str = "code"

    def __init__(
        self,
        requestor: Requestor,
        client_id: str,
        client_secret: str,
        redirect_uri: str | None = None,
    ) -> None:
        """Represent a single authentication to Reddit's API.

        :param requestor: An instance of :class:`.Requestor`.
        :param client_id: The OAuth2 client ID to use with the session.
        :param client_secret: The OAuth2 client secret to use with the session.
        :param redirect_uri: The redirect URI exactly as specified in your OAuth
            application settings on Reddit. This parameter is required if you want to
            use the :meth:`~.Authorizer.authorize_url` method, or the
            :meth:`~.Authorizer.authorize` method of the :class:`.Authorizer` class
            (default: ``None``).

        """
        super().__init__(requestor, client_id, redirect_uri)
        self.client_secret = client_secret

    def _auth(self) -> BasicAuth:
        return BasicAuth(self.client_id, self.client_secret)


class UntrustedAuthenticator(BaseAuthenticator):
    """Store OAuth2 authentication credentials for installed applications."""

    def _auth(self) -> BasicAuth:
        return BasicAuth(self.client_id, "")


class Authorizer(BaseAuthorizer):
    """Manages OAuth2 authorization tokens and scopes."""

    def __init__(
        self,
        authenticator: BaseAuthenticator,
        *,
        post_refresh_callback: (Callable[[Authorizer], Awaitable[None]] | Callable[[Authorizer], None] | None) = None,
        pre_refresh_callback: (Callable[[Authorizer], Awaitable[None]] | Callable[[Authorizer], None] | None) = None,
        refresh_token: str | None = None,
    ) -> None:
        """Represent a single authorization to Reddit's API.

        :param authenticator: An instance of a subclass of :class:`.BaseAuthenticator`.
        :param post_refresh_callback: When a single-argument synchronous or asynchronous
            function is passed, the function will be called prior to refreshing the
            access and refresh tokens. The argument to the callback is the
            :class:`.Authorizer` instance. This callback can be used to inspect and
            modify the attributes of the :class:`.Authorizer`.
        :param pre_refresh_callback: When a single-argument function synchronous or
            asynchronous is passed, the function will be called after refreshing the
            access and refresh tokens. The argument to the callback is the
            :class:`.Authorizer` instance. This callback can be used to inspect and
            modify the attributes of the :class:`.Authorizer`.
        :param refresh_token: Enables the ability to refresh the authorization.

        """
        super().__init__(authenticator)
        self._post_refresh_callback = post_refresh_callback
        self._pre_refresh_callback = pre_refresh_callback
        self.refresh_token = refresh_token

    async def authorize(self, code: str) -> None:
        """Obtain and set authorization tokens based on ``code``.

        :param code: The code obtained by an out-of-band authorization request to
            Reddit.

        """
        if self._authenticator.redirect_uri is None:
            msg = "redirect URI not provided"
            raise InvalidInvocation(msg)
        await self._request_token(
            code=code,
            grant_type="authorization_code",
            redirect_uri=self._authenticator.redirect_uri,
        )

    async def refresh(self) -> None:
        """Obtain a new access token from the refresh_token."""
        if self._pre_refresh_callback:
            if inspect.iscoroutinefunction(self._pre_refresh_callback):
                await self._pre_refresh_callback(self)
            else:
                self._pre_refresh_callback(self)
        if self.refresh_token is None:
            msg = "refresh token not provided"
            raise InvalidInvocation(msg)
        await self._request_token(grant_type="refresh_token", refresh_token=self.refresh_token)
        if self._post_refresh_callback:
            if inspect.iscoroutinefunction(self._post_refresh_callback):
                await self._post_refresh_callback(self)
            else:
                self._post_refresh_callback(self)

    async def revoke(self, only_access: bool = False) -> None:
        """Revoke the current Authorization.

        :param only_access: When explicitly set to ``True``, do not evict the refresh
            token if one is set.

        Revoking a refresh token will in-turn revoke all access tokens associated with
        that authorization.

        """
        if only_access or self.refresh_token is None:
            await super().revoke()
        else:
            await self._authenticator.revoke_token(self.refresh_token, "refresh_token")
            self._clear_access_token()
            self.refresh_token = None


class ImplicitAuthorizer(BaseAuthorizer):
    """Manages implicit installed-app type authorizations."""

    AUTHENTICATOR_CLASS = UntrustedAuthenticator

    def __init__(
        self,
        authenticator: UntrustedAuthenticator,
        access_token: str,
        expires_in: int,
        scope: str,
    ) -> None:
        """Represent a single implicit authorization to Reddit's API.

        :param authenticator: An instance of :class:`.UntrustedAuthenticator`.
        :param access_token: The access_token obtained from Reddit via callback to the
            authenticator's ``redirect_uri``.
        :param expires_in: The number of seconds the ``access_token`` is valid for. The
            origin of this value was returned from Reddit via callback to the
            authenticator's redirect uri. Note, you may need to subtract an offset
            before passing in this number to account for a delay between when Reddit
            prepared the response, and when you make this function call.
        :param scope: A space-delimited string of Reddit OAuth2 scope names as returned
            from Reddit in the callback to the authenticator's redirect uri.

        """
        super().__init__(authenticator)
        self._expiration_timestamp_ns = time.monotonic_ns() + expires_in * const.NANOSECONDS
        self.access_token = access_token
        self.scopes = set(scope.split(" "))


class ReadOnlyAuthorizer(Authorizer):
    """Manages authorizations that are not associated with a Reddit account.

    While the ``"*"`` scope will be available, some endpoints simply will not work due
    to the lack of an associated Reddit account.

    """

    AUTHENTICATOR_CLASS = TrustedAuthenticator

    def __init__(
        self,
        authenticator: BaseAuthenticator,
        scopes: list[str] | None = None,
    ) -> None:
        """Represent a ReadOnly authorization to Reddit's API.

        :param scopes: A list of OAuth scopes to request authorization for (default:
            ``None``). The scope ``"*"`` is requested when the default argument is used.

        """
        super().__init__(authenticator)
        self._scopes = scopes

    async def refresh(self) -> None:
        """Obtain a new ReadOnly access token."""
        additional_kwargs = {}
        if self._scopes:
            additional_kwargs["scope"] = " ".join(self._scopes)
        await self._request_token(grant_type="client_credentials", **additional_kwargs)


class ScriptAuthorizer(Authorizer):
    """Manages personal-use script type authorizations.

    Only users who are listed as developers for the application will be granted access
    tokens.

    """

    AUTHENTICATOR_CLASS = TrustedAuthenticator

    def __init__(
        self,
        authenticator: BaseAuthenticator,
        username: str | None,
        password: str | None,
        two_factor_callback: Callable | None = None,
        scopes: list[str] | None = None,
    ) -> None:
        """Represent a single personal-use authorization to Reddit's API.

        :param authenticator: An instance of :class:`.TrustedAuthenticator`.
        :param username: The Reddit username of one of the application's developers.
        :param password: The password associated with ``username``.
        :param two_factor_callback: A synchronous or asynchronous function that returns
            OTPs (One-Time Passcodes), also known as 2FA auth codes. If this function is
            provided, asyncprawcore will call it when authenticating.
        :param scopes: A list of OAuth scopes to request authorization for (default:
            ``None``). The scope ``"*"`` is requested when the default argument is used.

        """
        super().__init__(authenticator)
        self._password = password
        self._scopes = scopes
        self._two_factor_callback = two_factor_callback
        self._username = username

    async def refresh(self) -> None:
        """Obtain a new personal-use script type access token."""
        additional_kwargs = {}
        if self._scopes:
            additional_kwargs["scope"] = " ".join(self._scopes)
        if self._two_factor_callback:
            if inspect.iscoroutinefunction(self._two_factor_callback):
                two_factor_code = await self._two_factor_callback()
            else:
                two_factor_code = self._two_factor_callback()
            if two_factor_code:
                additional_kwargs["otp"] = two_factor_code
        await self._request_token(
            grant_type="password",
            username=self._username,
            password=self._password,
            **additional_kwargs,
        )


class DeviceIDAuthorizer(BaseAuthorizer):
    """Manages app-only OAuth2 for 'installed' applications.

    While the ``"*"`` scope will be available, some endpoints simply will not work due
    to the lack of an associated Reddit account.

    """

    AUTHENTICATOR_CLASS = (TrustedAuthenticator, UntrustedAuthenticator)

    def __init__(
        self,
        authenticator: BaseAuthenticator,
        device_id: str | None = None,
        scopes: list[str] | None = None,
    ) -> None:
        """Represent an app-only OAuth2 authorization for 'installed' apps.

        :param authenticator: An instance of :class:`.UntrustedAuthenticator` or
            :class:`.TrustedAuthenticator`.
        :param device_id: A unique ID (20-30 character ASCII string) (default:
            ``None``). ``device_id`` is set to ``"DO_NOT_TRACK_THIS_DEVICE"`` when the
            default argument is used. For more information about this parameter, see:
            https://github.com/reddit/reddit/wiki/OAuth2#application-only-oauth
        :param scopes: A list of OAuth scopes to request authorization for (default:
            ``None``). The scope ``"*"`` is requested when the default argument is used.

        """
        if device_id is None:
            device_id = "DO_NOT_TRACK_THIS_DEVICE"
        super().__init__(authenticator)
        self._device_id = device_id
        self._scopes = scopes

    async def refresh(self) -> None:
        """Obtain a new access token."""
        additional_kwargs = {}
        if self._scopes:
            additional_kwargs["scope"] = " ".join(self._scopes)
        grant_type = "https://oauth.reddit.com/grants/installed_client"
        await self._request_token(
            grant_type=grant_type,
            device_id=self._device_id,
            **additional_kwargs,
        )