1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
package dialers
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"net"
"os"
"os/user"
"path/filepath"
"time"
)
const (
// defaultTLSPort specifies the default libvirtd port.
defaultTLSPort = "16514"
// defaultTLSTimeout specifies the default libvirt dial timeout.
defaultTLSTimeout = 20 * time.Second
)
type certDirs struct {
KeyPath string
CertPath string
}
// TLS implements connecting to a remote server's libvirt using tls
type TLS struct {
timeout time.Duration
host, port string
insecureSkipVerify bool
certSearchPaths []certDirs
caSearchPaths []string
}
// TLSOption is a function for setting remote dialer options.
type TLSOption func(*TLS)
// WithInsecureNoVerify ignores the validity of the server certificate.
func WithInsecureNoVerify() TLSOption {
return func(r *TLS) {
r.insecureSkipVerify = true
}
}
// UseTLSPort sets the port to dial for libirt on the target host server.
func UseTLSPort(port string) TLSOption {
return func(r *TLS) {
r.port = port
}
}
// UsePKIPath sets the search path for TLS certificate files.
func UsePKIPath(pkiPath string) TLSOption {
return func(r *TLS) {
r.certSearchPaths = []certDirs{
{
KeyPath: pkiPath,
CertPath: pkiPath,
},
}
r.caSearchPaths = []string{pkiPath}
}
}
// NewTLS is a dialer for connecting to libvirt running on another server.
func NewTLS(hostAddr string, opts ...TLSOption) *TLS {
r := &TLS{
timeout: defaultTLSTimeout,
host: hostAddr,
port: defaultTLSPort,
certSearchPaths: []certDirs{
{
KeyPath: "/etc/pki/libvirt/private/",
CertPath: "/etc/pki/libvirt/",
},
},
caSearchPaths: []string{"/etc/pki/CA/"},
}
if u, err := user.Current(); err != nil || u.Uid != "0" {
cd := filepath.Join(u.HomeDir, ".pki", "libvirt")
r.certSearchPaths = append([]certDirs{{KeyPath: cd, CertPath: cd}}, r.certSearchPaths...)
// Some libvirt docs erroneously state that the user location for the
// CA cert is in ~/.pki/ but it is in fact in ~/.pki/libvirt/
r.caSearchPaths = append([]string{cd}, r.caSearchPaths...)
}
for _, opt := range opts {
opt(r)
}
return r
}
func (r *TLS) clientCert() (*tls.Certificate, error) {
var errs []error
for _, dirs := range r.certSearchPaths {
certFile, err := os.ReadFile(filepath.Join(dirs.CertPath, "clientcert.pem"))
if err != nil {
errs = append(errs,
fmt.Errorf("could not read tls client cert: %w", err))
continue
}
keyFile, err := os.ReadFile(filepath.Join(dirs.KeyPath, "clientkey.pem"))
if err != nil {
errs = append(errs,
fmt.Errorf("could not read tls private key: %w", err))
continue
}
cert, err := tls.X509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("invalid tls client cert: %w", err)
}
return &cert, nil
}
return nil, errors.Join(errs...)
}
func (r *TLS) caCerts(optional bool) (*x509.CertPool, error) {
var errs []error
pool := x509.NewCertPool()
for _, dir := range r.caSearchPaths {
if caFile, err := os.ReadFile(filepath.Join(dir, "cacert.pem")); err == nil {
pool.AppendCertsFromPEM(caFile)
return pool, nil
} else if !(optional && errors.Is(err, fs.ErrNotExist)) {
errs = append(errs,
fmt.Errorf("could not read tls CA cert: %w", err))
}
}
return nil, errors.Join(errs...)
}
func (r *TLS) config() (*tls.Config, error) {
cert, err := r.clientCert()
if err != nil {
return nil, err
}
rootCAs, err := r.caCerts(r.insecureSkipVerify)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{*cert},
RootCAs: rootCAs,
InsecureSkipVerify: r.insecureSkipVerify, //nolint:gosec
}, nil
}
// Dial connects to libvirt running on another server.
func (r *TLS) Dial() (net.Conn, error) {
conf, err := r.config()
if err != nil {
return nil, err
}
netDialer := net.Dialer{
Timeout: r.timeout,
}
c, err := tls.DialWithDialer(
&netDialer,
"tcp",
net.JoinHostPort(r.host, r.port),
conf,
)
if err != nil {
return nil, err
}
// When running over TLS, after connection libvirt writes a single byte to
// the socket to indicate whether the server's check of the client's
// certificate has succeeded.
// See https://github.com/digitalocean/go-libvirt/issues/89#issuecomment-1607300636
// for more details.
buf := make([]byte, 1)
if n, err := c.Read(buf); err != nil {
c.Close()
return nil, err
} else if n != 1 || buf[0] != byte(1) {
c.Close()
return nil, errors.New("server verification (of our certificate or IP address) failed")
}
return c, nil
}
|