1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
|
package rp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
jose "github.com/go-jose/go-jose/v4"
"github.com/zitadel/oidc/v3/pkg/client"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
func NewRemoteKeySet(client *http.Client, jwksURL string, opts ...func(*remoteKeySet)) oidc.KeySet {
keyset := &remoteKeySet{httpClient: client, jwksURL: jwksURL}
for _, opt := range opts {
opt(keyset)
}
return keyset
}
// SkipRemoteCheck will suppress checking for new remote keys if signature validation fails with cached keys
// and no kid header is set in the JWT
//
// this might be handy to save some unnecessary round trips in cases where the JWT does not contain a kid header and
// there is only a single remote key
// please notice that remote keys will then only be fetched if cached keys are empty
func SkipRemoteCheck() func(set *remoteKeySet) {
return func(set *remoteKeySet) {
set.skipRemoteCheck = true
}
}
type remoteKeySet struct {
jwksURL string
httpClient *http.Client
defaultAlg string
skipRemoteCheck bool
// guard all other fields
mu sync.Mutex
// inflight suppresses parallel execution of updateKeys and allows
// multiple goroutines to wait for its result.
inflight *inflight
// A set of cached keys and their expiry.
cachedKeys []jose.JSONWebKey
}
// inflight is used to wait on some in-flight request from multiple goroutines.
type inflight struct {
doneCh chan struct{}
keys []jose.JSONWebKey
err error
}
func newInflight() *inflight {
return &inflight{doneCh: make(chan struct{})}
}
// wait returns a channel that multiple goroutines can receive on. Once it returns
// a value, the inflight request is done and result() can be inspected.
func (i *inflight) wait() <-chan struct{} {
return i.doneCh
}
// done can only be called by a single goroutine. It records the result of the
// inflight request and signals other goroutines that the result is safe to
// inspect.
func (i *inflight) done(keys []jose.JSONWebKey, err error) {
i.keys = keys
i.err = err
close(i.doneCh)
}
// result cannot be called until the wait() channel has returned a value.
func (i *inflight) result() ([]jose.JSONWebKey, error) {
return i.keys, i.err
}
func (r *remoteKeySet) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) ([]byte, error) {
ctx, span := client.Tracer.Start(ctx, "VerifySignature")
defer span.End()
keyID, alg := oidc.GetKeyIDAndAlg(jws)
if alg == "" {
alg = r.defaultAlg
}
payload, err := r.verifySignatureCached(jws, keyID, alg)
if payload != nil {
return payload, nil
}
if err != nil {
return nil, err
}
return r.verifySignatureRemote(ctx, jws, keyID, alg)
}
// verifySignatureCached checks for a matching key in the cached key list
//
// if there is only one possible, it tries to verify the signature and will return the payload if successful
//
// it only returns an error if signature validation fails and keys exactMatch which is if either:
// - both kid are empty and skipRemoteCheck is set to true
// - or both (JWT and JWK) kid are equal
//
// otherwise it will return no error (so remote keys will be loaded)
func (r *remoteKeySet) verifySignatureCached(jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
keys := r.keysFromCache()
if len(keys) == 0 {
return nil, nil
}
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
// no key / multiple found, try with remote keys
return nil, nil //nolint:nilerr
}
payload, err := jws.Verify(&key)
if payload != nil {
return payload, nil
}
if !r.exactMatch(key.KeyID, keyID) {
// no exact key match, try getting better match with remote keys
return nil, nil
}
return nil, fmt.Errorf("signature verification failed: %w", err)
}
func (r *remoteKeySet) exactMatch(jwkID, jwsID string) bool {
if jwkID == "" && jwsID == "" {
return r.skipRemoteCheck
}
return jwkID == jwsID
}
func (r *remoteKeySet) verifySignatureRemote(ctx context.Context, jws *jose.JSONWebSignature, keyID, alg string) ([]byte, error) {
ctx, span := client.Tracer.Start(ctx, "verifySignatureRemote")
defer span.End()
keys, err := r.keysFromRemote(ctx)
if err != nil {
return nil, fmt.Errorf("unable to fetch key for signature validation: %w", err)
}
key, err := oidc.FindMatchingKey(keyID, oidc.KeyUseSignature, alg, keys...)
if err != nil {
return nil, fmt.Errorf("unable to validate signature: %w", err)
}
payload, err := jws.Verify(&key)
if err != nil {
return nil, fmt.Errorf("signature verification failed: %w", err)
}
return payload, nil
}
func (r *remoteKeySet) keysFromCache() (keys []jose.JSONWebKey) {
r.mu.Lock()
defer r.mu.Unlock()
return r.cachedKeys
}
// keysFromRemote syncs the key set from the remote set, records the values in the
// cache, and returns the key set.
func (r *remoteKeySet) keysFromRemote(ctx context.Context) ([]jose.JSONWebKey, error) {
ctx, span := client.Tracer.Start(ctx, "keysFromRemote")
defer span.End()
// Need to lock to inspect the inflight request field.
r.mu.Lock()
// If there's not a current inflight request, create one.
if r.inflight == nil {
r.inflight = newInflight()
// This goroutine has exclusive ownership over the current inflight
// request. It releases the resource by nil'ing the inflight field
// once the goroutine is done.
go r.updateKeys(ctx)
}
inflight := r.inflight
r.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-inflight.wait():
return inflight.result()
}
}
func (r *remoteKeySet) updateKeys(ctx context.Context) {
ctx, span := client.Tracer.Start(ctx, "updateKeys")
defer span.End()
// Sync keys and finish inflight when that's done.
keys, err := r.fetchRemoteKeys(ctx)
r.inflight.done(keys, err)
// Lock to update the keys and indicate that there is no longer an
// inflight request.
r.mu.Lock()
defer r.mu.Unlock()
if err == nil {
r.cachedKeys = keys
}
// Free inflight so a different request can run.
r.inflight = nil
}
func (r *remoteKeySet) fetchRemoteKeys(ctx context.Context) ([]jose.JSONWebKey, error) {
ctx, span := client.Tracer.Start(ctx, "fetchRemoteKeys")
defer span.End()
req, err := http.NewRequestWithContext(ctx, "GET", r.jwksURL, nil)
if err != nil {
return nil, fmt.Errorf("oidc: can't create request: %v", err)
}
keySet := new(jsonWebKeySet)
if err = httphelper.HttpRequest(r.httpClient, req, keySet); err != nil {
return nil, fmt.Errorf("oidc: failed to get keys: %v", err)
}
return keySet.Keys, nil
}
// jsonWebKeySet is an alias for jose.JSONWebKeySet which ignores unknown key types (kty)
type jsonWebKeySet jose.JSONWebKeySet
// UnmarshalJSON overrides the default jose.JSONWebKeySet method to ignore any error
// which might occur because of unknown key types (kty)
func (k *jsonWebKeySet) UnmarshalJSON(data []byte) (err error) {
var raw rawJSONWebKeySet
err = json.Unmarshal(data, &raw)
if err != nil {
return err
}
for _, key := range raw.Keys {
webKey := new(jose.JSONWebKey)
err = webKey.UnmarshalJSON(key)
if err == nil {
k.Keys = append(k.Keys, *webKey)
}
}
return nil
}
type rawJSONWebKeySet struct {
Keys []json.RawMessage `json:"keys"`
}
|