1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
|
#-------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
#--------------------------------------------------------------------------
# pylint: disable=super-init-not-called,no-self-use
import datetime
import logging
import time
from uamqp import Session, c_uamqp, compat, constants, errors, utils
from uamqp.constants import TransportType
from .common import _SASL, AMQPAuth
_logger = logging.getLogger(__name__)
class TokenRetryPolicy(object):
"""Retry policy for sending authentication tokens
for CBS authentication.
:param retries: The number of retry attempts for a failed
PUT token request. The default is 3. This is exclusive of
the initial attempt.
:type retries: int
:param backoff: The time in miliseconds to wait between
retry attempts.
:type backoff: int
"""
def __init__(self, retries=3, backoff=0):
self.retries = retries
self.backoff = float(backoff)/1000
class CBSAuthMixin(object):
"""Mixin to handle sending and refreshing CBS auth tokens."""
def update_token(self):
"""Update a token that is about to expire. This is specific
to a particular token type, and therefore must be implemented
in a child class.
"""
raise errors.TokenExpired(
"Unable to refresh token - no refresh logic implemented.")
def create_authenticator(self, connection, debug=False, **kwargs):
"""Create the AMQP session and the CBS channel with which
to negotiate the token.
:param connection: The underlying AMQP connection on which
to create the session.
:type connection: ~uamqp.connection.Connection
:param debug: Whether to emit network trace logging events for the
CBS session. Default is `False`. Logging events are set at INFO level.
:type debug: bool
:rtype: uamqp.c_uamqp.CBSTokenAuth
"""
self._connection = connection
self._session = Session(connection, **kwargs)
try:
self._cbs_auth = c_uamqp.CBSTokenAuth(
self.audience,
self.token_type,
self.token,
int(self.expires_at),
self._session._session, # pylint: disable=protected-access
self.timeout,
self._connection.container_id,
self._refresh_window
)
self._cbs_auth.set_trace(debug)
except ValueError:
self._session.destroy()
raise errors.AMQPConnectionError(
"Unable to open authentication session on connection {}.\n"
"Please confirm target hostname exists: {}".format(connection.container_id, connection.hostname))
return self._cbs_auth
def close_authenticator(self):
"""Close the CBS auth channel and session."""
_logger.info("Shutting down CBS session on connection: %r.", self._connection.container_id)
try:
_logger.debug("Unlocked CBS to close on connection: %r.", self._connection.container_id)
self._cbs_auth.destroy()
_logger.info("Auth closed, destroying session on connection: %r.", self._connection.container_id)
self._session.destroy()
finally:
_logger.info("Finished shutting down CBS session on connection: %r.", self._connection.container_id)
def handle_token(self):
"""This function is called periodically to check the status of the current
token if there is one, and request a new one if needed.
If the token request fails, it will be retried according to the retry policy.
A token refresh will be attempted if the token will expire soon.
This function will return a tuple of two booleans. The first represents whether
the token authentication has not completed within it's given timeout window. The
second indicates whether the token negotiation is still in progress.
:raises: ~uamqp.errors.AuthenticationException if the token authentication fails.
:raises: ~uamqp.errors.TokenExpired if the token has expired and cannot be
refreshed.
:rtype: tuple[bool, bool]
"""
# pylint: disable=protected-access
timeout = False
in_progress = False
try:
self._connection.lock()
if self._connection._closing or self._connection._error:
return timeout, in_progress
auth_status = self._cbs_auth.get_status()
auth_status = constants.CBSAuthStatus(auth_status)
if auth_status == constants.CBSAuthStatus.Error:
if self.retries >= self._retry_policy.retries: # pylint: disable=no-member
_logger.warning("Authentication Put-Token failed. Retries exhausted.")
raise errors.TokenAuthFailure(*self._cbs_auth.get_failure_info())
error_code, error_description = self._cbs_auth.get_failure_info()
_logger.info("Authentication status: %r, description: %r", error_code, error_description)
_logger.info("Authentication Put-Token failed. Retrying.")
self.retries += 1 # pylint: disable=no-member
time.sleep(self._retry_policy.backoff)
self._cbs_auth.authenticate()
in_progress = True
elif auth_status == constants.CBSAuthStatus.Failure:
raise errors.AuthenticationException("Failed to open CBS authentication link.")
elif auth_status == constants.CBSAuthStatus.Expired:
raise errors.TokenExpired("CBS Authentication Expired.")
elif auth_status == constants.CBSAuthStatus.Timeout:
timeout = True
elif auth_status == constants.CBSAuthStatus.InProgress:
in_progress = True
elif auth_status == constants.CBSAuthStatus.RefreshRequired:
_logger.info("Token on connection %r will expire soon - attempting to refresh.",
self._connection.container_id)
self.update_token()
if self.token != self._prev_token:
self._cbs_auth.refresh(self.token, int(self.expires_at))
else:
_logger.info(
"The newly acquired token on connection %r is the same as the previous one,"
" will keep attempting to refresh",
self._connection.container_id
)
elif auth_status == constants.CBSAuthStatus.Idle:
self._cbs_auth.authenticate()
in_progress = True
elif auth_status != constants.CBSAuthStatus.Ok:
raise errors.AuthenticationException("Invalid auth state.")
except compat.TimeoutException:
_logger.debug("CBS auth timed out while waiting for lock acquisition.")
return None, None
except ValueError as e:
raise errors.AuthenticationException(
"Token authentication failed: {}".format(e))
finally:
self._connection.release()
return timeout, in_progress
def _set_expiry(self, expires_at, expires_in):
if not expires_at and not expires_in:
raise ValueError("Must specify either 'expires_at' or 'expires_in'.")
if not expires_at:
expires_at = time.time() + expires_in.seconds
else:
expires_in_seconds = expires_at - time.time()
if expires_in_seconds < 1:
raise ValueError("Token has already expired.")
expires_in = datetime.timedelta(seconds=expires_in_seconds)
return expires_at, expires_in
class SASTokenAuth(AMQPAuth, CBSAuthMixin):
"""CBS authentication using SAS tokens.
:param audience: The token audience field. For SAS tokens
this is usually the URI.
:type audience: str or bytes
:param uri: The AMQP endpoint URI. This must be provided as
a decoded string.
:type uri: str
:param token: The SAS token.
:type token: str or bytes.
:param expires_in: The total remaining seconds until the token
expires.
:type expires_in: ~datetime.timedelta
:param expires_at: The timestamp at which the SAS token will expire
formatted as seconds since epoch.
:type expires_at: float
:param username: The SAS token username, also referred to as the key
name or policy name. This can optionally be encoded into the URI.
:type username: str
:param password: The SAS token password, also referred to as the key.
This can optionally be encoded into the URI.
:type password: str
:param port: The TLS port - default for AMQP is 5671.
:type port: int
:param timeout: The timeout in seconds in which to negotiate the token.
The default value is 10 seconds.
:type timeout: float
:param retry_policy: The retry policy for the PUT token request. The default
retry policy has 3 retries.
:type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
:param verify: The path to a user-defined certificate.
:type verify: str
:param token_type: The type field of the token request.
Default value is `b"servicebus.windows.net:sastoken"`.
:type token_type: bytes
:param http_proxy: HTTP proxy configuration. This should be a dictionary with
the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
keys are 'username' and 'password'.
:type http_proxy: dict
:param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
transport type is explicitly requested.
:type transport_type: ~uamqp.TransportType
:param encoding: The encoding to use if hostname is provided as a str.
Default is 'UTF-8'.
:type encoding: str
:keyword int refresh_window: The time in seconds before the token expiration
time to start the process of token refresh.
Default value is 10% of the remaining seconds until the token expires.
"""
def __init__(self, audience, uri, token,
expires_in=None,
expires_at=None,
username=None,
password=None,
port=None,
timeout=10,
retry_policy=TokenRetryPolicy(),
verify=None,
token_type=b"servicebus.windows.net:sastoken",
http_proxy=None,
transport_type=TransportType.Amqp,
encoding='UTF-8',
**kwargs): # pylint: disable=no-member
self._retry_policy = retry_policy
self._encoding = encoding
self._refresh_window = kwargs.pop("refresh_window", 0)
self._prev_token = None
self.uri = uri
parsed = compat.urlparse(uri) # pylint: disable=no-member
self.cert_file = verify
self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding)
self.username = compat.unquote_plus(parsed.username) if parsed.username else None # pylint: disable=no-member
self.password = compat.unquote_plus(parsed.password) if parsed.password else None # pylint: disable=no-member
self.username = username or self.username
self.password = password or self.password
self.audience = self._encode(audience)
self.token_type = self._encode(token_type)
self.token = self._encode(token)
self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in)
self.timeout = timeout
self.retries = 0
self.sasl = _SASL()
self.set_io(self.hostname, port, http_proxy, transport_type)
def update_token(self):
"""If a username and password are present - attempt to use them to
request a fresh SAS token.
"""
if not self.username or not self.password:
raise errors.TokenExpired("Unable to refresh token - no username or password.")
encoded_uri = compat.quote_plus(self.uri).encode(self._encoding) # pylint: disable=no-member
encoded_key = compat.quote_plus(self.username).encode(self._encoding) # pylint: disable=no-member
self.expires_at = time.time() + self.expires_in.seconds
self._prev_token = self.token
self.token = utils.create_sas_token(
encoded_key,
self.password.encode(self._encoding),
encoded_uri,
self.expires_in)
@classmethod
def from_shared_access_key(
cls,
uri,
key_name,
shared_access_key,
expiry=None,
port=None,
timeout=10,
retry_policy=TokenRetryPolicy(),
verify=None,
http_proxy=None,
transport_type=TransportType.Amqp,
encoding='UTF-8',
**kwargs):
"""Attempt to create a CBS token session using a Shared Access Key such
as is used to connect to Azure services.
:param uri: The AMQP endpoint URI. This must be provided as
a decoded string.
:type uri: str
:param key_name: The SAS token username, also referred to as the key
name or policy name.
:type key_name: str
:param shared_access_key: The SAS token password, also referred to as the key.
:type shared_access_key: str
:param expiry: The lifetime in seconds for the generated token. Default is 1 hour.
:type expiry: int
:param port: The TLS port - default for AMQP is 5671.
:type port: int
:param timeout: The timeout in seconds in which to negotiate the token.
The default value is 10 seconds.
:type timeout: float
:param retry_policy: The retry policy for the PUT token request. The default
retry policy has 3 retries.
:type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
:param verify: The path to a user-defined certificate.
:type verify: str
:param http_proxy: HTTP proxy configuration. This should be a dictionary with
the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
keys are 'username' and 'password'.
:type http_proxy: dict
:param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
transport type is explicitly requested.
:type transport_type: ~uamqp.TransportType
:param encoding: The encoding to use if hostname is provided as a str.
Default is 'UTF-8'.
:type encoding: str
:keyword int refresh_window: The time in seconds before the token expiration
time to start the process of token refresh.
Default value is 10% of the remaining seconds until the token expires.
"""
expires_in = datetime.timedelta(seconds=expiry or constants.AUTH_EXPIRATION_SECS)
encoded_uri = compat.quote_plus(uri).encode(encoding) # pylint: disable=no-member
encoded_key = compat.quote_plus(key_name).encode(encoding) # pylint: disable=no-member
expires_at = time.time() + expires_in.seconds
token = utils.create_sas_token(
encoded_key,
shared_access_key.encode(encoding),
encoded_uri,
expires_in)
return cls(
uri, uri, token,
expires_in=expires_in,
expires_at=expires_at,
username=key_name,
password=shared_access_key,
port=port,
timeout=timeout,
retry_policy=retry_policy,
verify=verify,
http_proxy=http_proxy,
transport_type=transport_type,
encoding=encoding,
custom_endpoint_hostname=kwargs.pop("custom_endpoint_hostname", None))
class JWTTokenAuth(AMQPAuth, CBSAuthMixin):
"""CBS authentication using JWT tokens.
:param audience: The token audience field. For JWT tokens
this is usually the URI.
:type audience: str or bytes
:param uri: The AMQP endpoint URI. This must be provided as
a decoded string.
:type uri: str
:param get_token: The callback function used for getting and refreshing
tokens. It should return a valid jwt token each time it is called.
:type get_token: callable object
:param expires_in: The total remaining seconds until the token
expires - default for JWT token generated by AAD is 3600s (1 hour).
:type expires_in: ~datetime.timedelta
:param expires_at: The timestamp at which the JWT token will expire
formatted as seconds since epoch.
:type expires_at: float
:param port: The TLS port - default for AMQP is 5671.
:type port: int
:param timeout: The timeout in seconds in which to negotiate the token.
The default value is 10 seconds.
:type timeout: float
:param retry_policy: The retry policy for the PUT token request. The default
retry policy has 3 retries.
:type retry_policy: ~uamqp.authentication.cbs_auth.TokenRetryPolicy
:param verify: The path to a user-defined certificate.
:type verify: str
:param token_type: The type field of the token request.
Default value is `b"jwt"`.
:type token_type: bytes
:param http_proxy: HTTP proxy configuration. This should be a dictionary with
the following keys present: 'proxy_hostname' and 'proxy_port'. Additional optional
keys are 'username' and 'password'.
:type http_proxy: dict
:param transport_type: The transport protocol type - default is ~uamqp.TransportType.Amqp.
~uamqp.TransportType.AmqpOverWebsocket is applied when http_proxy is set or the
transport type is explicitly requested.
:type transport_type: ~uamqp.TransportType
:param encoding: The encoding to use if hostname is provided as a str.
Default is 'UTF-8'.
:type encoding: str
:keyword int refresh_window: The time in seconds before the token expiration
time to start the process of token refresh.
Default value is 10% of the remaining seconds until the token expires.
"""
def __init__(self, audience, uri,
get_token,
expires_in=datetime.timedelta(seconds=constants.AUTH_EXPIRATION_SECS),
expires_at=None,
port=None,
timeout=10,
retry_policy=TokenRetryPolicy(),
verify=None,
token_type=b"jwt",
http_proxy=None,
transport_type=TransportType.Amqp,
encoding='UTF-8',
**kwargs): # pylint: disable=no-member
self._retry_policy = retry_policy
self._encoding = encoding
self._refresh_window = kwargs.pop("refresh_window", 0)
self._prev_token = None
self.uri = uri
parsed = compat.urlparse(uri) # pylint: disable=no-member
self.cert_file = verify
self.hostname = (kwargs.get("custom_endpoint_hostname") or parsed.hostname).encode(self._encoding)
if not get_token or not callable(get_token):
raise ValueError("get_token must be a callable object.")
self.get_token = get_token
self.audience = self._encode(audience)
self.token_type = self._encode(token_type)
self.token = None
self.expires_at, self.expires_in = self._set_expiry(expires_at, expires_in)
self.timeout = timeout
self.retries = 0
self.sasl = _SASL()
self.set_io(self.hostname, port, http_proxy, transport_type)
def create_authenticator(self, connection, debug=False, **kwargs):
self.update_token()
return super(JWTTokenAuth, self).create_authenticator(connection, debug, **kwargs)
def update_token(self):
access_token = self.get_token()
self.expires_at = access_token.expires_on
self._prev_token = self.token
self.token = self._encode(access_token.token)
|