1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
//go:build example
// +build example
package rdsutils_test
import (
"database/sql"
"flag"
"fmt"
"net/url"
"os"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/rds/rdsutils"
)
// ExampleConnectionStringBuilder contains usage of assuming a role and using
// that to build the auth token.
//
// For MySQL you many need to configure your MySQL driver with the TLS
// certificate to ensure verified TLS handshake,
// https://s3.amazonaws.com/rds-downloads/rds-combined-ca-bundle.pem
//
// Usage:
//
// ./main -user "iamuser" -dbname "foo" -region "us-west-2" -rolearn "arn" -endpoint "dbendpoint" -port 3306
func ExampleConnectionStringBuilder() {
userPtr := flag.String("user", "", "user of the credentials")
regionPtr := flag.String("region", "us-east-1", "region to be used when grabbing sts creds")
roleArnPtr := flag.String("rolearn", "", "role arn to be used when grabbing sts creds")
endpointPtr := flag.String("endpoint", "", "DB endpoint to be connected to")
portPtr := flag.Int("port", 3306, "DB port to be connected to")
tablePtr := flag.String("table", "test_table", "DB table to query against")
dbNamePtr := flag.String("dbname", "", "DB name to query against")
flag.Parse()
// Check required flags. Will exit with status code 1 if
// required field isn't set.
if err := requiredFlags(
userPtr,
regionPtr,
roleArnPtr,
endpointPtr,
portPtr,
dbNamePtr,
); err != nil {
fmt.Printf("Error: %v\n\n", err)
flag.PrintDefaults()
os.Exit(1)
}
sess := session.Must(session.NewSession())
creds := stscreds.NewCredentials(sess, *roleArnPtr)
v := url.Values{}
// required fields for DB connection
v.Add("tls", "rds")
v.Add("allowCleartextPasswords", "true")
endpoint := fmt.Sprintf("%s:%d", *endpointPtr, *portPtr)
b := rdsutils.NewConnectionStringBuilder(endpoint, *regionPtr, *userPtr, *dbNamePtr, creds)
connectStr, err := b.WithTCPFormat().WithParams(v).Build()
const dbType = "mysql"
db, err := sql.Open(dbType, connectStr)
// if an error is encountered here, then most likely security groups are incorrect
// in the database.
if err != nil {
panic(fmt.Errorf("failed to open connection to the database"))
}
rows, err := db.Query(fmt.Sprintf("SELECT * FROM %s LIMIT 1", *tablePtr))
if err != nil {
panic(fmt.Errorf("failed to select from table, %q, with %v", *tablePtr, err))
}
for rows.Next() {
columns, err := rows.Columns()
if err != nil {
panic(fmt.Errorf("failed to read columns from row: %v", err))
}
fmt.Printf("rows colums:\n%d\n", len(columns))
}
}
func requiredFlags(flags ...interface{}) error {
for _, f := range flags {
switch f.(type) {
case nil:
return fmt.Errorf("one or more required flags were not set")
}
}
return nil
}
|