1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
package jwk
import (
"context"
"fmt"
"io"
"math"
"os"
"strconv"
"sync"
"sync/atomic"
"github.com/lestrrat-go/httprc"
)
type Fetcher interface {
Fetch(context.Context, string, ...FetchOption) (Set, error)
}
type FetchFunc func(context.Context, string, ...FetchOption) (Set, error)
func (f FetchFunc) Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) {
return f(ctx, u, options...)
}
var globalFetcher httprc.Fetcher
var muGlobalFetcher sync.Mutex
var fetcherChanged uint32
func init() {
atomic.StoreUint32(&fetcherChanged, 1)
}
func getGlobalFetcher() httprc.Fetcher {
if v := atomic.LoadUint32(&fetcherChanged); v == 0 {
return globalFetcher
}
muGlobalFetcher.Lock()
defer muGlobalFetcher.Unlock()
if globalFetcher == nil {
var nworkers int
v := os.Getenv(`JWK_FETCHER_WORKER_COUNT`)
if c, err := strconv.ParseInt(v, 10, 64); err == nil {
if c > math.MaxInt {
nworkers = math.MaxInt
} else {
nworkers = int(c)
}
}
if nworkers < 1 {
nworkers = 3
}
globalFetcher = httprc.NewFetcher(context.Background(), httprc.WithFetcherWorkerCount(nworkers))
}
atomic.StoreUint32(&fetcherChanged, 0)
return globalFetcher
}
// SetGlobalFetcher allows users to specify a custom global fetcher,
// which is used by the `Fetch` function. Assigning `nil` forces
// the default fetcher to be (re)created when the next call to
// `jwk.Fetch` occurs
//
// You only need to call this function when you want to
// either change the fetching behavior (for example, you want to change
// how the default whitelist is handled), or when you want to control
// the lifetime of the global fetcher, for example for tests
// that require a clean shutdown.
//
// If you do use this function to set a custom fetcher, and you
// control its termination, make sure that you call `jwk.SetGlobalFetcher()`
// one more time (possibly with `nil`) to assign a valid fetcher.
// Otherwise, once the fetcher is invalidated, subsequent calls to `jwk.Fetch`
// may hang, causing very hard to debug problems.
//
// If you are sure you no longer need `jwk.Fetch` after terminating the
// fetcher, then you the above caution is not necessary.
func SetGlobalFetcher(f httprc.Fetcher) {
muGlobalFetcher.Lock()
globalFetcher = f
muGlobalFetcher.Unlock()
atomic.StoreUint32(&fetcherChanged, 1)
}
// Fetch fetches a JWK resource specified by a URL. The url must be
// pointing to a resource that is supported by `net/http`.
//
// If you are using the same `jwk.Set` for long periods of time during
// the lifecycle of your program, and would like to periodically refresh the
// contents of the object with the data at the remote resource,
// consider using `jwk.Cache`, which automatically refreshes
// jwk.Set objects asynchronously.
//
// Please note that underneath the `jwk.Fetch` function, it uses a global
// object that spawns goroutines that are present until the go runtime
// exits. Initially this global variable is uninitialized, but upon
// calling `jwk.Fetch` once, it is initialized and goroutines are spawned.
// If you want to control the lifetime of these goroutines, you can
// call `jwk.SetGlobalFetcher` with a custom fetcher which is tied to
// a `context.Context` object that you can control.
func Fetch(ctx context.Context, u string, options ...FetchOption) (Set, error) {
var hrfopts []httprc.FetchOption
var parseOptions []ParseOption
for _, option := range options {
if parseOpt, ok := option.(ParseOption); ok {
parseOptions = append(parseOptions, parseOpt)
continue
}
//nolint:forcetypeassert
switch option.Ident() {
case identHTTPClient{}:
hrfopts = append(hrfopts, httprc.WithHTTPClient(option.Value().(HTTPClient)))
case identFetchWhitelist{}:
hrfopts = append(hrfopts, httprc.WithWhitelist(option.Value().(httprc.Whitelist)))
}
}
res, err := getGlobalFetcher().Fetch(ctx, u, hrfopts...)
if err != nil {
return nil, fmt.Errorf(`failed to fetch %q: %w`, u, err)
}
buf, err := io.ReadAll(res.Body)
defer res.Body.Close()
if err != nil {
return nil, fmt.Errorf(`failed to read response body for %q: %w`, u, err)
}
return Parse(buf, parseOptions...)
}
|