File: keystore.go

package info (click to toggle)
golang-github-smallstep-certificates 0.28.4-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,684 kB
  • sloc: sh: 367; makefile: 129
file content (127 lines) | stat: -rw-r--r-- 2,819 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package provisioner

import (
	"encoding/json"
	"math/rand"
	"regexp"
	"strconv"
	"sync"
	"time"

	"github.com/pkg/errors"
	"go.step.sm/crypto/jose"
)

const (
	defaultCacheAge    = 12 * time.Hour
	defaultCacheJitter = 1 * time.Hour
)

var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`)

type keyStore struct {
	sync.RWMutex
	client HTTPClient
	uri    string
	keySet jose.JSONWebKeySet
	expiry time.Time
	jitter time.Duration
}

func newKeyStore(client HTTPClient, uri string) (*keyStore, error) {
	keys, age, err := getKeysFromJWKsURI(client, uri)
	if err != nil {
		return nil, err
	}
	jitter := getCacheJitter(age)
	return &keyStore{
		client: client,
		uri:    uri,
		keySet: keys,
		expiry: getExpirationTime(age, jitter),
		jitter: jitter,
	}, nil
}

func (ks *keyStore) Get(kid string) (keys []jose.JSONWebKey) {
	ks.RLock()
	// Force reload if expiration has passed
	if time.Now().After(ks.expiry) {
		ks.RUnlock()
		ks.reload()
		ks.RLock()
	}
	keys = ks.keySet.Key(kid)
	ks.RUnlock()
	return
}

func (ks *keyStore) reload() {
	if keys, age, err := getKeysFromJWKsURI(ks.client, ks.uri); err == nil {
		ks.Lock()
		ks.keySet = keys
		ks.jitter = getCacheJitter(age)
		ks.expiry = getExpirationTime(age, ks.jitter)
		ks.Unlock()
	}
}

func getKeysFromJWKsURI(client HTTPClient, uri string) (jose.JSONWebKeySet, time.Duration, error) {
	var keys jose.JSONWebKeySet
	resp, err := client.Get(uri)
	if err != nil {
		return keys, 0, errors.Wrapf(err, "failed to connect to %s", uri)
	}
	defer resp.Body.Close()
	if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
		return keys, 0, errors.Wrapf(err, "error reading %s", uri)
	}
	return keys, getCacheAge(resp.Header.Get("cache-control")), nil
}

func getCacheAge(cacheControl string) time.Duration {
	age := defaultCacheAge
	if cacheControl != "" {
		match := maxAgeRegex.FindAllStringSubmatch(cacheControl, -1)
		if len(match) > 0 {
			if len(match[0]) == 2 {
				maxAge := match[0][1]
				maxAgeInt, err := strconv.ParseInt(maxAge, 10, 64)
				if err != nil {
					return defaultCacheAge
				}
				age = time.Duration(maxAgeInt) * time.Second
			}
		}
	}
	return age
}

func getCacheJitter(age time.Duration) time.Duration {
	switch {
	case age > time.Hour:
		return defaultCacheJitter
	case age == 0:
		// Avoids a 0 jitter. The duration is not important as it will rotate
		// automatically on each Get request.
		return defaultCacheJitter
	default:
		return age / 3
	}
}

func getExpirationTime(age, jitter time.Duration) time.Time {
	if age > 0 {
		n := rand.Int63n(int64(jitter)) //nolint:gosec // not used for cryptographic security
		age -= time.Duration(n)
	}
	return time.Now().Truncate(time.Second).Add(abs(age))
}

// abs returns the absolute value of n.
func abs(n time.Duration) time.Duration {
	if n < 0 {
		return -n
	}
	return n
}