1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
|
package workloadapi
import (
"context"
"sync"
"github.com/spiffe/go-spiffe/v2/bundle/jwtbundle"
"github.com/spiffe/go-spiffe/v2/svid/jwtsvid"
"github.com/zeebo/errs"
)
type sourceClient interface {
WatchX509Context(context.Context, X509ContextWatcher) error
WatchJWTBundles(context.Context, JWTBundleWatcher) error
FetchJWTSVID(context.Context, jwtsvid.Params) (*jwtsvid.SVID, error)
FetchJWTSVIDs(context.Context, jwtsvid.Params) ([]*jwtsvid.SVID, error)
Close() error
}
type watcherConfig struct {
client sourceClient
clientOptions []ClientOption
}
type watcher struct {
updatedCh chan struct{}
client sourceClient
ownsClient bool
cancel func()
wg sync.WaitGroup
closeMtx sync.Mutex
closed bool
closeErr error
x509ContextFn func(*X509Context)
x509ContextSet chan struct{}
x509ContextSetOnce sync.Once
jwtBundlesFn func(*jwtbundle.Set)
jwtBundlesSet chan struct{}
jwtBundlesSetOnce sync.Once
}
func newWatcher(ctx context.Context, config watcherConfig, x509ContextFn func(*X509Context), jwtBundlesFn func(*jwtbundle.Set)) (_ *watcher, err error) {
w := &watcher{
updatedCh: make(chan struct{}, 1),
client: config.client,
cancel: func() {},
x509ContextFn: x509ContextFn,
x509ContextSet: make(chan struct{}),
jwtBundlesFn: jwtBundlesFn,
jwtBundlesSet: make(chan struct{}),
}
// If this function fails, we need to clean up the source.
defer func() {
if err != nil {
err = errs.Combine(err, w.Close())
}
}()
// Initialize a new client unless one is provided by the options
if w.client == nil {
client, err := New(ctx, config.clientOptions...)
if err != nil {
return nil, err
}
w.client = client
w.ownsClient = true
}
errCh := make(chan error, 2)
waitFor := func(has <-chan struct{}) error {
select {
case <-has:
return nil
case err := <-errCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
// Kick up a background goroutine that watches the Workload API for
// updates.
var watchCtx context.Context
watchCtx, w.cancel = context.WithCancel(context.Background())
if w.x509ContextFn != nil {
w.wg.Add(1)
go func() {
defer w.wg.Done()
errCh <- w.client.WatchX509Context(watchCtx, w)
}()
if err := waitFor(w.x509ContextSet); err != nil {
return nil, err
}
}
if w.jwtBundlesFn != nil {
w.wg.Add(1)
go func() {
defer w.wg.Done()
errCh <- w.client.WatchJWTBundles(watchCtx, w)
}()
if err := waitFor(w.jwtBundlesSet); err != nil {
return nil, err
}
}
// Drain the update channel since this function blocks until an update and
// don't want callers to think there was an update on the source right
// after it was initialized. If we ever allow the watcher to be initialzed
// without waiting, this reset should be removed.
w.drainUpdated()
return w, nil
}
// Close closes the watcher, dropping the connection to the Workload API.
func (w *watcher) Close() error {
w.closeMtx.Lock()
defer w.closeMtx.Unlock()
if !w.closed {
w.cancel()
w.wg.Wait()
// Close() can be called by New() to close a partially initialized source.
// Only close the client if it has been set and the source owns it.
if w.client != nil && w.ownsClient {
w.closeErr = w.client.Close()
}
w.closed = true
}
return w.closeErr
}
func (w *watcher) OnX509ContextUpdate(x509Context *X509Context) {
w.x509ContextFn(x509Context)
w.triggerUpdated()
w.x509ContextSetOnce.Do(func() {
close(w.x509ContextSet)
})
}
func (w *watcher) OnX509ContextWatchError(err error) {
// The watcher doesn't do anything special with the error. If logging is
// desired, it should be provided to the Workload API client.
}
func (w *watcher) OnJWTBundlesUpdate(jwtBundles *jwtbundle.Set) {
w.jwtBundlesFn(jwtBundles)
w.triggerUpdated()
w.jwtBundlesSetOnce.Do(func() {
close(w.jwtBundlesSet)
})
}
func (w *watcher) OnJWTBundlesWatchError(error) {
// The watcher doesn't do anything special with the error. If logging is
// desired, it should be provided to the Workload API client.
}
func (w *watcher) WaitUntilUpdated(ctx context.Context) error {
select {
case <-w.updatedCh:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (w *watcher) Updated() <-chan struct{} {
return w.updatedCh
}
func (w *watcher) drainUpdated() {
select {
case <-w.updatedCh:
default:
}
}
func (w *watcher) triggerUpdated() {
w.drainUpdated()
w.updatedCh <- struct{}{}
}
|