1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
|
package tlsutil
import (
"context"
"crypto/tls"
"fmt"
"math/rand"
"sync"
"time"
)
// RenewFunc defines the type of the functions used to get a new tls
// certificate.
type RenewFunc func() (*tls.Certificate, *tls.Config, error)
// MinCertDuration is the minimum validity of a certificate.
var MinCertDuration = time.Minute
// Renewer automatically renews a tls certificate using a RenewFunc.
//
//nolint:gocritic // ignore exposedSyncMutex
type Renewer struct {
sync.RWMutex
RenewFunc RenewFunc
cert *tls.Certificate
config *tls.Config
timer *time.Timer
renewBefore time.Duration
renewJitter time.Duration
certNotAfter time.Time
}
type renewerOptions func(r *Renewer) error
// WithRenewBefore modifies a tls renewer by setting the renewBefore attribute.
func WithRenewBefore(b time.Duration) func(r *Renewer) error {
return func(r *Renewer) error {
r.renewBefore = b
return nil
}
}
// WithRenewJitter modifies a tls renewer by setting the renewJitter attribute.
func WithRenewJitter(j time.Duration) func(r *Renewer) error {
return func(r *Renewer) error {
r.renewJitter = j
return nil
}
}
// NewRenewer creates a TLS renewer for the given cert. It will use the given
// RenewFunc to get a new certificate when required.
func NewRenewer(cert *tls.Certificate, config *tls.Config, fn RenewFunc, opts ...renewerOptions) (*Renewer, error) {
r := &Renewer{
RenewFunc: fn,
cert: cert,
config: config.Clone(),
certNotAfter: cert.Leaf.NotAfter,
}
// Use renewer methods.
if r.config.GetCertificate == nil {
r.config.GetCertificate = r.GetCertificate
}
if r.config.GetClientCertificate == nil {
r.config.GetClientCertificate = r.GetClientCertificate
}
if r.config.GetConfigForClient == nil {
r.config.GetConfigForClient = r.GetConfigForClient
}
for _, f := range opts {
if err := f(r); err != nil {
return nil, fmt.Errorf("error applying options: %w", err)
}
}
period := cert.Leaf.NotAfter.Sub(cert.Leaf.NotBefore)
if period < MinCertDuration {
return nil, fmt.Errorf("period must be greater than or equal to %s, but got %v", MinCertDuration, period)
}
// By default we will try to renew the cert before 2/3 of the validity
// period have expired.
if r.renewBefore == 0 {
r.renewBefore = period / 3
}
// By default we set the jitter to 1/20th of the validity period.
if r.renewJitter == 0 {
r.renewJitter = period / 20
}
return r, nil
}
// GetConfig returns the current tls.Config.
func (r *Renewer) GetConfig() *tls.Config {
return r.getConfigForClient()
}
// Run starts the certificate renewer for the given certificate.
func (r *Renewer) Run() {
r.Lock()
next := r.nextRenewDuration(r.certNotAfter)
r.timer = time.AfterFunc(next, r.renewCertificate)
r.Unlock()
}
// RunContext starts the certificate renewer for the given certificate.
func (r *Renewer) RunContext(ctx context.Context) {
r.Run()
go func() {
<-ctx.Done()
r.Stop()
}()
}
// Stop prevents the renew timer from firing.
func (r *Renewer) Stop() bool {
r.Lock()
defer r.Unlock()
if r.timer != nil {
return r.timer.Stop()
}
return true
}
// GetCertificate returns the current server certificate.
//
// This method is set in the tls.Config GetCertificate property.
func (r *Renewer) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// GetClientCertificate returns the current client certificate.
//
// This method is set in the tls.Config GetClientCertificate property.
func (r *Renewer) GetClientCertificate(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return r.getCertificate(), nil
}
// GetConfigForClient returns the tls.Config used per request.
//
// This method is set in the tls.Config GetConfigForClient property.
func (r *Renewer) GetConfigForClient(_ *tls.ClientHelloInfo) (*tls.Config, error) {
return r.getConfigForClient(), nil
}
// getCertificate returns the certificate using a read-only lock. It will
// automatically renew the certificate if it has expired.
func (r *Renewer) getCertificate() *tls.Certificate {
r.RLock()
// Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) {
r.RUnlock()
r.renewCertificate()
r.RLock()
}
cert := r.cert
r.RUnlock()
return cert
}
func (r *Renewer) getConfigForClient() *tls.Config {
r.RLock()
// Force certificate renewal if the timer didn't run.
// This is an special case that can happen after a computer sleep.
if time.Now().After(r.certNotAfter) {
r.RUnlock()
r.renewCertificate()
r.RLock()
}
config := r.config
r.RUnlock()
return config
}
// setCertificate updates the certificate using a read-write lock. It also
// updates certNotAfter with 1m of delta; this will force the renewal of the
// certificate if it is about to expire.
func (r *Renewer) setCertificate(cert *tls.Certificate, config *tls.Config) {
r.Lock()
r.cert = cert
r.config = config
r.certNotAfter = cert.Leaf.NotAfter
// Use renewer methods.
if r.config.GetCertificate == nil {
r.config.GetCertificate = r.GetCertificate
}
if r.config.GetClientCertificate == nil {
r.config.GetClientCertificate = r.GetClientCertificate
}
if r.config.GetConfigForClient == nil {
r.config.GetConfigForClient = r.GetConfigForClient
}
r.Unlock()
}
func (r *Renewer) renewCertificate() {
var next time.Duration
cert, config, err := r.RenewFunc()
if err != nil {
next = r.renewJitter / 2
next += time.Duration(mathRandInt63n(int64(next)))
} else {
r.setCertificate(cert, config)
next = r.nextRenewDuration(cert.Leaf.NotAfter)
}
r.Lock()
r.timer.Reset(next)
r.Unlock()
}
func (r *Renewer) nextRenewDuration(notAfter time.Time) time.Duration {
d := time.Until(notAfter) - r.renewBefore
n := mathRandInt63n(int64(r.renewJitter))
d -= time.Duration(n)
if d < 0 {
d = 0
}
return d
}
//nolint:gosec // not used for security reasons
func mathRandInt63n(n int64) int64 {
return rand.Int63n(n)
}
|