| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 
 | /*
 *
 * Copyright 2020 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */
package advancedtls
import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"math/big"
	"testing"
	"time"
	"github.com/google/go-cmp/cmp"
	"google.golang.org/grpc/credentials/tls/certprovider"
	"google.golang.org/grpc/security/advancedtls/testdata"
)
func (s) TestNewPEMFileProvider(t *testing.T) {
	tests := []struct {
		desc      string
		options   PEMFileProviderOptions
		certFile  string
		keyFile   string
		trustFile string
		wantError bool
	}{
		{
			desc:      "Expect error if no credential files specified",
			options:   PEMFileProviderOptions{},
			wantError: true,
		},
		{
			desc: "Expect error if only certFile is specified",
			options: PEMFileProviderOptions{
				CertFile: testdata.Path("client_cert_1.pem"),
			},
			wantError: true,
		},
		{
			desc: "Should be good if only identity key cert pairs are specified",
			options: PEMFileProviderOptions{
				KeyFile:  testdata.Path("client_key_1.pem"),
				CertFile: testdata.Path("client_cert_1.pem"),
			},
			wantError: false,
		},
		{
			desc: "Should be good if only root certs are specified",
			options: PEMFileProviderOptions{
				TrustFile: testdata.Path("client_trust_cert_1.pem"),
			},
			wantError: false,
		},
		{
			desc: "Should be good if both identity pairs and root certs are specified",
			options: PEMFileProviderOptions{
				KeyFile:   testdata.Path("client_key_1.pem"),
				CertFile:  testdata.Path("client_cert_1.pem"),
				TrustFile: testdata.Path("client_trust_cert_1.pem"),
			},
			wantError: false,
		},
	}
	for _, test := range tests {
		t.Run(test.desc, func(t *testing.T) {
			provider, err := NewPEMFileProvider(test.options)
			if (err != nil) != test.wantError {
				t.Fatalf("NewPEMFileProvider(%v) = %v, want %v", test.options, err, test.wantError)
			}
			if err != nil {
				return
			}
			provider.Close()
		})
	}
}
// This test overwrites the credential reading function used by the watching
// goroutine. It is tested under different stages:
// At stage 0, we force reading function to load clientPeer1 and serverTrust1,
// and see if the credentials are picked up by the watching go routine.
// At stage 1, we force reading function to cause an error. The watching go
// routine should log the error while leaving the credentials unchanged.
// At stage 2, we force reading function to load clientPeer2 and serverTrust2,
// and see if the new credentials are picked up.
func (s) TestWatchingRoutineUpdates(t *testing.T) {
	// Load certificates.
	cs := &certStore{}
	if err := cs.loadCerts(); err != nil {
		t.Fatalf("cs.loadCerts() failed: %v", err)
	}
	tests := []struct {
		desc         string
		options      PEMFileProviderOptions
		wantKmStage0 certprovider.KeyMaterial
		wantKmStage1 certprovider.KeyMaterial
		wantKmStage2 certprovider.KeyMaterial
	}{
		{
			desc: "use identity certs and root certs",
			options: PEMFileProviderOptions{
				CertFile:  "not_empty_cert_file",
				KeyFile:   "not_empty_key_file",
				TrustFile: "not_empty_trust_file",
			},
			wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
			wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}, Roots: cs.serverTrust1},
			wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}, Roots: cs.serverTrust2},
		},
		{
			desc: "use identity certs only",
			options: PEMFileProviderOptions{
				CertFile: "not_empty_cert_file",
				KeyFile:  "not_empty_key_file",
			},
			wantKmStage0: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
			wantKmStage1: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer1}},
			wantKmStage2: certprovider.KeyMaterial{Certs: []tls.Certificate{cs.clientPeer2}},
		},
		{
			desc: "use trust certs only",
			options: PEMFileProviderOptions{
				TrustFile: "not_empty_trust_file",
			},
			wantKmStage0: certprovider.KeyMaterial{Roots: cs.serverTrust1},
			wantKmStage1: certprovider.KeyMaterial{Roots: cs.serverTrust1},
			wantKmStage2: certprovider.KeyMaterial{Roots: cs.serverTrust2},
		},
	}
	for _, test := range tests {
		testInterval := 200 * time.Millisecond
		test.options.IdentityInterval = testInterval
		test.options.RootInterval = testInterval
		t.Run(test.desc, func(t *testing.T) {
			stage := &stageInfo{}
			oldReadKeyCertPairFunc := readKeyCertPairFunc
			readKeyCertPairFunc = func(certFile, keyFile string) (tls.Certificate, error) {
				switch stage.read() {
				case 0:
					return cs.clientPeer1, nil
				case 1:
					return tls.Certificate{}, fmt.Errorf("error occurred while reloading")
				case 2:
					return cs.clientPeer2, nil
				default:
					return tls.Certificate{}, fmt.Errorf("test stage not supported")
				}
			}
			defer func() {
				readKeyCertPairFunc = oldReadKeyCertPairFunc
			}()
			oldReadTrustCertFunc := readTrustCertFunc
			readTrustCertFunc = func(trustFile string) (*x509.CertPool, error) {
				switch stage.read() {
				case 0:
					return cs.serverTrust1, nil
				case 1:
					return nil, fmt.Errorf("error occurred while reloading")
				case 2:
					return cs.serverTrust2, nil
				default:
					return nil, fmt.Errorf("test stage not supported")
				}
			}
			defer func() {
				readTrustCertFunc = oldReadTrustCertFunc
			}()
			provider, err := NewPEMFileProvider(test.options)
			if err != nil {
				t.Fatalf("NewPEMFileProvider failed: %v", err)
			}
			defer provider.Close()
			ctx, cancel := context.WithCancel(context.Background())
			defer cancel()
			//// ------------------------Stage 0------------------------------------
			// Wait for the refreshing go-routine to pick up the changes.
			time.Sleep(1 * time.Second)
			gotKM, err := provider.KeyMaterial(ctx)
			if !cmp.Equal(*gotKM, test.wantKmStage0, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage0)
			}
			// ------------------------Stage 1------------------------------------
			stage.increase()
			// Wait for the refreshing go-routine to pick up the changes.
			time.Sleep(1 * time.Second)
			gotKM, err = provider.KeyMaterial(ctx)
			if !cmp.Equal(*gotKM, test.wantKmStage1, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage1)
			}
			//// ------------------------Stage 2------------------------------------
			// Wait for the refreshing go-routine to pick up the changes.
			stage.increase()
			time.Sleep(1 * time.Second)
			gotKM, err = provider.KeyMaterial(ctx)
			if !cmp.Equal(*gotKM, test.wantKmStage2, cmp.AllowUnexported(big.Int{}, x509.CertPool{})) {
				t.Fatalf("provider.KeyMaterial() = %+v, want %+v", *gotKM, test.wantKmStage2)
			}
			stage.reset()
		})
	}
}
 |