1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
package self_test
import (
"context"
"crypto/tls"
"testing"
"time"
"github.com/quic-go/quic-go"
"github.com/stretchr/testify/require"
)
type clientSessionCache struct {
cache tls.ClientSessionCache
gets chan<- string
puts chan<- string
}
func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
return &clientSessionCache{
cache: cache,
gets: gets,
puts: puts,
}
}
var _ tls.ClientSessionCache = &clientSessionCache{}
func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
session, ok := c.cache.Get(sessionKey)
if c.gets != nil {
c.gets <- sessionKey
}
return session, ok
}
func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
c.cache.Put(sessionKey, cs)
if c.puts != nil {
c.puts <- sessionKey
}
}
func TestTLSSessionResumption(t *testing.T) {
t.Run("uses session resumption", func(t *testing.T) {
handshakeWithSessionResumption(t, getTLSConfig(), true)
})
t.Run("disabled in tls.Config", func(t *testing.T) {
sConf := getTLSConfig()
sConf.SessionTicketsDisabled = true
handshakeWithSessionResumption(t, sConf, false)
})
t.Run("disabled in tls.Config.GetConfigForClient", func(t *testing.T) {
sConf := &tls.Config{
GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
conf := getTLSConfig()
conf.SessionTicketsDisabled = true
return conf, nil
},
}
handshakeWithSessionResumption(t, sConf, false)
})
}
func handshakeWithSessionResumption(t *testing.T, serverTLSConf *tls.Config, expectSessionTicket bool) {
server, err := quic.Listen(newUDPConnLocalhost(t), serverTLSConf, getQuicConfig(nil))
require.NoError(t, err)
defer server.Close()
gets := make(chan string, 100)
puts := make(chan string, 100)
cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
tlsConf := getTLSClientConfig()
tlsConf.ClientSessionCache = cache
// first connection - doesn't use resumption
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
conn1, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), tlsConf, getQuicConfig(nil))
require.NoError(t, err)
defer conn1.CloseWithError(0, "")
require.False(t, conn1.ConnectionState().TLS.DidResume)
var sessionKey string
select {
case sessionKey = <-puts:
if !expectSessionTicket {
t.Fatal("unexpected session ticket")
}
case <-time.After(scaleDuration(50 * time.Millisecond)):
if expectSessionTicket {
t.Fatal("timeout waiting for session ticket")
}
}
serverConn, err := server.Accept(ctx)
require.NoError(t, err)
require.False(t, serverConn.ConnectionState().TLS.DidResume)
// second connection - will use resumption, if enabled
conn2, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), tlsConf, getQuicConfig(nil))
require.NoError(t, err)
defer conn2.CloseWithError(0, "")
select {
case k := <-gets:
if expectSessionTicket {
// we can only perform this check if we got a session ticket before
require.Equal(t, sessionKey, k)
}
case <-time.After(scaleDuration(50 * time.Millisecond)):
if expectSessionTicket {
t.Fatal("timeout waiting for retrieval of session ticket")
}
}
serverConn, err = server.Accept(context.Background())
require.NoError(t, err)
if expectSessionTicket {
require.True(t, conn2.ConnectionState().TLS.DidResume)
require.True(t, serverConn.ConnectionState().TLS.DidResume)
} else {
require.False(t, conn2.ConnectionState().TLS.DidResume)
require.False(t, serverConn.ConnectionState().TLS.DidResume)
}
}
|