File: certificate_test.go

package info (click to toggle)
golang-github-pion-dtls-v3 3.0.7-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 2,124 kB
  • sloc: makefile: 4
file content (93 lines) | stat: -rw-r--r-- 2,431 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package dtls

import (
	"crypto/tls"
	"testing"

	"github.com/pion/dtls/v3/pkg/crypto/selfsign"
	"github.com/stretchr/testify/assert"
)

func TestGetCertificate(t *testing.T) {
	certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test")
	assert.NoError(t, err)

	certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test")
	assert.NoError(t, err)

	certificateRandom, err := selfsign.GenerateSelfSigned()
	assert.NoError(t, err)

	testCases := []struct {
		localCertificates   []tls.Certificate
		desc                string
		serverName          string
		expectedCertificate tls.Certificate
		getCertificate      func(info *ClientHelloInfo) (*tls.Certificate, error)
	}{
		{
			desc: "Simple match in CN",
			localCertificates: []tls.Certificate{
				certificateRandom,
				certificateTest,
				certificateWildcard,
			},
			serverName:          "test.test",
			expectedCertificate: certificateTest,
		},
		{
			desc: "Simple match in SANs",
			localCertificates: []tls.Certificate{
				certificateRandom,
				certificateTest,
				certificateWildcard,
			},
			serverName:          "www.test.test",
			expectedCertificate: certificateTest,
		},

		{
			desc: "Wildcard match",
			localCertificates: []tls.Certificate{
				certificateRandom,
				certificateTest,
				certificateWildcard,
			},
			serverName:          "foo.test.test",
			expectedCertificate: certificateWildcard,
		},
		{
			desc: "No match return first",
			localCertificates: []tls.Certificate{
				certificateRandom,
				certificateTest,
				certificateWildcard,
			},
			serverName:          "foo.bar",
			expectedCertificate: certificateRandom,
		},
		{
			desc: "Get certificate from callback",
			getCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) {
				return &certificateTest, nil
			},
			expectedCertificate: certificateTest,
		},
	}

	for _, test := range testCases {
		test := test
		t.Run(test.desc, func(t *testing.T) {
			cfg := &handshakeConfig{
				localCertificates:   test.localCertificates,
				localGetCertificate: test.getCertificate,
			}
			cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName})
			assert.NoError(t, err)
			assert.Equal(t, test.expectedCertificate.Leaf, cert.Leaf, "Certificate Leaf should match expected")
		})
	}
}