1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
|
package client
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"gitlab.com/gitlab-org/labkit/correlation"
"gitlab.com/gitlab-org/labkit/tracing"
)
const (
socketBaseURL = "http://unix"
unixSocketProtocol = "http+unix://"
httpProtocol = "http://"
httpsProtocol = "https://"
defaultReadTimeoutSeconds = 300
)
//nolint:revive // This is unintentionally missing documentation.
var ErrCafileNotFound = errors.New("cafile not found")
//nolint:revive // This is unintentionally missing documentation.
type HTTPClient struct {
*http.Client
Host string
}
type httpClientCfg struct {
keyPath, certPath string
caFile, caPath string
}
func (hcc httpClientCfg) HaveCertAndKey() bool { return hcc.keyPath != "" && hcc.certPath != "" }
// HTTPClientOpt provides options for configuring an HttpClient
type HTTPClientOpt func(*httpClientCfg)
// WithClientCert will configure the HttpClient to provide client certificates
// when connecting to a server.
func WithClientCert(certPath, keyPath string) HTTPClientOpt {
return func(hcc *httpClientCfg) {
hcc.keyPath = keyPath
hcc.certPath = certPath
}
}
func validateCaFile(filename string) error {
if filename == "" {
return nil
}
if _, err := os.Stat(filename); err != nil {
if os.IsNotExist(err) {
return fmt.Errorf("cannot find cafile %q: %w", filename, ErrCafileNotFound)
}
return err
}
return nil
}
// NewHTTPClientWithOpts builds an HTTP client using the provided options
func NewHTTPClientWithOpts(gitlabURL, gitlabRelativeURLRoot, caFile, caPath string, readTimeoutSeconds uint64, opts []HTTPClientOpt) (*HTTPClient, error) {
var transport *http.Transport
var host string
var err error
if strings.HasPrefix(gitlabURL, unixSocketProtocol) {
transport, host = buildSocketTransport(gitlabURL, gitlabRelativeURLRoot)
} else if strings.HasPrefix(gitlabURL, httpProtocol) {
transport, host = buildHTTPTransport(gitlabURL)
} else if strings.HasPrefix(gitlabURL, httpsProtocol) {
err = validateCaFile(caFile)
if err != nil {
return nil, err
}
hcc := &httpClientCfg{
caFile: caFile,
caPath: caPath,
}
for _, opt := range opts {
opt(hcc)
}
transport, host, err = buildHTTPSTransport(*hcc, gitlabURL)
if err != nil {
return nil, err
}
} else {
return nil, errors.New("unknown GitLab URL prefix")
}
c := &http.Client{
Transport: correlation.NewInstrumentedRoundTripper(tracing.NewRoundTripper(transport)),
Timeout: readTimeout(readTimeoutSeconds),
}
client := &HTTPClient{Client: c, Host: host}
return client, nil
}
func buildSocketTransport(gitlabURL, gitlabRelativeURLRoot string) (*http.Transport, string) {
socketPath := strings.TrimPrefix(gitlabURL, unixSocketProtocol)
transport := &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", socketPath)
},
}
host := socketBaseURL
gitlabRelativeURLRoot = strings.Trim(gitlabRelativeURLRoot, "/")
if gitlabRelativeURLRoot != "" {
host = host + "/" + gitlabRelativeURLRoot
}
return transport, host
}
func buildHTTPSTransport(hcc httpClientCfg, gitlabURL string) (*http.Transport, string, error) {
certPool, err := x509.SystemCertPool()
if err != nil {
certPool = x509.NewCertPool()
}
if hcc.caFile != "" {
addCertToPool(certPool, hcc.caFile)
}
if hcc.caPath != "" {
fis, _ := os.ReadDir(hcc.caPath)
for _, fi := range fis {
if fi.IsDir() {
continue
}
addCertToPool(certPool, filepath.Join(hcc.caPath, fi.Name()))
}
}
tlsConfig := &tls.Config{
RootCAs: certPool,
MinVersion: tls.VersionTLS12,
}
if hcc.HaveCertAndKey() {
cert, err := tls.LoadX509KeyPair(hcc.certPath, hcc.keyPath)
if err != nil {
return nil, "", err
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
}
return transport, gitlabURL, err
}
func addCertToPool(certPool *x509.CertPool, fileName string) {
cert, err := os.ReadFile(fileName)
if err == nil {
certPool.AppendCertsFromPEM(cert)
}
}
func buildHTTPTransport(gitlabURL string) (*http.Transport, string) {
return &http.Transport{}, gitlabURL
}
func readTimeout(timeoutSeconds uint64) time.Duration {
if timeoutSeconds == 0 {
timeoutSeconds = defaultReadTimeoutSeconds
}
return time.Duration(timeoutSeconds) * time.Second
}
|