1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
package v4a
import (
"context"
"crypto/ecdsa"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
)
// Credentials is Context, ECDSA, and Optional Session Token that can be used
// to sign requests using SigV4a
type Credentials struct {
Context string
PrivateKey *ecdsa.PrivateKey
SessionToken string
// Time the credentials will expire.
CanExpire bool
Expires time.Time
}
// Expired returns if the credentials have expired.
func (v Credentials) Expired() bool {
if v.CanExpire {
return !v.Expires.After(sdk.NowTime())
}
return false
}
// HasKeys returns if the credentials keys are set.
func (v Credentials) HasKeys() bool {
return len(v.Context) > 0 && v.PrivateKey != nil
}
// SymmetricCredentialAdaptor wraps a SigV4 AccessKey/SecretKey provider and adapts the credentials
// to a ECDSA PrivateKey for signing with SiV4a
type SymmetricCredentialAdaptor struct {
SymmetricProvider aws.CredentialsProvider
asymmetric atomic.Value
m sync.Mutex
}
// Retrieve retrieves symmetric credentials from the underlying provider.
func (s *SymmetricCredentialAdaptor) Retrieve(ctx context.Context) (aws.Credentials, error) {
symCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return aws.Credentials{}, err
}
if asymCreds := s.getCreds(); asymCreds == nil {
return symCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
asymCreds := s.getCreds()
if asymCreds == nil {
return symCreds, nil
}
// if the context does not match the access key id clear it
if asymCreds.Context != symCreds.AccessKeyID {
s.asymmetric.Store((*Credentials)(nil))
}
return symCreds, nil
}
// RetrievePrivateKey returns credentials suitable for SigV4a signing
func (s *SymmetricCredentialAdaptor) RetrievePrivateKey(ctx context.Context) (Credentials, error) {
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
s.m.Lock()
defer s.m.Unlock()
if asymCreds := s.getCreds(); asymCreds != nil {
return *asymCreds, nil
}
symmetricCreds, err := s.retrieveFromSymmetricProvider(ctx)
if err != nil {
return Credentials{}, fmt.Errorf("failed to retrieve symmetric credentials: %v", err)
}
privateKey, err := deriveKeyFromAccessKeyPair(symmetricCreds.AccessKeyID, symmetricCreds.SecretAccessKey)
if err != nil {
return Credentials{}, fmt.Errorf("failed to derive assymetric key from credentials")
}
creds := Credentials{
Context: symmetricCreds.AccessKeyID,
PrivateKey: privateKey,
SessionToken: symmetricCreds.SessionToken,
CanExpire: symmetricCreds.CanExpire,
Expires: symmetricCreds.Expires,
}
s.asymmetric.Store(&creds)
return creds, nil
}
func (s *SymmetricCredentialAdaptor) getCreds() *Credentials {
v := s.asymmetric.Load()
if v == nil {
return nil
}
c := v.(*Credentials)
if c != nil && c.HasKeys() && !c.Expired() {
return c
}
return nil
}
func (s *SymmetricCredentialAdaptor) retrieveFromSymmetricProvider(ctx context.Context) (aws.Credentials, error) {
credentials, err := s.SymmetricProvider.Retrieve(ctx)
if err != nil {
return aws.Credentials{}, err
}
return credentials, nil
}
// CredentialsProvider is the interface for a provider to retrieve credentials
// to sign requests with.
type CredentialsProvider interface {
RetrievePrivateKey(context.Context) (Credentials, error)
}
|