1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
"""Authorization provider for the Logi Circle API wrapper"""
# coding: utf-8
# vim:sw=4:ts=4:et:
import os
import logging
import pickle
from urllib.parse import urlencode
import aiohttp
import asyncio
from .const import AUTH_BASE, AUTH_ENDPOINT, TOKEN_ENDPOINT
from .exception import AuthorizationFailed, NotAuthorized, SessionInvalidated
_LOGGER = logging.getLogger(__name__)
class AuthProvider():
"""OAuth2 client for the Logi Circle API"""
def __init__(self, client_id, client_secret, redirect_uri, scopes, cache_file, logi_base):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.scopes = scopes
self.cache_file = cache_file
self.logi = logi_base
self.tokens = self._read_token()
self.invalid = False
self.session = None
self._lock = asyncio.Lock()
@property
def authorized(self):
"""Checks if the current client ID has a refresh token"""
return self.client_id in self.tokens and 'refresh_token' in self.tokens[self.client_id]
@property
def authorize_url(self):
"""Returns the authorization URL for the Logi Circle API"""
query_string = {"response_type": "code",
"client_id": self.client_id,
"client_secret": self.client_secret,
"redirect_uri": self.redirect_uri,
"scope": self.scopes}
return '%s?%s' % (AUTH_BASE + AUTH_ENDPOINT, urlencode(query_string))
@property
def refresh_token(self):
"""The refresh token granted by the Logi Circle API for the current client ID."""
if not self.authorized:
return None
return self.tokens[self.client_id].get('refresh_token')
@property
def access_token(self):
"""The access token granted by the Logi Circle API for the current client ID."""
if not self.authorized:
return None
return self.tokens[self.client_id].get('access_token')
async def authorize(self, code):
"""Request a bearer token with the supplied authorization code"""
authorize_payload = {"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
"client_secret": self.client_secret}
await self._authenticate(authorize_payload)
async def clear_authorization(self):
"""Logs out and clears all persisted tokens for this client ID."""
await self.close()
self.tokens[self.client_id] = {}
self._save_token()
async def refresh(self):
"""Use the persisted refresh token to request a new access token."""
if not self.authorized:
raise NotAuthorized(
'No refresh token is available for client ID %s' % (self.client_id))
refresh_payload = {"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
"client_id": self.client_id,
"client_secret": self.client_secret}
_LOGGER.debug("Refreshing access token for client %s", self.client_id)
await self._authenticate(refresh_payload)
async def close(self):
"""Closes the aiohttp session."""
for subscription in self.logi.subscriptions:
if subscription.opened:
# Signal subscription to close itself when the next frame is processed.
subscription.invalidate()
_LOGGER.warning('One or more WS connections have not been closed.')
if isinstance(self.session, aiohttp.ClientSession):
await self.session.close()
self.session = None
self.logi.is_connected = False
async def _authenticate(self, payload):
"""Request or refresh the access token with Logi Circle"""
if self.invalid:
raise SessionInvalidated('Logi API session invalidated due to 4xx exception refreshing token')
if self._lock.locked():
async with self._lock:
_LOGGER.debug("Concurrent request to authenticate client ID %s ignored", self.client_id)
return
async with self._lock:
_LOGGER.debug("Authenticating client ID %s", self.client_id)
session = await self.get_session()
async with session.post(AUTH_BASE + TOKEN_ENDPOINT, data=payload) as req:
try:
response = await req.json()
if req.status >= 400:
self.logi.is_connected = False
if req.status >= 400 and req.status < 500:
self.invalid = True
error_message = response.get(
"error_description", "Non-OK code %s returned" % (req.status))
raise AuthorizationFailed(error_message)
# Authorization succeeded. Persist the refresh and access tokens.
_LOGGER.debug("Successfully authenticated client ID %s", self.client_id)
self.logi.is_connected = True
self.invalid = False
self.tokens[self.client_id] = response
self._save_token()
except aiohttp.ContentTypeError:
response = await req.text()
self.logi.is_connected = False
if req.status >= 400 and req.status < 500:
self.invalid = True
if req.status >= 400:
raise AuthorizationFailed("Non-OK code %s returned: %s" % (req.status, response))
else:
raise AuthorizationFailed("Unexpected content type from Logi API: %s" % (response))
async def get_session(self):
"""Returns a aiohttp session, creating one if it doesn't already exist."""
if not isinstance(self.session, aiohttp.ClientSession):
self.session = aiohttp.ClientSession()
self.logi.is_connected = True
return self.session
def _save_token(self):
"""Dump data into a pickle file."""
with open(self.cache_file, 'wb') as pickle_db:
pickle.dump(self.tokens, pickle_db)
return True
def _read_token(self):
"""Read data from a pickle file."""
filename = self.cache_file
if os.path.isfile(filename):
data = pickle.load(open(filename, 'rb'))
return data
return {}
|