File: key_provider.go

package info (click to toggle)
golang-github-lestrrat-go-jwx 2.1.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,872 kB
  • sloc: sh: 222; makefile: 86; perl: 62
file content (276 lines) | stat: -rw-r--r-- 8,346 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
package jws

import (
	"context"
	"fmt"
	"net/url"
	"sync"

	"github.com/lestrrat-go/jwx/v2/jwa"
	"github.com/lestrrat-go/jwx/v2/jwk"
)

// KeyProvider is responsible for providing key(s) to sign or verify a payload.
// Multiple `jws.KeyProvider`s can be passed to `jws.Verify()` or `jws.Sign()`
//
// `jws.Sign()` can only accept static key providers via `jws.WithKey()`,
// while `jws.Verify()` can accept `jws.WithKey()`, `jws.WithKeySet()`,
// `jws.WithVerifyAuto()`, and `jws.WithKeyProvider()`.
//
// Understanding how this works is crucial to learn how this package works.
//
// `jws.Sign()` is straightforward: signatures are created for each
// provided key.
//
// `jws.Verify()` is a bit more involved, because there are cases you
// will want to compute/deduce/guess the keys that you would like to
// use for verification.
//
// The first thing that `jws.Verify()` does is to collect the
// KeyProviders from the option list that the user provided (presented in pseudocode):
//
//	keyProviders := filterKeyProviders(options)
//
// Then, remember that a JWS message may contain multiple signatures in the
// message. For each signature, we call on the KeyProviders to give us
// the key(s) to use on this signature:
//
//	for sig in msg.Signatures {
//	  for kp in keyProviders {
//	    kp.FetchKeys(ctx, sink, sig, msg)
//	    ...
//	  }
//	}
//
// The `sink` argument passed to the KeyProvider is a temporary storage
// for the keys (either a jwk.Key or a "raw" key). The `KeyProvider`
// is responsible for sending keys into the `sink`.
//
// When called, the `KeyProvider` created by `jws.WithKey()` sends the same key,
// `jws.WithKeySet()` sends keys that matches a particular `kid` and `alg`,
// `jws.WithVerifyAuto()` fetches a JWK from the `jku` URL,
// and finally `jws.WithKeyProvider()` allows you to execute arbitrary
// logic to provide keys. If you are providing a custom `KeyProvider`,
// you should execute the necessary checks or retrieval of keys, and
// then send the key(s) to the sink:
//
//	sink.Key(alg, key)
//
// These keys are then retrieved and tried for each signature, until
// a match is found:
//
//	keys := sink.Keys()
//	for key in keys {
//	  if givenSignature == makeSignature(key, payload, ...)) {
//	    return OK
//	  }
//	}
type KeyProvider interface {
	FetchKeys(context.Context, KeySink, *Signature, *Message) error
}

// KeySink is a data storage where `jws.KeyProvider` objects should
// send their keys to.
type KeySink interface {
	Key(jwa.SignatureAlgorithm, interface{})
}

type algKeyPair struct {
	alg jwa.KeyAlgorithm
	key interface{}
}

type algKeySink struct {
	mu   sync.Mutex
	list []algKeyPair
}

func (s *algKeySink) Key(alg jwa.SignatureAlgorithm, key interface{}) {
	s.mu.Lock()
	s.list = append(s.list, algKeyPair{alg, key})
	s.mu.Unlock()
}

type staticKeyProvider struct {
	alg jwa.SignatureAlgorithm
	key interface{}
}

func (kp *staticKeyProvider) FetchKeys(_ context.Context, sink KeySink, _ *Signature, _ *Message) error {
	sink.Key(kp.alg, kp.key)
	return nil
}

type keySetProvider struct {
	set                  jwk.Set
	requireKid           bool // true if `kid` must be specified
	useDefault           bool // true if the first key should be used iff there's exactly one key in set
	inferAlgorithm       bool // true if the algorithm should be inferred from key type
	multipleKeysPerKeyID bool // true if we should attempt to match multiple keys per key ID. if false we assume that only one key exists for a given key ID
}

func (kp *keySetProvider) selectKey(sink KeySink, key jwk.Key, sig *Signature, _ *Message) error {
	if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() {
		return nil
	}

	if v := key.Algorithm(); v.String() != "" {
		var alg jwa.SignatureAlgorithm
		if err := alg.Accept(v); err != nil {
			return fmt.Errorf(`invalid signature algorithm %s: %w`, key.Algorithm(), err)
		}

		sink.Key(alg, key)
		return nil
	}

	if kp.inferAlgorithm {
		algs, err := AlgorithmsForKey(key)
		if err != nil {
			return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
		}

		// bail out if the JWT has a `alg` field, and it doesn't match
		if tokAlg := sig.ProtectedHeaders().Algorithm(); tokAlg != "" {
			for _, alg := range algs {
				if tokAlg == alg {
					sink.Key(alg, key)
					return nil
				}
			}
			return fmt.Errorf(`algorithm in the message does not match any of the inferred algorithms`)
		}

		// Yes, you get to try them all!!!!!!!
		for _, alg := range algs {
			sink.Key(alg, key)
		}
		return nil
	}
	return nil
}

