1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
|
package jws
import (
"context"
"fmt"
"net/url"
"sync"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
)
// KeyProvider is responsible for providing key(s) to sign or verify a payload.
// Multiple `jws.KeyProvider`s can be passed to `jws.Verify()` or `jws.Sign()`
//
// `jws.Sign()` can only accept static key providers via `jws.WithKey()`,
// while `jws.Verify()` can accept `jws.WithKey()`, `jws.WithKeySet()`,
// `jws.WithVerifyAuto()`, and `jws.WithKeyProvider()`.
//
// Understanding how this works is crucial to learn how this package works.
//
// `jws.Sign()` is straightforward: signatures are created for each
// provided key.
//
// `jws.Verify()` is a bit more involved, because there are cases you
// will want to compute/deduce/guess the keys that you would like to
// use for verification.
//
// The first thing that `jws.Verify()` does is to collect the
// KeyProviders from the option list that the user provided (presented in pseudocode):
//
// keyProviders := filterKeyProviders(options)
//
// Then, remember that a JWS message may contain multiple signatures in the
// message. For each signature, we call on the KeyProviders to give us
// the key(s) to use on this signature:
//
// for sig in msg.Signatures {
// for kp in keyProviders {
// kp.FetchKeys(ctx, sink, sig, msg)
// ...
// }
// }
//
// The `sink` argument passed to the KeyProvider is a temporary storage
// for the keys (either a jwk.Key or a "raw" key). The `KeyProvider`
// is responsible for sending keys into the `sink`.
//
// When called, the `KeyProvider` created by `jws.WithKey()` sends the same key,
// `jws.WithKeySet()` sends keys that matches a particular `kid` and `alg`,
// `jws.WithVerifyAuto()` fetches a JWK from the `jku` URL,
// and finally `jws.WithKeyProvider()` allows you to execute arbitrary
// logic to provide keys. If you are providing a custom `KeyProvider`,
// you should execute the necessary checks or retrieval of keys, and
// then send the key(s) to the sink:
//
// sink.Key(alg, key)
//
// These keys are then retrieved and tried for each signature, until
// a match is found:
//
// keys := sink.Keys()
// for key in keys {
// if givenSignature == makeSignature(key, payload, ...)) {
// return OK
// }
// }
type KeyProvider interface {
FetchKeys(context.Context, KeySink, *Signature, *Message) error
}
// KeySink is a data storage where `jws.KeyProvider` objects should
// send their keys to.
type KeySink interface {
Key(jwa.SignatureAlgorithm, interface{})
}
type algKeyPair struct {
alg jwa.KeyAlgorithm
key interface{}
}
type algKeySink struct {
mu sync.Mutex
list []algKeyPair
}
func (s *algKeySink) Key(alg jwa.SignatureAlgorithm, key interface{}) {
s.mu.Lock()
s.list = append(s.list, algKeyPair{alg, key})
s.mu.Unlock()
}
type staticKeyProvider struct {
alg jwa.SignatureAlgorithm
key interface{}
}
func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ *Signature, _ *Message) error {
sink.Key(kp.alg, kp.key)
return nil
}
type keySetProvider struct {
set jwk.Set
requireKid bool // true if `kid` must be specified
useDefault bool // true if the first key should be used iff there's exactly one key in set
inferAlgorithm bool // true if the algorithm should be inferred from key type
multipleKeysPerKeyID bool // true if we should attempt to match multiple keys per key ID. if false we assume that only one key exists for a given key ID
}
func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature, _ *Message) error {
if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() {
return nil
}
if v := key.Algorithm(); v.String() != "" {
var alg jwa.SignatureAlgorithm
if err := alg.Accept(v); err != nil {
return fmt.Errorf(`invalid signature algorithm %s: %w`, key.Algorithm(), err)
}
sink.Key(alg, key)
return nil
}
if kp.inferAlgorithm {
algs, err := AlgorithmsForKey(key)
if err != nil {
return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
}
// bail out if the JWT has a `alg` field, and it doesn't match
if tokAlg := sig.ProtectedHeaders().Algorithm(); tokAlg != "" {
for _, alg := range algs {
if tokAlg == alg {
sink.Key(alg, key)
return nil
}
}
return fmt.Errorf(`algorithm in the message does not match any of the inferred algorithms`)
}
// Yes, you get to try them all!!!!!!!
for _, alg := range algs {
sink.Key(alg, key)
}
return nil
}
return nil
}
func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signature, msg *Message) error {
if kp.requireKid {
wantedKid := sig.ProtectedHeaders().KeyID()
if wantedKid == "" {
// If the kid is NOT specified... kp.useDefault needs to be true, and the
// JWKs must have exactly one key in it
if !kp.useDefault {
return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token`)
} else if kp.useDefault && kp.set.Len() > 1 {
return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
}
// if we got here, then useDefault == true AND there is exactly
// one key in the set.
key, _ := kp.set.Key(0)
return kp.selectKey(sink, key, sig, msg)
}
// Otherwise we better be able to look up the key.
// <= v2.0.3 backwards compatible case: only match a single key
// whose key ID matches `wantedKid`
if !kp.multipleKeysPerKeyID {
key, ok := kp.set.LookupKeyID(wantedKid)
if !ok {
return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
}
return kp.selectKey(sink, key, sig, msg)
}
// if multipleKeysPerKeyID is true, we attempt all keys whose key ID matches
// the wantedKey
var ok bool
for i := 0; i < kp.set.Len(); i++ {
key, _ := kp.set.Key(i)
if key.KeyID() != wantedKid {
continue
}
if err := kp.selectKey(sink, key, sig, msg); err != nil {
continue
}
ok = true
// continue processing so that we try all keys with the same key ID
}
if !ok {
return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
}
return nil
}
// Otherwise just try all keys
for i := 0; i < kp.set.Len(); i++ {
key, _ := kp.set.Key(i)
if err := kp.selectKey(sink, key, sig, msg); err != nil {
continue
}
}
return nil
}
type jkuProvider struct {
fetcher jwk.Fetcher
options []jwk.FetchOption
}
func (kp jkuProvider) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, _ *Message) error {
kid := sig.ProtectedHeaders().KeyID()
if kid == "" {
return fmt.Errorf(`use of "jku" requires that the payload contain a "kid" field in the protected header`)
}
// errors here can't be reliably passed to the consumers.
// it's unfortunate, but if you need this control, you are
// going to have to write your own fetcher
u := sig.ProtectedHeaders().JWKSetURL()
if u == "" {
return fmt.Errorf(`use of "jku" field specified, but the field is empty`)
}
uo, err := url.Parse(u)
if err != nil {
return fmt.Errorf(`failed to parse "jku": %w`, err)
}
if uo.Scheme != "https" {
return fmt.Errorf(`url in "jku" must be HTTPS`)
}
set, err := kp.fetcher.Fetch(ctx, u, kp.options...)
if err != nil {
return fmt.Errorf(`failed to fetch %q: %w`, u, err)
}
key, ok := set.LookupKeyID(kid)
if !ok {
// It is not an error if the key with the kid doesn't exist
return nil
}
algs, err := AlgorithmsForKey(key)
if err != nil {
return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
}
hdrAlg := sig.ProtectedHeaders().Algorithm()
for _, alg := range algs {
// if we have an "alg" field in the JWS, we can only proceed if
// the inferred algorithm matches
if hdrAlg != "" && hdrAlg != alg {
continue
}
sink.Key(alg, key)
break
}
return nil
}
// KeyProviderFunc is a type of KeyProvider that is implemented by
// a single function. You can use this to create ad-hoc `KeyProvider`
// instances.
type KeyProviderFunc func(context.Context, KeySink, *Signature, *Message) error
func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, msg *Message) error {
return kp(ctx, sink, sig, msg)
}
|