1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
|
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package credproviders
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"time"
"go.mongodb.org/mongo-driver/internal/aws/credentials"
"go.mongodb.org/mongo-driver/internal/uuid"
)
const (
// assumeRoleProviderName provides a name of assume role provider
assumeRoleProviderName = "AssumeRoleProvider"
stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
)
// An AssumeRoleProvider retrieves credentials for assume role with web identity.
type AssumeRoleProvider struct {
AwsRoleArnEnv EnvVar
AwsWebIdentityTokenFileEnv EnvVar
AwsRoleSessionNameEnv EnvVar
httpClient *http.Client
expiration time.Time
// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
//
// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired.
expiryWindow time.Duration
}
// NewAssumeRoleProvider returns a pointer to an assume role provider.
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
return &AssumeRoleProvider{
// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
httpClient: httpClient,
expiryWindow: expiryWindow,
}
}
// RetrieveWithContext retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
const defaultHTTPTimeout = 10 * time.Second
v := credentials.Value{ProviderName: assumeRoleProviderName}
roleArn := a.AwsRoleArnEnv.Get()
tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
if tokenFile == "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
}
if tokenFile != "" && roleArn == "" {
return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
}
if tokenFile == "" && roleArn != "" {
return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
}
token, err := ioutil.ReadFile(tokenFile)
if err != nil {
return v, err
}
sessionName := a.AwsRoleSessionNameEnv.Get()
if sessionName == "" {
// Use a UUID if the RoleSessionName is not given.
id, err := uuid.New()
if err != nil {
return v, err
}
sessionName = id.String()
}
fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))
req, err := http.NewRequest(http.MethodPost, fullURI, nil)
if err != nil {
return v, err
}
req.Header.Set("Accept", "application/json")
ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
defer cancel()
resp, err := a.httpClient.Do(req.WithContext(ctx))
if err != nil {
return v, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return v, fmt.Errorf("response failure: %s", resp.Status)
}
var stsResp struct {
Response struct {
Result struct {
Credentials struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"SessionToken"`
Expiration float64 `json:"Expiration"`
} `json:"Credentials"`
} `json:"AssumeRoleWithWebIdentityResult"`
} `json:"AssumeRoleWithWebIdentityResponse"`
}
err = json.NewDecoder(resp.Body).Decode(&stsResp)
if err != nil {
return v, err
}
v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
v.SessionToken = stsResp.Response.Result.Credentials.Token
if !v.HasKeys() {
return v, errors.New("failed to retrieve web identity keys")
}
sec := int64(stsResp.Response.Result.Credentials.Expiration)
a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)
return v, nil
}
// Retrieve retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
return a.RetrieveWithContext(context.Background())
}
// IsExpired returns true if the credentials are expired.
func (a *AssumeRoleProvider) IsExpired() bool {
return a.expiration.Before(time.Now())
}
|