File: assume_role_provider.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.17.1%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 25,988 kB
  • sloc: perl: 533; ansic: 491; python: 432; sh: 327; makefile: 174
file content (148 lines) | stat: -rw-r--r-- 4,892 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright (C) MongoDB, Inc. 2023-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package credproviders

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"time"

	"go.mongodb.org/mongo-driver/internal/aws/credentials"
	"go.mongodb.org/mongo-driver/internal/uuid"
)

const (
	// assumeRoleProviderName provides a name of assume role provider
	assumeRoleProviderName = "AssumeRoleProvider"

	stsURI = `https://sts.amazonaws.com/?Action=AssumeRoleWithWebIdentity&RoleSessionName=%s&RoleArn=%s&WebIdentityToken=%s&Version=2011-06-15`
)

// An AssumeRoleProvider retrieves credentials for assume role with web identity.
type AssumeRoleProvider struct {
	AwsRoleArnEnv              EnvVar
	AwsWebIdentityTokenFileEnv EnvVar
	AwsRoleSessionNameEnv      EnvVar

	httpClient *http.Client
	expiration time.Time

	// expiryWindow will allow the credentials to trigger refreshing prior to the credentials actually expiring.
	// This is beneficial so expiring credentials do not cause request to fail unexpectedly due to exceptions.
	//
	// So a ExpiryWindow of 10s would cause calls to IsExpired() to return true
	// 10 seconds before the credentials are actually expired.
	expiryWindow time.Duration
}

// NewAssumeRoleProvider returns a pointer to an assume role provider.
func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) *AssumeRoleProvider {
	return &AssumeRoleProvider{
		// AwsRoleArnEnv is the environment variable for AWS_ROLE_ARN
		AwsRoleArnEnv: EnvVar("AWS_ROLE_ARN"),
		// AwsWebIdentityTokenFileEnv is the environment variable for AWS_WEB_IDENTITY_TOKEN_FILE
		AwsWebIdentityTokenFileEnv: EnvVar("AWS_WEB_IDENTITY_TOKEN_FILE"),
		// AwsRoleSessionNameEnv is the environment variable for AWS_ROLE_SESSION_NAME
		AwsRoleSessionNameEnv: EnvVar("AWS_ROLE_SESSION_NAME"),
		httpClient:            httpClient,
		expiryWindow:          expiryWindow,
	}
}

// RetrieveWithContext retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
	const defaultHTTPTimeout = 10 * time.Second

	v := credentials.Value{ProviderName: assumeRoleProviderName}

	roleArn := a.AwsRoleArnEnv.Get()
	tokenFile := a.AwsWebIdentityTokenFileEnv.Get()
	if tokenFile == "" && roleArn == "" {
		return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE and AWS_ROLE_ARN are missing")
	}
	if tokenFile != "" && roleArn == "" {
		return v, errors.New("AWS_WEB_IDENTITY_TOKEN_FILE is set, but AWS_ROLE_ARN is missing")
	}
	if tokenFile == "" && roleArn != "" {
		return v, errors.New("AWS_ROLE_ARN is set, but AWS_WEB_IDENTITY_TOKEN_FILE is missing")
	}
	token, err := ioutil.ReadFile(tokenFile)
	if err != nil {
		return v, err
	}

	sessionName := a.AwsRoleSessionNameEnv.Get()
	if sessionName == "" {
		// Use a UUID if the RoleSessionName is not given.
		id, err := uuid.New()
		if err != nil {
			return v, err
		}
		sessionName = id.String()
	}

	fullURI := fmt.Sprintf(stsURI, sessionName, roleArn, string(token))

	req, err := http.NewRequest(http.MethodPost, fullURI, nil)
	if err != nil {
		return v, err
	}
	req.Header.Set("Accept", "application/json")

	ctx, cancel := context.WithTimeout(ctx, defaultHTTPTimeout)
	defer cancel()
	resp, err := a.httpClient.Do(req.WithContext(ctx))
	if err != nil {
		return v, err
	}
	defer resp.Body.Close()
	if resp.StatusCode != http.StatusOK {
		return v, fmt.Errorf("response failure: %s", resp.Status)
	}

	var stsResp struct {
		Response struct {
			Result struct {
				Credentials struct {
					AccessKeyID     string  `json:"AccessKeyId"`
					SecretAccessKey string  `json:"SecretAccessKey"`
					Token           string  `json:"SessionToken"`
					Expiration      float64 `json:"Expiration"`
				} `json:"Credentials"`
			} `json:"AssumeRoleWithWebIdentityResult"`
		} `json:"AssumeRoleWithWebIdentityResponse"`
	}

	err = json.NewDecoder(resp.Body).Decode(&stsResp)
	if err != nil {
		return v, err
	}
	v.AccessKeyID = stsResp.Response.Result.Credentials.AccessKeyID
	v.SecretAccessKey = stsResp.Response.Result.Credentials.SecretAccessKey
	v.SessionToken = stsResp.Response.Result.Credentials.Token
	if !v.HasKeys() {
		return v, errors.New("failed to retrieve web identity keys")
	}
	sec := int64(stsResp.Response.Result.Credentials.Expiration)
	a.expiration = time.Unix(sec, 0).Add(-a.expiryWindow)

	return v, nil
}

// Retrieve retrieves the keys from the AWS service.
func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) {
	return a.RetrieveWithContext(context.Background())
}

// IsExpired returns true if the credentials are expired.
func (a *AssumeRoleProvider) IsExpired() bool {
	return a.expiration.Before(time.Now())
}