1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
package provisioner
import (
"encoding/json"
"math/rand"
"regexp"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
"go.step.sm/crypto/jose"
)
const (
defaultCacheAge = 12 * time.Hour
defaultCacheJitter = 1 * time.Hour
)
var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)
type keyStore struct {
sync.RWMutex
client HTTPClient
uri string
keySet jose.JSONWebKeySet
expiry time.Time
jitter time.Duration
}
func newKeyStore(client HTTPClient, uri string) (*keyStore, error) {
keys, age, err := getKeysFromJWKsURI(client, uri)
if err != nil {
return nil, err
}
jitter := getCacheJitter(age)
return &keyStore{
client: client,
uri: uri,
keySet: keys,
expiry: getExpirationTime(age, jitter),
jitter: jitter,
}, nil
}
func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
ks.RLock()
// Force reload if expiration has passed
if time.Now().After(ks.expiry) {
ks.RUnlock()
ks.reload()
ks.RLock()
}
keys = ks.keySet.Key(kid)
ks.RUnlock()
return
}
func (ks *keyStore) reload() {
if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil {
ks.Lock()
ks.keySet = keys
ks.jitter = getCacheJitter(age)
ks.expiry = getExpirationTime(age, ks.jitter)
ks.Unlock()
}
}
func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) {
var keys jose.JSONWebKeySet
resp, err := client.Get(uri)
if err != nil {
return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
}
defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
return keys, 0, errors.Wrapf(err, "error reading %s", uri)
}
return keys, getCacheAge(resp.Header.Get("cache-control")), nil
}
func getCacheAge(cacheControl string) time.Duration {
age := defaultCacheAge
if cacheControl != "" {
match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
if len(match) > 0 {
if len(match[0]) == 2 {
maxAge := match[0][1]
maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
if err != nil {
return defaultCacheAge
}
age = time.Duration(maxAgeInt) * time.Second
}
}
}
return age
}
func getCacheJitter(age time.Duration) time.Duration {
switch {
case age > time.Hour:
return defaultCacheJitter
case age == 0:
// Avoids a 0 jitter. The duration is not important as it will rotate
// automatically on each Get request.
return defaultCacheJitter
default:
return age / 3
}
}
func getExpirationTime(age, jitter time.Duration) time.Time {
if age > 0 {
n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security
age -= time.Duration(n)
}
return time.Now().Truncate(time.Second).Add(abs(age))
}
// abs returns the absolute value of n.
func abs(n time.Duration) time.Duration {
if n < 0 {
return -n
}
return n
}
|