1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
|
# frozen_string_literal: true
require 'forwardable'
module JWT
module JWK
# JWK representation for Elliptic Curve (EC) keys
class EC < KeyBase # rubocop:disable Metrics/ClassLength
KTY = 'EC'
KTYS = [KTY, OpenSSL::PKey::EC, JWT::JWK::EC].freeze
BINARY = 2
EC_PUBLIC_KEY_ELEMENTS = %i[kty crv x y].freeze
EC_PRIVATE_KEY_ELEMENTS = %i[d].freeze
EC_KEY_ELEMENTS = (EC_PRIVATE_KEY_ELEMENTS + EC_PUBLIC_KEY_ELEMENTS).freeze
ZERO_BYTE = "\0".b.freeze
def initialize(key, params = nil, options = {})
params ||= {}
# For backwards compatibility when kid was a String
params = { kid: params } if params.is_a?(String)
key_params = extract_key_params(key)
params = params.transform_keys(&:to_sym)
check_jwk_params!(key_params, params)
super(options, key_params.merge(params))
end
def keypair
ec_key
end
def private?
ec_key.private_key?
end
def signing_key
ec_key
end
def verify_key
ec_key
end
def public_key
ec_key
end
def members
EC_PUBLIC_KEY_ELEMENTS.each_with_object({}) { |i, h| h[i] = self[i] }
end
def export(options = {})
exported = parameters.clone
exported.reject! { |k, _| EC_PRIVATE_KEY_ELEMENTS.include? k } unless private? && options[:include_private] == true
exported
end
def key_digest
_crv, x_octets, y_octets = keypair_components(ec_key)
sequence = OpenSSL::ASN1::Sequence([OpenSSL::ASN1::Integer.new(OpenSSL::BN.new(x_octets, BINARY)),
OpenSSL::ASN1::Integer.new(OpenSSL::BN.new(y_octets, BINARY))])
OpenSSL::Digest::SHA256.hexdigest(sequence.to_der)
end
def []=(key, value)
raise ArgumentError, 'cannot overwrite cryptographic key attributes' if EC_KEY_ELEMENTS.include?(key.to_sym)
super
end
def jwa
return super if self[:alg]
curve_name = self.class.to_openssl_curve(self[:crv])
JWA.resolve(JWA::Ecdsa.curve_by_name(curve_name)[:algorithm])
end
private
def ec_key
@ec_key ||= create_ec_key(self[:crv], self[:x], self[:y], self[:d])
end
def extract_key_params(key)
case key
when JWT::JWK::EC
key.export(include_private: true)
when OpenSSL::PKey::EC # Accept OpenSSL key as input
@ec_key = key # Preserve the object to avoid recreation
parse_ec_key(key)
when Hash
key.transform_keys(&:to_sym)
else
raise ArgumentError, 'key must be of type OpenSSL::PKey::EC or Hash with key parameters'
end
end
def check_jwk_params!(key_params, params)
raise ArgumentError, 'cannot overwrite cryptographic key attributes' unless (EC_KEY_ELEMENTS & params.keys).empty?
raise JWT::JWKError, "Incorrect 'kty' value: #{key_params[:kty]}, expected #{KTY}" unless key_params[:kty] == KTY
raise JWT::JWKError, 'Key format is invalid for EC' unless key_params[:crv] && key_params[:x] && key_params[:y]
end
def keypair_components(ec_keypair)
encoded_point = ec_keypair.public_key.to_bn.to_s(BINARY)
case ec_keypair.group.curve_name
when 'prime256v1'
crv = 'P-256'
x_octets, y_octets = encoded_point.unpack('xa32a32')
when 'secp256k1'
crv = 'P-256K'
x_octets, y_octets = encoded_point.unpack('xa32a32')
when 'secp384r1'
crv = 'P-384'
x_octets, y_octets = encoded_point.unpack('xa48a48')
when 'secp521r1'
crv = 'P-521'
x_octets, y_octets = encoded_point.unpack('xa66a66')
else
raise JWT::JWKError, "Unsupported curve '#{ec_keypair.group.curve_name}'"
end
[crv, x_octets, y_octets]
end
def encode_octets(octets)
return unless octets
::JWT::Base64.url_encode(octets)
end
def parse_ec_key(key)
crv, x_octets, y_octets = keypair_components(key)
octets = key.private_key&.to_bn&.to_s(BINARY)
{
kty: KTY,
crv: crv,
x: encode_octets(x_octets),
y: encode_octets(y_octets),
d: encode_octets(octets)
}.compact
end
def create_point(jwk_crv, jwk_x, jwk_y)
curve = EC.to_openssl_curve(jwk_crv)
x_octets = decode_octets(jwk_x)
y_octets = decode_octets(jwk_y)
# The details of the `Point` instantiation are covered in:
# - https://docs.ruby-lang.org/en/2.4.0/OpenSSL/PKey/EC.html
# - https://www.openssl.org/docs/manmaster/man3/EC_POINT_new.html
# - https://tools.ietf.org/html/rfc5480#section-2.2
# - https://www.secg.org/SEC1-Ver-1.0.pdf
# Section 2.3.3 of the last of these references specifies that the
# encoding of an uncompressed point consists of the byte `0x04` followed
# by the x value then the y value.
OpenSSL::PKey::EC::Point.new(
OpenSSL::PKey::EC::Group.new(curve),
OpenSSL::BN.new([0x04, x_octets, y_octets].pack('Ca*a*'), 2)
)
end
if ::JWT.openssl_3?
def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d)
point = create_point(jwk_crv, jwk_x, jwk_y)
return ::JWT::JWA::Ecdsa.create_public_key_from_point(point) unless jwk_d
# https://datatracker.ietf.org/doc/html/rfc5915.html
# ECPrivateKey ::= SEQUENCE {
# version INTEGER { ecPrivkeyVer1(1) } (ecPrivkeyVer1),
# privateKey OCTET STRING,
# parameters [0] ECParameters {{ NamedCurve }} OPTIONAL,
# publicKey [1] BIT STRING OPTIONAL
# }
sequence = OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::Integer(1),
OpenSSL::ASN1::OctetString(OpenSSL::BN.new(decode_octets(jwk_d), 2).to_s(2)),
OpenSSL::ASN1::ObjectId(point.group.curve_name, 0, :EXPLICIT),
OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
])
OpenSSL::PKey::EC.new(sequence.to_der)
end
else
def create_ec_key(jwk_crv, jwk_x, jwk_y, jwk_d)
point = create_point(jwk_crv, jwk_x, jwk_y)
::JWT::JWA::Ecdsa.create_public_key_from_point(point).tap do |key|
key.private_key = OpenSSL::BN.new(decode_octets(jwk_d), 2) if jwk_d
end
end
end
def decode_octets(base64_encoded_coordinate)
bytes = ::JWT::Base64.url_decode(base64_encoded_coordinate)
# Some base64 encoders on some platform omit a single 0-byte at
# the start of either Y or X coordinate of the elliptic curve point.
# This leads to an encoding error when data is passed to OpenSSL BN.
# It is know to have happened to exported JWKs on a Java application and
# on a Flutter/Dart application (both iOS and Android). All that is
# needed to fix the problem is adding a leading 0-byte. We know the
# required byte is 0 because with any other byte the point is no longer
# on the curve - and OpenSSL will actually communicate this via another
# exception. The indication of a stripped byte will be the fact that the
# coordinates - once decoded into bytes - should always be an even
# bytesize. For example, with a P-521 curve, both x and y must be 66 bytes.
# With a P-256 curve, both x and y must be 32 and so on. The simplest way
# to check for this truncation is thus to check whether the number of bytes
# is odd, and restore the leading 0-byte if it is.
if bytes.bytesize.odd?
ZERO_BYTE + bytes
else
bytes
end
end
class << self
def import(jwk_data)
new(jwk_data)
end
def to_openssl_curve(crv)
# The JWK specs and OpenSSL use different names for the same curves.
# See https://tools.ietf.org/html/rfc5480#section-2.1.1.1 for some
# pointers on different names for common curves.
case crv
when 'P-256' then 'prime256v1'
when 'P-384' then 'secp384r1'
when 'P-521' then 'secp521r1'
when 'P-256K' then 'secp256k1'
else raise JWT::JWKError, 'Invalid curve provided'
end
end
end
end
end
end
|