File: verifier.go

package info (click to toggle)
golang-github-zitadel-oidc 3.44.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,520 kB
  • sloc: makefile: 5
file content (275 lines) | stat: -rw-r--r-- 9,126 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
package oidc

import (
	"bytes"
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"slices"
	"strings"
	"time"

	jose "github.com/go-jose/go-jose/v4"
)

type Claims interface {
	GetIssuer() string
	GetSubject() string
	GetAudience() []string
	GetExpiration() time.Time
	GetIssuedAt() time.Time
	GetNonce() string
	GetAuthenticationContextClassReference() string
	GetAuthTime() time.Time
	GetAuthorizedParty() string
	ClaimsSignature
}

type ClaimsSignature interface {
	SetSignatureAlgorithm(algorithm jose.SignatureAlgorithm)
}

type IDClaims interface {
	Claims
	GetSignatureAlgorithm() jose.SignatureAlgorithm
	GetAccessTokenHash() string
}

var (
	ErrParse                   = errors.New("parsing of request failed")
	ErrIssuerInvalid           = errors.New("issuer does not match")
	ErrDiscoveryFailed         = errors.New("OpenID Provider Configuration Discovery has failed")
	ErrSubjectMissing          = errors.New("subject missing")
	ErrAudience                = errors.New("audience is not valid")
	ErrAzpMissing              = errors.New("authorized party is not set. If Token is valid for multiple audiences, azp must not be empty")
	ErrAzpInvalid              = errors.New("authorized party is not valid")
	ErrSignatureMissing        = errors.New("id_token does not contain a signature")
	ErrSignatureMultiple       = errors.New("id_token contains multiple signatures")
	ErrSignatureUnsupportedAlg = errors.New("signature algorithm not supported")
	ErrSignatureInvalidPayload = errors.New("signature does not match Payload")
	ErrSignatureInvalid        = errors.New("invalid signature")
	ErrExpired                 = errors.New("token has expired")
	ErrIatMissing              = errors.New("issuedAt of token is missing")
	ErrIatInFuture             = errors.New("issuedAt of token is in the future")
	ErrIatToOld                = errors.New("issuedAt of token is to old")
	ErrNonceInvalid            = errors.New("nonce does not match")
	ErrAcrInvalid              = errors.New("acr is invalid")
	ErrAuthTimeNotPresent      = errors.New("claim `auth_time` of token is missing")
	ErrAuthTimeToOld           = errors.New("auth time of token is too old")
	ErrAtHash                  = errors.New("at_hash does not correspond to access token")
)

// Verifier caries configuration for the various token verification
// functions. Use package specific constructor functions to know
// which values need to be set.
type Verifier struct {
	Issuer            string
	MaxAgeIAT         time.Duration
	Offset            time.Duration
	ClientID          string
	SupportedSignAlgs []string
	MaxAge            time.Duration
	ACR               ACRVerifier
	AZP               AZPVerifier
	KeySet            KeySet
	Nonce             func(ctx context.Context) string
}

// ACRVerifier specifies the function to be used by the `DefaultVerifier` for validating the acr claim
type ACRVerifier func(string) error

// DefaultACRVerifier implements `ACRVerifier` returning an error
// if none of the provided values matches the acr claim
func DefaultACRVerifier(possibleValues []string) ACRVerifier {
	return func(acr string) error {
		if !slices.Contains(possibleValues, acr) {
			return fmt.Errorf("expected one of: %v, got: %q", possibleValues, acr)
		}
		return nil
	}
}

func DecryptToken(tokenString string) (string, error) {
	return tokenString, nil // TODO: impl
}

func ParseToken(tokenString string, claims any) ([]byte, error) {
	parts := strings.Split(tokenString, ".")
	if len(parts) != 3 {
		return nil, fmt.Errorf("%w: token contains an invalid number of segments", ErrParse)
	}
	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
	if err != nil {
		return nil, fmt.Errorf("%w: malformed jwt payload: %v", ErrParse, err)
	}
	err = json.Unmarshal(payload, claims)
	return payload, err
}

func CheckSubject(claims Claims) error {
	if claims.GetSubject() == "" {
		return ErrSubjectMissing
	}
	return nil
}

func CheckIssuer(claims Claims, issuer string) error {
	if claims.GetIssuer() != issuer {
		return fmt.Errorf("%w: Expected: %s, got: %s", ErrIssuerInvalid, issuer, claims.GetIssuer())
	}
	return nil
}

func CheckAudience(claims Claims, clientID string) error {
	if !slices.Contains(claims.GetAudience(), clientID) {
		return fmt.Errorf("%w: Audience must contain client_id %q", ErrAudience, clientID)
	}

	// TODO: check aud trusted
	return nil
}

// AZPVerifier specifies the function to be used by the `DefaultVerifier` for validating the azp claim
type AZPVerifier func(string) error

