File: resumption_test.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (130 lines) | stat: -rw-r--r-- 3,610 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package self_test

import (
	"context"
	"crypto/tls"
	"testing"
	"time"

	"github.com/quic-go/quic-go"

	"github.com/stretchr/testify/require"
)

type clientSessionCache struct {
	cache tls.ClientSessionCache
	gets  chan<- string
	puts  chan<- string
}

func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
	return &clientSessionCache{
		cache: cache,
		gets:  gets,
		puts:  puts,
	}
}

var _ tls.ClientSessionCache = &clientSessionCache{}

func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
	session, ok := c.cache.Get(sessionKey)
	if c.gets != nil {
		c.gets <- sessionKey
	}
	return session, ok
}

func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
	c.cache.Put(sessionKey, cs)
	if c.puts != nil {
		c.puts <- sessionKey
	}
}

func TestTLSSessionResumption(t *testing.T) {
	t.Run("uses session resumption", func(t *testing.T) {
		handshakeWithSessionResumption(t, getTLSConfig(), true)
	})

	t.Run("disabled in tls.Config", func(t *testing.T) {
		sConf := getTLSConfig()
		sConf.SessionTicketsDisabled = true
		handshakeWithSessionResumption(t, sConf, false)
	})

	t.Run("disabled in tls.Config.GetConfigForClient", func(t *testing.T) {
		sConf := &tls.Config{
			GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
				conf := getTLSConfig()
				conf.SessionTicketsDisabled = true
				return conf, nil
			},
		}
		handshakeWithSessionResumption(t, sConf, false)
	})
}

func handshakeWithSessionResumption(t *testing.T, serverTLSConf *tls.Config, expectSessionTicket bool) {
	server, err := quic.Listen(newUDPConnLocalhost(t), serverTLSConf, getQuicConfig(nil))
	require.NoError(t, err)
	defer server.Close()

	gets := make(chan string, 100)
	puts := make(chan string, 100)
	cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
	tlsConf := getTLSClientConfig()
	tlsConf.ClientSessionCache = cache

	// first connection - doesn't use resumption
	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	defer cancel()
	conn1, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), tlsConf, getQuicConfig(nil))
	require.NoError(t, err)
	defer conn1.CloseWithError(0, "")
	require.False(t, conn1.ConnectionState().TLS.DidResume)

	var sessionKey string
	select {
	case sessionKey = <-puts:
		if !expectSessionTicket {
			t.Fatal("unexpected session ticket")
		}
	case <-time.After(scaleDuration(50 * time.Millisecond)):
		if expectSessionTicket {
			t.Fatal("timeout waiting for session ticket")
		}
	}

	serverConn, err := server.Accept(ctx)
	require.NoError(t, err)
	require.False(t, serverConn.ConnectionState().TLS.DidResume)

	// second connection - will use resumption, if enabled
	conn2, err := quic.Dial(ctx, newUDPConnLocalhost(t), server.Addr(), tlsConf, getQuicConfig(nil))
	require.NoError(t, err)
	defer conn2.CloseWithError(0, "")

	select {
	case k := <-gets:
		if expectSessionTicket {
			// we can only perform this check if we got a session ticket before
			require.Equal(t, sessionKey, k)
		}
	case <-time.After(scaleDuration(50 * time.Millisecond)):
		if expectSessionTicket {
			t.Fatal("timeout waiting for retrieval of session ticket")
		}
	}

	serverConn, err = server.Accept(context.Background())
	require.NoError(t, err)

	if expectSessionTicket {
		require.True(t, conn2.ConnectionState().TLS.DidResume)
		require.True(t, serverConn.ConnectionState().TLS.DidResume)
	} else {
		require.False(t, conn2.ConnectionState().TLS.DidResume)
		require.False(t, serverConn.ConnectionState().TLS.DidResume)
	}
}