1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
|
module JSON
class JWK < ActiveSupport::HashWithIndifferentAccess
class UnknownAlgorithm < JWT::Exception; end
def initialize(params = {}, ex_params = {})
case params
when OpenSSL::PKey::RSA, OpenSSL::PKey::EC
super params.to_jwk(ex_params)
when OpenSSL::PKey::PKey
raise UnknownAlgorithm.new('Unknown Key Type')
when String
super(
k: params,
kty: :oct
)
merge! ex_params
else
super params
merge! ex_params
end
calculate_default_kid if self[:kid].blank?
end
def content_type
'application/jwk+json'
end
def thumbprint(digest = OpenSSL::Digest::SHA256.new)
digest = case digest
when OpenSSL::Digest
digest
when String, Symbol
OpenSSL::Digest.new digest.to_s
else
raise UnknownAlgorithm.new('Unknown Digest Algorithm')
end
Base64.urlsafe_encode64 digest.digest(normalize.to_json), padding: false
end
def to_key
case
when rsa?
to_rsa_key
when ec?
to_ec_key
when oct?
self[:k]
else
raise UnknownAlgorithm.new('Unknown Key Type')
end
end
def rsa?
self[:kty]&.to_sym == :RSA
end
def ec?
self[:kty]&.to_sym == :EC
end
def oct?
self[:kty]&.to_sym == :oct
end
def normalize
case
when rsa?
{
e: self[:e],
kty: self[:kty],
n: self[:n]
}
when ec?
{
crv: self[:crv],
kty: self[:kty],
x: self[:x],
y: self[:y]
}
when oct?
{
k: self[:k],
kty: self[:kty]
}
else
raise UnknownAlgorithm.new('Unknown Key Type')
end
end
private
def calculate_default_kid
self[:kid] = thumbprint
rescue
# ignore
end
def to_rsa_key
e, n, d, p, q, dp, dq, qi = [:e, :n, :d, :p, :q, :dp, :dq, :qi].collect do |key|
if self[key]
OpenSSL::BN.new Base64.urlsafe_decode64(self[key]), 2
end
end
# Public key
data_sequence = OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::Integer(n),
OpenSSL::ASN1::Integer(e),
])
if d && p && q && dp && dq && qi
data_sequence = OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::Integer(0),
OpenSSL::ASN1::Integer(n),
OpenSSL::ASN1::Integer(e),
OpenSSL::ASN1::Integer(d),
OpenSSL::ASN1::Integer(p),
OpenSSL::ASN1::Integer(q),
OpenSSL::ASN1::Integer(dp),
OpenSSL::ASN1::Integer(dq),
OpenSSL::ASN1::Integer(qi),
])
end
asn1 = OpenSSL::ASN1::Sequence(data_sequence)
OpenSSL::PKey::RSA.new(asn1.to_der)
end
def to_ec_key
curve_name = case self[:crv]&.to_sym
when :'P-256'
'prime256v1'
when :'P-384'
'secp384r1'
when :'P-521'
'secp521r1'
when :secp256k1
'secp256k1'
else
raise UnknownAlgorithm.new('Unknown EC Curve')
end
x, y, d = [:x, :y, :d].collect do |key|
if self[key]
Base64.urlsafe_decode64(self[key])
end
end
point = OpenSSL::PKey::EC::Point.new(
OpenSSL::PKey::EC::Group.new(curve_name),
OpenSSL::BN.new(['04' + x.unpack('H*').first + y.unpack('H*').first].pack('H*'), 2)
)
# Public key
data_sequence = OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::ObjectId("id-ecPublicKey"),
OpenSSL::ASN1::ObjectId(curve_name)
]),
OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed))
])
if d
# Private key
data_sequence = OpenSSL::ASN1::Sequence([
OpenSSL::ASN1::Integer(1),
OpenSSL::ASN1::OctetString(OpenSSL::BN.new(d, 2).to_s(2)),
OpenSSL::ASN1::ObjectId(curve_name, 0, :EXPLICIT),
OpenSSL::ASN1::BitString(point.to_octet_string(:uncompressed), 1, :EXPLICIT)
])
end
OpenSSL::PKey::EC.new(data_sequence.to_der)
end
end
end
|