1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
|
import binascii
import json
from collections.abc import Iterable, Mapping
from jose import jwk
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError, JWSSignatureError
from jose.utils import base64url_decode, base64url_encode
def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256):
"""Signs a claims set and returns a JWS string.
Args:
payload (str or dict): A string to sign
key (str or dict): The key to use for signing the claim set. Can be
individual JWK or JWK set.
headers (dict, optional): A set of headers that will be added to
the default headers. Any headers that are added as additional
headers will override the default headers.
algorithm (str, optional): The algorithm to use for signing the
the claims. Defaults to HS256.
Returns:
str: The string representation of the header, claims, and signature.
Raises:
JWSError: If there is an error signing the token.
Examples:
>>> jws.sign({'a': 'b'}, 'secret', algorithm='HS256')
'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
"""
if algorithm not in ALGORITHMS.SUPPORTED:
raise JWSError("Algorithm %s not supported." % algorithm)
encoded_header = _encode_header(algorithm, additional_headers=headers)
encoded_payload = _encode_payload(payload)
signed_output = _sign_header_and_claims(encoded_header, encoded_payload, algorithm, key)
return signed_output
def verify(token, key, algorithms, verify=True):
"""Verifies a JWS string's signature.
Args:
token (str): A signed JWS to be verified.
key (str or dict): A key to attempt to verify the payload with. Can be
individual JWK or JWK set.
algorithms (str or list): Valid algorithms that should be used to verify the JWS.
Returns:
str: The str representation of the payload, assuming the signature is valid.
Raises:
JWSError: If there is an exception verifying a token.
Examples:
>>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
>>> jws.verify(token, 'secret', algorithms='HS256')
"""
header, payload, signing_input, signature = _load(token)
if verify:
_verify_signature(signing_input, header, signature, key, algorithms)
return payload
def get_unverified_header(token):
"""Returns the decoded headers without verification of any kind.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWSError: If there is an exception decoding the token.
"""
header, claims, signing_input, signature = _load(token)
return header
def get_unverified_headers(token):
"""Returns the decoded headers without verification of any kind.
This is simply a wrapper of get_unverified_header() for backwards
compatibility.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
dict: The dict representation of the token headers.
Raises:
JWSError: If there is an exception decoding the token.
"""
return get_unverified_header(token)
def get_unverified_claims(token):
"""Returns the decoded claims without verification of any kind.
Args:
token (str): A signed JWS to decode the headers from.
Returns:
str: The str representation of the token claims.
Raises:
JWSError: If there is an exception decoding the token.
"""
header, claims, signing_input, signature = _load(token)
return claims
def _encode_header(algorithm, additional_headers=None):
header = {"typ": "JWT", "alg": algorithm}
if additional_headers:
header.update(additional_headers)
json_header = json.dumps(
header,
separators=(",", ":"),
sort_keys=True,
).encode("utf-8")
return base64url_encode(json_header)
def _encode_payload(payload):
if isinstance(payload, Mapping):
try:
payload = json.dumps(
payload,
separators=(",", ":"),
).encode("utf-8")
except ValueError:
pass
return base64url_encode(payload)
def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
signing_input = b".".join([encoded_header, encoded_claims])
try:
if not isinstance(key, Key):
key = jwk.construct(key, algorithm)
signature = key.sign(signing_input)
except Exception as e:
raise JWSError(e)
encoded_signature = base64url_encode(signature)
encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature])
return encoded_string.decode("utf-8")
def _load(jwt):
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
try:
signing_input, crypto_segment = jwt.rsplit(b".", 1)
header_segment, claims_segment = signing_input.split(b".", 1)
header_data = base64url_decode(header_segment)
except ValueError:
raise JWSError("Not enough segments")
except (TypeError, binascii.Error):
raise JWSError("Invalid header padding")
try:
header = json.loads(header_data.decode("utf-8"))
except ValueError as e:
raise JWSError("Invalid header string: %s" % e)
if not isinstance(header, Mapping):
raise JWSError("Invalid header string: must be a json object")
try:
payload = base64url_decode(claims_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid payload padding")
try:
signature = base64url_decode(crypto_segment)
except (TypeError, binascii.Error):
raise JWSError("Invalid crypto padding")
return (header, payload, signing_input, signature)
def _sig_matches_keys(keys, signing_input, signature, alg):
for key in keys:
if not isinstance(key, Key):
key = jwk.construct(key, alg)
try:
if key.verify(signing_input, signature):
return True
except Exception:
pass
return False
def _get_keys(key):
if isinstance(key, Key):
return (key,)
try:
key = json.loads(key, parse_int=str, parse_float=str)
except Exception:
pass
if isinstance(key, Mapping):
if "keys" in key:
# JWK Set per RFC 7517
return key["keys"]
elif "kty" in key:
# Individual JWK per RFC 7517
return (key,)
else:
# Some other mapping. Firebase uses just dict of kid, cert pairs
values = key.values()
if values:
return values
return (key,)
# Iterable but not text or mapping => list- or tuple-like
elif isinstance(key, Iterable) and not (isinstance(key, str) or isinstance(key, bytes)):
return key
# Scalar value, wrap in tuple.
else:
return (key,)
def _verify_signature(signing_input, header, signature, key="", algorithms=None):
alg = header.get("alg")
if not alg:
raise JWSError("No algorithm was specified in the JWS header.")
if algorithms is not None and alg not in algorithms:
raise JWSError("The specified alg value is not allowed")
keys = _get_keys(key)
try:
if not _sig_matches_keys(keys, signing_input, signature, alg):
raise JWSSignatureError()
except JWSSignatureError:
raise JWSError("Signature verification failed.")
except JWSError:
raise JWSError("Invalid or unsupported algorithm: %s" % alg)
|