File: scram_test.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.8.4%2Bds1-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, bookworm-backports
  • size: 18,520 kB
  • sloc: perl: 533; ansic: 491; python: 432; makefile: 187; sh: 72
file content (120 lines) | stat: -rw-r--r-- 4,395 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package auth

import (
	"context"
	"testing"

	"go.mongodb.org/mongo-driver/internal/testutil/assert"
	"go.mongodb.org/mongo-driver/mongo/description"
	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
	"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)

const (
	scramSha1Nonce   = "fyko+d2lbbFgONRv9qkxdawL"
	scramSha256Nonce = "rOprNGfwEbeRWgbNEkqO"
)

var (
	scramSha1ShortPayloads = [][]byte{
		[]byte("r=fyko+d2lbbFgONRv9qkxdawLHo+Vgk7qvUOKUwuWLIWg4l/9SraGMHEE,s=rQ9ZY3MntBeuP3E1TDVC4w==,i=10000"),
		[]byte("v=UMWeI25JD1yNYZRMpZ4VHvhZ9e0="),
	}
	scramSha256ShortPayloads = [][]byte{
		[]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"),
		[]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="),
	}
	scramSha1LongPayloads   = append(scramSha1ShortPayloads, []byte{})
	scramSha256LongPayloads = append(scramSha256ShortPayloads, []byte{})
)

func TestSCRAM(t *testing.T) {
	t.Run("conversation", func(t *testing.T) {
		testCases := []struct {
			name                  string
			createAuthenticatorFn func(*Cred) (Authenticator, error)
			payloads              [][]byte
			nonce                 string
		}{
			{"scram-sha-1 short conversation", newScramSHA1Authenticator, scramSha1ShortPayloads, scramSha1Nonce},
			{"scram-sha-256 short conversation", newScramSHA256Authenticator, scramSha256ShortPayloads, scramSha256Nonce},
			{"scram-sha-1 long conversation", newScramSHA1Authenticator, scramSha1LongPayloads, scramSha1Nonce},
			{"scram-sha-256 long conversation", newScramSHA256Authenticator, scramSha256LongPayloads, scramSha256Nonce},
		}
		for _, tc := range testCases {
			t.Run(tc.name, func(t *testing.T) {
				authenticator, err := tc.createAuthenticatorFn(&Cred{
					Username: "user",
					Password: "pencil",
					Source:   "admin",
				})
				assert.Nil(t, err, "error creating authenticator: %v", err)
				sa, _ := authenticator.(*ScramAuthenticator)
				sa.client = sa.client.WithNonceGenerator(func() string {
					return tc.nonce
				})

				responses := make(chan []byte, len(tc.payloads))
				writeReplies(t, responses, createSCRAMConversation(tc.payloads)...)

				desc := description.Server{
					WireVersion: &description.VersionRange{
						Max: 4,
					},
				}
				conn := &drivertest.ChannelConn{
					Written:  make(chan []byte, len(tc.payloads)),
					ReadResp: responses,
					Desc:     desc,
				}

				err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn})
				assert.Nil(t, err, "Auth error: %v\n", err)

				// Verify that the first command sent is saslStart.
				assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection")
				startCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
				assert.Nil(t, err, "error parsing wire message: %v", err)
				cmdName := startCmd.Index(0).Key()
				assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName)

				// Verify that the saslStart command always has {options: {skipEmptyExchange: true}}
				optionsVal, err := startCmd.LookupErr("options")
				assert.Nil(t, err, "no options found in saslStart command")
				optionsDoc := optionsVal.Document()
				assert.Equal(t, optionsDoc, scramStartOptions, "expected options %v, got %v", scramStartOptions, optionsDoc)
			})
		}
	})
}

func createSCRAMConversation(payloads [][]byte) []bsoncore.Document {
	responses := make([]bsoncore.Document, len(payloads))
	for idx, payload := range payloads {
		res := createSCRAMServerResponse(payload, idx == len(payloads)-1)
		responses[idx] = res
	}
	return responses
}

func createSCRAMServerResponse(payload []byte, done bool) bsoncore.Document {
	return bsoncore.BuildDocumentFromElements(nil,
		bsoncore.AppendInt32Element(nil, "conversationId", 1),
		bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
		bsoncore.AppendBooleanElement(nil, "done", done),
		bsoncore.AppendInt32Element(nil, "ok", 1),
	)
}

func writeReplies(t *testing.T, c chan []byte, docs ...bsoncore.Document) {
	for _, doc := range docs {
		reply := drivertest.MakeReply(doc)
		c <- reply
	}
}