File: connect.go

package info (click to toggle)
golang-github-aws-aws-sdk-go-v2 1.17.1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 384,244 kB
  • sloc: java: 13,538; makefile: 400; sh: 137
file content (125 lines) | stat: -rw-r--r-- 3,461 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package auth

import (
	"context"
	"fmt"
	"net/http"
	"strconv"
	"strings"
	"time"

	"github.com/aws/aws-sdk-go-v2/aws"
	"github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)

const (
	signingID        = "rds-db"
	emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
)

// BuildAuthTokenOptions is the optional set of configuration properties for BuildAuthToken
type BuildAuthTokenOptions struct{}

// BuildAuthToken will return an authorization token used as the password for a DB
// connection.
//
// * endpoint - Endpoint consists of the port needed to connect to the DB. <host>:<port>
// * region - Region is the location of where the DB is
// * dbUser - User account within the database to sign in with
// * creds - Credentials to be signed with
//
// The following example shows how to use BuildAuthToken to create an authentication
// token for connecting to a MySQL database in RDS.
//
//	authToken, err := BuildAuthToken(dbEndpoint, awsRegion, dbUser, awsCreds)
//
//	// Create the MySQL DNS string for the DB connection
//	// user:password@protocol(endpoint)/dbname?<params>
//	connectStr = fmt.Sprintf("%s:%s@tcp(%s)/%s?allowCleartextPasswords=true&tls=rds",
//	   dbUser, authToken, dbEndpoint, dbName,
//	)
//
//	// Use db to perform SQL operations on database
//	db, err := sql.Open("mysql", connectStr)
//
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
// for more information on using IAM database authentication with RDS.
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
	_, port := validateURL(endpoint)
	if port == "" {
		return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
	}

	o := BuildAuthTokenOptions{}

	for _, fn := range optFns {
		fn(&o)
	}

	if creds == nil {
		return "", fmt.Errorf("credetials provider must not ne nil")
	}

	// the scheme is arbitrary and is only needed because validation of the URL requires one.
	if !(strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://")) {
		endpoint = "https://" + endpoint
	}

	req, err := http.NewRequest("GET", endpoint, nil)
	if err != nil {
		return "", err
	}
	values := req.URL.Query()
	values.Set("Action", "connect")
	values.Set("DBUser", dbUser)
	req.URL.RawQuery = values.Encode()

	signer := v4.NewSigner()

	credentials, err := creds.Retrieve(ctx)
	if err != nil {
		return "", err
	}

	// Expire Time: 15 minute
	query := req.URL.Query()
	query.Set("X-Amz-Expires", "900")
	req.URL.RawQuery = query.Encode()

	signedURI, _, err := signer.PresignHTTP(ctx, credentials, req, emptyPayloadHash, signingID, region, time.Now().UTC())
	if err != nil {
		return "", err
	}

	url := signedURI
	if strings.HasPrefix(url, "http://") {
		url = url[len("http://"):]
	} else if strings.HasPrefix(url, "https://") {
		url = url[len("https://"):]
	}

	return url, nil
}

func validateURL(hostPort string) (host, port string) {
	colon := strings.LastIndexByte(hostPort, ':')
	if colon != -1 {
		host, port = hostPort[:colon], hostPort[colon+1:]
	}
	if !validatePort(port) {
		port = ""
		return
	}
	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
		host = host[1 : len(host)-1]
	}

	return
}

func validatePort(port string) bool {
	if _, err := strconv.Atoi(port); err == nil {
		return true
	}
	return false
}