File: confidential.go

package info (click to toggle)
golang-github-azuread-microsoft-authentication-library-for-go 1.0.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 964 kB
  • sloc: makefile: 4
file content (205 lines) | stat: -rw-r--r-- 5,696 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package main

import (
	"context"
	"fmt"
	"os"
	"runtime"
	"strconv"
	"sync"
	"text/template"
	"time"

	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
	internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)

const accessToken = "fake_token"

var tokenScope = []string{"fake_scope"}

type testParams struct {
	// the number of goroutines to use
	Concurrency int

	// the number of tokens in the cache
	// must be divisible by Concurrency
	TokenCount int
}

func fakeClient() (base.Client, error) {
	// we use a base.Client so we can provide a fake OAuth client
	return base.New("fake_client_id", "https://fake_authority/fake", &oauth.Client{
		AccessTokens: &fake.AccessTokens{
			AccessToken: accesstokens.TokenResponse{
				AccessToken:   accessToken,
				ExpiresOn:     internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
				GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
			},
		},
		Authority: &fake.Authority{
			InstanceResp: authority.InstanceDiscoveryResponse{
				Metadata: []authority.InstanceDiscoveryMetadata{
					{
						PreferredNetwork: "fake_authority",
						Aliases:          []string{"fake_authority"},
					},
				},
			},
		},
		Resolver: &fake.ResolveEndpoints{
			Endpoints: authority.Endpoints{
				AuthorizationEndpoint: "auth_endpoint",
				TokenEndpoint:         "token_endpoint",
			},
		},
		WSTrust: &fake.WSTrust{},
	})
}

type execTime struct {
	start time.Time
	end   time.Time
}

func populateTokenCache(client base.Client, params testParams) execTime {
	if r := params.TokenCount % params.Concurrency; r != 0 {
		panic("TokenCount must be divisible by Concurrency")
	}
	parts := params.TokenCount / params.Concurrency
	authParams := client.AuthParams
	authParams.Scopes = tokenScope
	authParams.AuthorizationType = authority.ATClientCredentials

	wg := &sync.WaitGroup{}
	fmt.Printf("Populating token cache with %d tokens...", params.TokenCount)
	start := time.Now()
	for n := 0; n < params.Concurrency; n++ {
		wg.Add(1)
		go func(chunk int) {
			for i := parts * chunk; i < parts*(chunk+1); i++ {
				// we use this to add a fake token to the cache.
				// each token has a different scope which is what makes them unique
				_, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{
					AccessToken:   accessToken,
					ExpiresOn:     internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
					GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}},
				}, true)
				if err != nil {
					panic(err)
				}
			}
			wg.Done()
		}(n)
	}
	wg.Wait()
	return execTime{start: start, end: time.Now()}
}

func executeTest(client base.Client, params testParams) execTime {
	wg := &sync.WaitGroup{}
	fmt.Printf("Begin token retrieval.....")
	start := time.Now()
	for n := 0; n < params.Concurrency; n++ {
		wg.Add(1)
		go func() {
			// retrieve each token once per goroutine
			for tk := 0; tk < params.TokenCount; tk++ {
				_, err := client.AcquireTokenSilent(context.Background(), base.AcquireTokenSilentParameters{
					Scopes:      []string{strconv.FormatInt(int64(tk), 10)},
					RequestType: accesstokens.ATConfidential,
					Credential: &accesstokens.Credential{
						Secret: "fake_secret",
					},
				})
				if err != nil {
					panic(err)
				}
			}
			wg.Done()
		}()
	}
	wg.Wait()
	return execTime{start: start, end: time.Now()}
}

// Stats is used with statsTemplText for reporting purposes
type Stats struct {
	popExec     execTime
	retExec     execTime
	Concurrency int
	Count       int64
}

// PopDur returns the total duration for populating the cache.
func (s *Stats) PopDur() time.Duration {
	return s.popExec.end.Sub(s.popExec.start)
}

// RetDur returns the total duration for retrieving tokens.
func (s *Stats) RetDur() time.Duration {
	return s.retExec.end.Sub(s.retExec.start)
}

// PopAvg returns the mean average of caching a token.
func (s *Stats) PopAvg() time.Duration {
	return s.PopDur() / time.Duration(s.Count)
}

// RetAvg returns the mean average of retrieving a token.
func (s *Stats) RetAvg() time.Duration {
	return s.RetDur() / time.Duration(s.Count)
}

var statsTemplText = `
Test Results:
[{{.Concurrency}} goroutines][{{.Count}} tokens] [population: total {{.PopDur}}, avg {{.PopAvg}}] [retrieval: total {{.RetDur}}, avg {{.RetAvg}}]
==========================================================================
`
var statsTempl = template.Must(template.New("stats").Parse(statsTemplText))

func main() {
	tests := []testParams{
		{
			Concurrency: runtime.NumCPU(),
			TokenCount:  100,
		},
		{
			Concurrency: runtime.NumCPU(),
			TokenCount:  1000,
		},
		{
			Concurrency: runtime.NumCPU(),
			TokenCount:  10000,
		},
		{
			Concurrency: runtime.NumCPU(),
			TokenCount:  20000,
		},
	}

	for _, t := range tests {
		client, err := fakeClient()
		if err != nil {
			panic(err)
		}
		fmt.Printf("Test Params: %#v\n", t)
		ptime := populateTokenCache(client, t)
		ttime := executeTest(client, t)
		if err := statsTempl.Execute(os.Stdout, &Stats{
			popExec:     ptime,
			retExec:     ttime,
			Concurrency: t.Concurrency,
			Count:       int64(t.TokenCount),
		}); err != nil {
			panic(err)
		}
	}
}