func (kp *keySetProvider) FetchKeys(_ context.Context, sink KeySink, sig *Signature, msg *Message) error {
	if kp.requireKid {
		wantedKid := sig.ProtectedHeaders().KeyID()
		if wantedKid == "" {
			// If the kid is NOT specified... kp.useDefault needs to be true, and the
			// JWKs must have exactly one key in it
			if !kp.useDefault {
				return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token`)
			} else if kp.useDefault && kp.set.Len() > 1 {
				return fmt.Errorf(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
			}

			// if we got here, then useDefault == true AND there is exactly
			// one key in the set.
			key, _ := kp.set.Key(0)
			return kp.selectKey(sink, key, sig, msg)
		}

		// Otherwise we better be able to look up the key.
		// <= v2.0.3 backwards compatible case: only match a single key
		// whose key ID matches `wantedKid`
		if !kp.multipleKeysPerKeyID {
			key, ok := kp.set.LookupKeyID(wantedKid)
			if !ok {
				return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
			}
			return kp.selectKey(sink, key, sig, msg)
		}

		// if multipleKeysPerKeyID is true, we attempt all keys whose key ID matches
		// the wantedKey
		var ok bool
		for i := 0; i < kp.set.Len(); i++ {
			key, _ := kp.set.Key(i)
			if key.KeyID() != wantedKid {
				continue
			}

			if err := kp.selectKey(sink, key, sig, msg); err != nil {
				continue
			}
			ok = true
			// continue processing so that we try all keys with the same key ID
		}
		if !ok {
			return fmt.Errorf(`failed to find key with key ID %q in key set`, wantedKid)
		}
		return nil
	}

	// Otherwise just try all keys
	for i := 0; i < kp.set.Len(); i++ {
		key, _ := kp.set.Key(i)
		if err := kp.selectKey(sink, key, sig, msg); err != nil {
			continue
		}
	}
	return nil
}

type jkuProvider struct {
	fetcher jwk.Fetcher
	options []jwk.FetchOption
}

func (kp jkuProvider) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, _ *Message) error {
	kid := sig.ProtectedHeaders().KeyID()
	if kid == "" {
		return fmt.Errorf(`use of "jku" requires that the payload contain a "kid" field in the protected header`)
	}

	// errors here can't be reliably passed to the consumers.
	// it's unfortunate, but if you need this control, you are
	// going to have to write your own fetcher
	u := sig.ProtectedHeaders().JWKSetURL()
	if u == "" {
		return fmt.Errorf(`use of "jku" field specified, but the field is empty`)
	}
	uo, err := url.Parse(u)
	if err != nil {
		return fmt.Errorf(`failed to parse "jku": %w`, err)
	}
	if uo.Scheme != "https" {
		return fmt.Errorf(`url in "jku" must be HTTPS`)
	}

	set, err := kp.fetcher.Fetch(ctx, u, kp.options...)
	if err != nil {
		return fmt.Errorf(`failed to fetch %q: %w`, u, err)
	}

	key, ok := set.LookupKeyID(kid)
	if !ok {
		// It is not an error if the key with the kid doesn't exist
		return nil
	}

	algs, err := AlgorithmsForKey(key)
	if err != nil {
		return fmt.Errorf(`failed to get a list of signature methods for key type %s: %w`, key.KeyType(), err)
	}

	hdrAlg := sig.ProtectedHeaders().Algorithm()
	for _, alg := range algs {
		// if we have an "alg" field in the JWS, we can only proceed if
		// the inferred algorithm matches
		if hdrAlg != "" && hdrAlg != alg {
			continue
		}

		sink.Key(alg, key)
		break
	}
	return nil
}

// KeyProviderFunc is a type of KeyProvider that is implemented by
// a single function. You can use this to create ad-hoc `KeyProvider`
// instances.
type KeyProviderFunc func(context.Context, KeySink, *Signature, *Message) error

func (kp KeyProviderFunc) FetchKeys(ctx context.Context, sink KeySink, sig *Signature, msg *Message) error {
	return kp(ctx, sink, sig, msg)
}