1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
package fakebundleendpoint
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/http"
"sync"
"testing"
"time"
"github.com/spiffe/go-spiffe/v2/bundle/spiffebundle"
"github.com/spiffe/go-spiffe/v2/internal/test"
"github.com/spiffe/go-spiffe/v2/internal/x509util"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
"github.com/stretchr/testify/assert"
)
type Server struct {
tb testing.TB
wg sync.WaitGroup
addr net.Addr
httpServer *http.Server
// Root certificates used by clients to verify server certificates.
rootCAs *x509.CertPool
// TLS configuration used by the server.
tlscfg *tls.Config
// SPIFFE bundles that can be returned by this Server.
bundles []*spiffebundle.Bundle
}
type ServerOption interface {
apply(*Server)
}
func New(tb testing.TB, option ...ServerOption) *Server {
rootCAs, cert := test.CreateWebCredentials(tb)
tlscfg := &tls.Config{
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
}
s := &Server{
tb: tb,
rootCAs: rootCAs,
tlscfg: tlscfg,
}
for _, opt := range option {
opt.apply(s)
}
sm := http.NewServeMux()
sm.HandleFunc("/test-bundle", s.testbundle)
s.httpServer = &http.Server{
Handler: sm,
TLSConfig: s.tlscfg,
ReadHeaderTimeout: time.Second * 10,
}
err := s.start()
if err != nil {
tb.Fatalf("Failed to start: %v", err)
}
return s
}
func (s *Server) Shutdown() {
err := s.httpServer.Shutdown(context.Background())
assert.NoError(s.tb, err)
s.wg.Wait()
}
func (s *Server) Addr() string {
return s.addr.String()
}
func (s *Server) FetchBundleURL() string {
return fmt.Sprintf("https://%s/test-bundle", s.Addr())
}
func (s *Server) RootCAs() *x509.CertPool {
return s.rootCAs
}
func (s *Server) start() error {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return err
}
s.addr = ln.Addr()
s.wg.Add(1)
go func() {
err := s.httpServer.ServeTLS(ln, "", "")
assert.EqualError(s.tb, err, http.ErrServerClosed.Error())
s.wg.Done()
ln.Close()
}()
return nil
}
func (s *Server) testbundle(w http.ResponseWriter, r *http.Request) {
if len(s.bundles) == 0 {
w.WriteHeader(http.StatusNotFound)
return
}
bb, err := s.bundles[0].Marshal()
assert.NoError(s.tb, err)
s.bundles = s.bundles[1:]
w.Header().Add("Content-Type", "application/json")
b, err := w.Write(bb)
assert.NoError(s.tb, err)
assert.Equal(s.tb, len(bb), b)
}
type serverOption func(*Server)
// WithTestBundles sets the bundles that are returned by the Bundle Endpoint. You can
// specify several bundles, which are going to be returned one at a time each time
// a bundle is GET by a client.
func WithTestBundles(bundles ...*spiffebundle.Bundle) ServerOption {
return serverOption(func(s *Server) {
s.bundles = bundles
})
}
func WithSPIFFEAuth(bundle *spiffebundle.Bundle, svid *x509svid.SVID) ServerOption {
return serverOption(func(s *Server) {
s.rootCAs = x509util.NewCertPool(bundle.X509Authorities())
s.tlscfg = tlsconfig.TLSServerConfig(svid)
})
}
func (so serverOption) apply(s *Server) {
so(s)
}
|