1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
package s3
import (
"context"
"crypto/hmac"
"crypto/sha256"
"errors"
"fmt"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
"github.com/aws/smithy-go/container/private/cache"
"github.com/aws/smithy-go/container/private/cache/lru"
)
const s3ExpressCacheCap = 100
const s3ExpressRefreshWindow = 1 * time.Minute
type cacheKey struct {
CredentialsHash string // hmac(sigv4 akid, sigv4 secret)
Bucket string
}
func (c cacheKey) Slug() string {
return fmt.Sprintf("%s%s", c.CredentialsHash, c.Bucket)
}
type sessionCredsCache struct {
mu sync.Mutex
cache cache.Cache
}
func (c *sessionCredsCache) Get(key cacheKey) (*aws.Credentials, bool) {
c.mu.Lock()
defer c.mu.Unlock()
if v, ok := c.cache.Get(key); ok {
return v.(*aws.Credentials), true
}
return nil, false
}
func (c *sessionCredsCache) Put(key cacheKey, creds *aws.Credentials) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache.Put(key, creds)
}
// The default S3Express provider uses an LRU cache with a capacity of 100.
//
// Credentials will be refreshed asynchronously when a Retrieve() call is made
// for cached credentials within an expiry window (1 minute, currently
// non-configurable).
type defaultS3ExpressCredentialsProvider struct {
sf singleflight.Group
client createSessionAPIClient
cache *sessionCredsCache
refreshWindow time.Duration
v4creds aws.CredentialsProvider // underlying credentials used for CreateSession
}
type createSessionAPIClient interface {
CreateSession(context.Context, *CreateSessionInput, ...func(*Options)) (*CreateSessionOutput, error)
}
func newDefaultS3ExpressCredentialsProvider() *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
cache: &sessionCredsCache{
cache: lru.New(s3ExpressCacheCap),
},
refreshWindow: s3ExpressRefreshWindow,
}
}
// returns a cloned provider using new base credentials, used when per-op
// config mutations change the credentials provider
func (p *defaultS3ExpressCredentialsProvider) CloneWithBaseCredentials(v4creds aws.CredentialsProvider) *defaultS3ExpressCredentialsProvider {
return &defaultS3ExpressCredentialsProvider{
client: p.client,
cache: p.cache,
refreshWindow: p.refreshWindow,
v4creds: v4creds,
}
}
func (p *defaultS3ExpressCredentialsProvider) Retrieve(ctx context.Context, bucket string) (aws.Credentials, error) {
v4creds, err := p.v4creds.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, fmt.Errorf("get sigv4 creds: %w", err)
}
key := cacheKey{
CredentialsHash: gethmac(v4creds.AccessKeyID, v4creds.SecretAccessKey),
Bucket: bucket,
}
creds, ok := p.cache.Get(key)
if !ok || creds.Expired() {
return p.awaitDoChanRetrieve(ctx, key)
}
if creds.Expires.Sub(sdk.NowTime()) <= p.refreshWindow {
p.doChanRetrieve(ctx, key)
}
return *creds, nil
}
func (p *defaultS3ExpressCredentialsProvider) doChanRetrieve(ctx context.Context, key cacheKey) <-chan singleflight.Result {
return p.sf.DoChan(key.Slug(), func() (interface{}, error) {
return p.retrieve(ctx, key)
})
}
func (p *defaultS3ExpressCredentialsProvider) awaitDoChanRetrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
ch := p.doChanRetrieve(ctx, key)
select {
case r := <-ch:
return r.Val.(aws.Credentials), r.Err
case <-ctx.Done():
return aws.Credentials{}, errors.New("s3express retrieve credentials canceled")
}
}
func (p *defaultS3ExpressCredentialsProvider) retrieve(ctx context.Context, key cacheKey) (aws.Credentials, error) {
resp, err := p.client.CreateSession(ctx, &CreateSessionInput{
Bucket: aws.String(key.Bucket),
})
if err != nil {
return aws.Credentials{}, err
}
creds, err := credentialsFromResponse(resp)
if err != nil {
return aws.Credentials{}, err
}
p.cache.Put(key, creds)
return *creds, nil
}
func credentialsFromResponse(o *CreateSessionOutput) (*aws.Credentials, error) {
if o.Credentials == nil {
return nil, errors.New("s3express session credentials unset")
}
if o.Credentials.AccessKeyId == nil || o.Credentials.SecretAccessKey == nil || o.Credentials.SessionToken == nil || o.Credentials.Expiration == nil {
return nil, errors.New("s3express session credentials missing one or more required fields")
}
return &aws.Credentials{
AccessKeyID: *o.Credentials.AccessKeyId,
SecretAccessKey: *o.Credentials.SecretAccessKey,
SessionToken: *o.Credentials.SessionToken,
CanExpire: true,
Expires: *o.Credentials.Expiration,
}, nil
}
func gethmac(p, key string) string {
hash := hmac.New(sha256.New, []byte(key))
hash.Write([]byte(p))
return string(hash.Sum(nil))
}
|