File: builder.go

package info (click to toggle)
golang-github-aws-aws-sdk-go 1.16.18%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: buster, buster-backports, experimental
  • size: 93,084 kB
  • sloc: ruby: 193; makefile: 174; xml: 11
file content (127 lines) | stat: -rw-r--r-- 3,736 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package rdsutils

import (
	"fmt"
	"net/url"

	"github.com/aws/aws-sdk-go/aws/awserr"
	"github.com/aws/aws-sdk-go/aws/credentials"
)

// ConnectionFormat is the type of connection that will be
// used to connect to the database
type ConnectionFormat string

// ConnectionFormat enums
const (
	NoConnectionFormat ConnectionFormat = ""
	TCPFormat          ConnectionFormat = "tcp"
)

// ErrNoConnectionFormat will be returned during build if no format had been
// specified
var ErrNoConnectionFormat = awserr.New("NoConnectionFormat", "No connection format was specified", nil)

// ConnectionStringBuilder is a builder that will construct a connection
// string with the provided parameters. params field is required to have
// a tls specification and allowCleartextPasswords must be set to true.
type ConnectionStringBuilder struct {
	dbName   string
	endpoint string
	region   string
	user     string
	creds    *credentials.Credentials

	connectFormat ConnectionFormat
	params        url.Values
}

// NewConnectionStringBuilder will return an ConnectionStringBuilder
func NewConnectionStringBuilder(endpoint, region, dbUser, dbName string, creds *credentials.Credentials) ConnectionStringBuilder {
	return ConnectionStringBuilder{
		dbName:   dbName,
		endpoint: endpoint,
		region:   region,
		user:     dbUser,
		creds:    creds,
	}
}

// WithEndpoint will return a builder with the given endpoint
func (b ConnectionStringBuilder) WithEndpoint(endpoint string) ConnectionStringBuilder {
	b.endpoint = endpoint
	return b
}

// WithRegion will return a builder with the given region
func (b ConnectionStringBuilder) WithRegion(region string) ConnectionStringBuilder {
	b.region = region
	return b
}

// WithUser will return a builder with the given user
func (b ConnectionStringBuilder) WithUser(user string) ConnectionStringBuilder {
	b.user = user
	return b
}

// WithDBName will return a builder with the given database name
func (b ConnectionStringBuilder) WithDBName(dbName string) ConnectionStringBuilder {
	b.dbName = dbName
	return b
}

// WithParams will return a builder with the given params. The parameters
// will be included in the connection query string
//
//	Example:
//	v := url.Values{}
//	v.Add("tls", "rds")
//	b := rdsutils.NewConnectionBuilder(endpoint, region, user, dbname, creds)
//	connectStr, err := b.WithParams(v).WithTCPFormat().Build()
func (b ConnectionStringBuilder) WithParams(params url.Values) ConnectionStringBuilder {
	b.params = params
	return b
}

// WithFormat will return a builder with the given connection format
func (b ConnectionStringBuilder) WithFormat(f ConnectionFormat) ConnectionStringBuilder {
	b.connectFormat = f
	return b
}

// WithTCPFormat will set the format to TCP and return the modified builder
func (b ConnectionStringBuilder) WithTCPFormat() ConnectionStringBuilder {
	return b.WithFormat(TCPFormat)
}

// Build will return a new connection string that can be used to open a connection
// to the desired database.
//
//	Example:
//	b := rdsutils.NewConnectionStringBuilder(endpoint, region, user, dbname, creds)
//	connectStr, err := b.WithTCPFormat().Build()
//	if err != nil {
//		panic(err)
//	}
//	const dbType = "mysql"
//	db, err := sql.Open(dbType, connectStr)
func (b ConnectionStringBuilder) Build() (string, error) {
	if b.connectFormat == NoConnectionFormat {
		return "", ErrNoConnectionFormat
	}

	authToken, err := BuildAuthToken(b.endpoint, b.region, b.user, b.creds)
	if err != nil {
		return "", err
	}

	connectionStr := fmt.Sprintf("%s:%s@%s(%s)/%s",
		b.user, authToken, string(b.connectFormat), b.endpoint, b.dbName,
	)

	if len(b.params) > 0 {
		connectionStr = fmt.Sprintf("%s?%s", connectionStr, b.params.Encode())
	}
	return connectionStr, nil
}