1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
|
from cryptography.hazmat.primitives.serialization import BestAvailableEncryption
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.primitives.serialization import NoEncryption
from cryptography.hazmat.primitives.serialization import PrivateFormat
from cryptography.hazmat.primitives.serialization import PublicFormat
from authlib.common.encoding import to_bytes
from ._cryptography_key import load_pem_key
from .base_key import Key
class AsymmetricKey(Key):
"""This is the base class for a JSON Web Key."""
PUBLIC_KEY_FIELDS = []
PRIVATE_KEY_FIELDS = []
PRIVATE_KEY_CLS = bytes
PUBLIC_KEY_CLS = bytes
SSH_PUBLIC_PREFIX = b""
def __init__(self, private_key=None, public_key=None, options=None):
super().__init__(options)
self.private_key = private_key
self.public_key = public_key
@property
def public_only(self):
if self.private_key:
return False
if "d" in self.tokens:
return False
return True
def get_op_key(self, operation):
"""Get the raw key for the given key_op. This method will also
check if the given key_op is supported by this key.
:param operation: key operation value, such as "sign", "encrypt".
:return: raw key
"""
self.check_key_op(operation)
if operation in self.PUBLIC_KEY_OPS:
return self.get_public_key()
return self.get_private_key()
def get_public_key(self):
if self.public_key:
return self.public_key
private_key = self.get_private_key()
if private_key:
return private_key.public_key()
return self.public_key
def get_private_key(self):
if self.private_key:
return self.private_key
if self.tokens:
self.load_raw_key()
return self.private_key
def load_raw_key(self):
if "d" in self.tokens:
self.private_key = self.load_private_key()
else:
self.public_key = self.load_public_key()
def load_dict_key(self):
if self.private_key:
self._dict_data.update(self.dumps_private_key())
else:
self._dict_data.update(self.dumps_public_key())
def dumps_private_key(self):
raise NotImplementedError()
def dumps_public_key(self):
raise NotImplementedError()
def load_private_key(self):
raise NotImplementedError()
def load_public_key(self):
raise NotImplementedError()
def as_dict(self, is_private=False, **params):
"""Represent this key as a dict of the JSON Web Key."""
tokens = self.tokens
if is_private and "d" not in tokens:
raise ValueError("This is a public key")
kid = tokens.get("kid")
if "d" in tokens and not is_private:
# filter out private fields
tokens = {k: tokens[k] for k in tokens if k in self.PUBLIC_KEY_FIELDS}
tokens["kty"] = self.kty
if kid:
tokens["kid"] = kid
if not kid:
tokens["kid"] = self.thumbprint()
tokens.update(params)
return tokens
def as_key(self, is_private=False):
"""Represent this key as raw key."""
if is_private:
return self.get_private_key()
return self.get_public_key()
def as_bytes(self, encoding=None, is_private=False, password=None):
"""Export key into PEM/DER format bytes.
:param encoding: "PEM" or "DER"
:param is_private: export private key or public key
:param password: encrypt private key with password
:return: bytes
"""
if encoding is None or encoding == "PEM":
encoding = Encoding.PEM
elif encoding == "DER":
encoding = Encoding.DER
else:
raise ValueError(f"Invalid encoding: {encoding!r}")
raw_key = self.as_key(is_private)
if is_private:
if not raw_key:
raise ValueError("This is a public key")
if password is None:
encryption_algorithm = NoEncryption()
else:
encryption_algorithm = BestAvailableEncryption(to_bytes(password))
return raw_key.private_bytes(
encoding=encoding,
format=PrivateFormat.PKCS8,
encryption_algorithm=encryption_algorithm,
)
return raw_key.public_bytes(
encoding=encoding,
format=PublicFormat.SubjectPublicKeyInfo,
)
def as_pem(self, is_private=False, password=None):
return self.as_bytes(is_private=is_private, password=password)
def as_der(self, is_private=False, password=None):
return self.as_bytes(encoding="DER", is_private=is_private, password=password)
@classmethod
def import_dict_key(cls, raw, options=None):
cls.check_required_fields(raw)
key = cls(options=options)
key._dict_data = raw
return key
@classmethod
def import_key(cls, raw, options=None):
if isinstance(raw, cls):
if options is not None:
raw.options.update(options)
return raw
if isinstance(raw, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw, options=options)
elif isinstance(raw, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw, options=options)
elif isinstance(raw, dict):
key = cls.import_dict_key(raw, options)
else:
if options is not None:
password = options.pop("password", None)
else:
password = None
raw_key = load_pem_key(raw, cls.SSH_PUBLIC_PREFIX, password=password)
if isinstance(raw_key, cls.PUBLIC_KEY_CLS):
key = cls(public_key=raw_key, options=options)
elif isinstance(raw_key, cls.PRIVATE_KEY_CLS):
key = cls(private_key=raw_key, options=options)
else:
raise ValueError("Invalid data for importing key")
return key
@classmethod
def validate_raw_key(cls, key):
return isinstance(key, cls.PUBLIC_KEY_CLS) or isinstance(
key, cls.PRIVATE_KEY_CLS
)
@classmethod
def generate_key(cls, crv_or_size, options=None, is_private=False):
raise NotImplementedError()
|