// DefaultAZPVerifier implements `AZPVerifier` returning an error
// if the azp claim is set and doesn't match the clientID.
func DefaultAZPVerifier(clientID string) AZPVerifier {
	return func(azp string) error {
		if azp != "" && azp != clientID {
			return fmt.Errorf("%w: azp %q must be equal to client_id %q", ErrAzpInvalid, azp, clientID)
		}
		return nil
	}
}

// CheckAuthorizedParty checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client MAY verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAuthorizedParty(claims Claims, clientID string) error {
	return CheckAZPVerifier(claims, DefaultAZPVerifier(clientID))
}

// CheckAZPVerifier checks azp (authorized party) claim requirements.
//
// If the ID Token contains multiple audiences, the Client SHOULD verify that an azp Claim is present.
// If an azp Claim is present, the Client MAY verify that its client_id is the Claim Value.
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
func CheckAZPVerifier(claims Claims, azp AZPVerifier) error {
	if len(claims.GetAudience()) > 1 {
		if claims.GetAuthorizedParty() == "" {
			return ErrAzpMissing
		}
	}

	if err := azp(claims.GetAuthorizedParty()); err != nil {
		return fmt.Errorf("%w: %v", ErrAzpInvalid, err)
	}
	return nil
}

func CheckSignature(ctx context.Context, token string, payload []byte, claims ClaimsSignature, supportedSigAlgs []string, set KeySet) error {
	jws, err := jose.ParseSigned(token, toJoseSignatureAlgorithms(supportedSigAlgs))
	if err != nil {
		if strings.HasPrefix(err.Error(), "go-jose/go-jose: unexpected signature algorithm") {
			// TODO(v4): we should wrap errors instead of returning static ones.
			// This is a workaround so we keep returning the same error for now.
			return ErrSignatureUnsupportedAlg
		}
		return ErrParse
	}
	if len(jws.Signatures) == 0 {
		return ErrSignatureMissing
	}
	if len(jws.Signatures) > 1 {
		return ErrSignatureMultiple
	}
	sig := jws.Signatures[0]

	signedPayload, err := set.VerifySignature(ctx, jws)
	if err != nil {
		return fmt.Errorf("%w (%v)", ErrSignatureInvalid, err)
	}

	if !bytes.Equal(signedPayload, payload) {
		return ErrSignatureInvalidPayload
	}

	claims.SetSignatureAlgorithm(jose.SignatureAlgorithm(sig.Header.Algorithm))

	return nil
}

// TODO(v4): Use the new jose.SignatureAlgorithm type directly, instead of string.
func toJoseSignatureAlgorithms(algorithms []string) []jose.SignatureAlgorithm {
	out := make([]jose.SignatureAlgorithm, len(algorithms))
	for i := range algorithms {
		out[i] = jose.SignatureAlgorithm(algorithms[i])
	}
	if len(out) == 0 {
		out = append(out, jose.RS256, jose.ES256, jose.PS256)
	}
	return out
}

func CheckExpiration(claims Claims, offset time.Duration) error {
	expiration := claims.GetExpiration()
	if !time.Now().Add(offset).Before(expiration) {
		return ErrExpired
	}
	return nil
}

func CheckIssuedAt(claims Claims, maxAgeIAT, offset time.Duration) error {
	issuedAt := claims.GetIssuedAt()
	if issuedAt.IsZero() {
		return ErrIatMissing
	}
	nowWithOffset := time.Now().Add(offset).Round(time.Second)
	if issuedAt.After(nowWithOffset) {
		return fmt.Errorf("%w: (iat: %v, now with offset: %v)", ErrIatInFuture, issuedAt, nowWithOffset)
	}
	if maxAgeIAT == 0 {
		return nil
	}
	maxAge := time.Now().Add(-maxAgeIAT).Round(time.Second)
	if issuedAt.Before(maxAge) {
		return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrIatToOld, maxAge, issuedAt, maxAge.Sub(issuedAt))
	}
	return nil
}

func CheckNonce(claims Claims, nonce string) error {
	if claims.GetNonce() != nonce {
		return fmt.Errorf("%w: expected %q but was %q", ErrNonceInvalid, nonce, claims.GetNonce())
	}
	return nil
}

func CheckAuthorizationContextClassReference(claims Claims, acr ACRVerifier) error {
	if acr != nil {
		if err := acr(claims.GetAuthenticationContextClassReference()); err != nil {
			return fmt.Errorf("%w: %v", ErrAcrInvalid, err)
		}
	}
	return nil
}

func CheckAuthTime(claims Claims, maxAge time.Duration) error {
	if maxAge == 0 {
		return nil
	}
	if claims.GetAuthTime().IsZero() {
		return ErrAuthTimeNotPresent
	}
	authTime := claims.GetAuthTime()
	maxAuthTime := time.Now().Add(-maxAge).Round(time.Second)
	if authTime.Before(maxAuthTime) {
		return fmt.Errorf("%w: must not be older than %v, but was %v (%v to old)", ErrAuthTimeToOld, maxAge, authTime, maxAuthTime.Sub(authTime))
	}
	return nil
}