1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
// tired at this point to do it. It, like many other *Manager code I found was broken because
// they didn't have mutex protection.
package oauth
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)
// ADFS is an active directory federation service authority type.
const ADFS = "ADFS"
type cacheEntry struct {
Endpoints authority.Endpoints
ValidForDomainsInList map[string]bool
}
func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
return cacheEntry{endpoints, map[string]bool{}}
}
// AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
type authorityEndpoint struct {
rest *ops.REST
mu sync.Mutex
cache map[string]cacheEntry
}
// newAuthorityEndpoint is the constructor for AuthorityEndpoint.
func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
return m
}
// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
return endpoints, nil
}
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
if err != nil {
return authority.Endpoints{}, err
}
resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
if err != nil {
return authority.Endpoints{}, err
}
if err := resp.Validate(); err != nil {
return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
}
tenant := authorityInfo.Tenant
endpoints := authority.NewEndpoints(
strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
authorityInfo.Host)
m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
return endpoints, nil
}
// cachedEndpoints returns a the cached endpoints if they exists. If not, we return false.
func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
if authorityInfo.AuthorityType == ADFS {
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
return cacheEntry.Endpoints, true
}
}
}
return cacheEntry.Endpoints, true
}
return authority.Endpoints{}, false
}
func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
m.mu.Lock()
defer m.mu.Unlock()
updatedCacheEntry := createcacheEntry(endpoints)
if authorityInfo.AuthorityType == ADFS {
// Since we're here, we've made a call to the backend. We want to ensure we're caching
// the latest values from the server.
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
for k := range cacheEntry.ValidForDomainsInList {
updatedCacheEntry.ValidForDomainsInList[k] = true
}
}
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
updatedCacheEntry.ValidForDomainsInList[domain] = true
}
}
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
}
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
if authorityInfo.Tenant == "adfs" {
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
} else if authorityInfo.Region != "" {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
}
return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
}
func adfsDomainFromUpn(userPrincipalName string) (string, error) {
parts := strings.Split(userPrincipalName, "@")
if len(parts) < 2 {
return "", errors.New("no @ present in user principal name")
}
return parts[1], nil
}
|