File: token_refresh.go

package info (click to toggle)
golang-github-zitadel-oidc 3.37.0-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 1,484 kB
  • sloc: makefile: 5
file content (152 lines) | stat: -rw-r--r-- 5,806 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package op

import (
	"context"
	"errors"
	"net/http"
	"slices"
	"time"

	httphelper "github.com/zitadel/oidc/v3/pkg/http"
	"github.com/zitadel/oidc/v3/pkg/oidc"
)

type RefreshTokenRequest interface {
	GetAMR() []string
	GetAudience() []string
	GetAuthTime() time.Time
	GetClientID() string
	GetScopes() []string
	GetSubject() string
	SetCurrentScopes(scopes []string)
}

// RefreshTokenExchange handles the OAuth 2.0 refresh_token grant, including
// parsing, validating, authorizing the client and finally exchanging the refresh_token for new tokens
func RefreshTokenExchange(w http.ResponseWriter, r *http.Request, exchanger Exchanger) {
	ctx, span := tracer.Start(r.Context(), "RefreshTokenExchange")
	defer span.End()
	r = r.WithContext(ctx)

	tokenReq, err := ParseRefreshTokenRequest(r, exchanger.Decoder())
	if err != nil {
		RequestError(w, r, err, exchanger.Logger())
	}
	validatedRequest, client, err := ValidateRefreshTokenRequest(r.Context(), tokenReq, exchanger)
	if err != nil {
		RequestError(w, r, err, exchanger.Logger())
		return
	}
	resp, err := CreateTokenResponse(r.Context(), validatedRequest, client, exchanger, true, "", tokenReq.RefreshToken)
	if err != nil {
		RequestError(w, r, err, exchanger.Logger())
		return
	}
	httphelper.MarshalJSON(w, resp)
}

// ParseRefreshTokenRequest parsed the http request into a oidc.RefreshTokenRequest
func ParseRefreshTokenRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.RefreshTokenRequest, error) {
	request := new(oidc.RefreshTokenRequest)
	err := ParseAuthenticatedTokenRequest(r, decoder, request)
	if err != nil {
		return nil, err
	}
	return request, nil
}

// ValidateRefreshTokenRequest validates the refresh_token request parameters including authorization check of the client
// and returns the data representing the original auth request corresponding to the refresh_token
func ValidateRefreshTokenRequest(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (RefreshTokenRequest, Client, error) {
	ctx, span := tracer.Start(ctx, "ValidateRefreshTokenRequest")
	defer span.End()

	if tokenReq.RefreshToken == "" {
		return nil, nil, oidc.ErrInvalidRequest().WithDescription("refresh_token missing")
	}
	request, client, err := AuthorizeRefreshClient(ctx, tokenReq, exchanger)
	if err != nil {
		return nil, nil, err
	}
	if client.GetID() != request.GetClientID() {
		return nil, nil, oidc.ErrInvalidGrant()
	}
	if err = ValidateRefreshTokenScopes(tokenReq.Scopes, request); err != nil {
		return nil, nil, err
	}
	return request, client, nil
}

// ValidateRefreshTokenScopes validates that the requested scope is a subset of the original auth request scope
// it will set the requested scopes as current scopes onto RefreshTokenRequest
// if empty the original scopes will be used
func ValidateRefreshTokenScopes(requestedScopes []string, authRequest RefreshTokenRequest) error {
	if len(requestedScopes) == 0 {
		return nil
	}
	for _, scope := range requestedScopes {
		if !slices.Contains(authRequest.GetScopes(), scope) {
			return oidc.ErrInvalidScope()
		}
	}
	authRequest.SetCurrentScopes(requestedScopes)
	return nil
}

// AuthorizeRefreshClient checks the authorization of the client and that the used method was the one previously registered.
// It than returns the data representing the original auth request corresponding to the refresh_token
func AuthorizeRefreshClient(ctx context.Context, tokenReq *oidc.RefreshTokenRequest, exchanger Exchanger) (request RefreshTokenRequest, client Client, err error) {
	ctx, span := tracer.Start(ctx, "AuthorizeRefreshClient")
	defer span.End()

	if tokenReq.ClientAssertionType == oidc.ClientAssertionTypeJWTAssertion {
		jwtExchanger, ok := exchanger.(JWTAuthorizationGrantExchanger)
		if !ok || !exchanger.AuthMethodPrivateKeyJWTSupported() {
			return nil, nil, errors.New("auth_method private_key_jwt not supported")
		}
		client, err = AuthorizePrivateJWTKey(ctx, tokenReq.ClientAssertion, jwtExchanger)
		if err != nil {
			return nil, nil, err
		}
		if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
			return nil, nil, oidc.ErrUnauthorizedClient()
		}
		request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
		return request, client, err
	}
	client, err = exchanger.Storage().GetClientByClientID(ctx, tokenReq.ClientID)
	if err != nil {
		return nil, nil, err
	}
	if !ValidateGrantType(client, oidc.GrantTypeRefreshToken) {
		return nil, nil, oidc.ErrUnauthorizedClient()
	}
	if client.AuthMethod() == oidc.AuthMethodPrivateKeyJWT {
		return nil, nil, oidc.ErrInvalidClient()
	}
	if client.AuthMethod() == oidc.AuthMethodNone {
		request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
		return request, client, err
	}
	if client.AuthMethod() == oidc.AuthMethodPost && !exchanger.AuthMethodPostSupported() {
		return nil, nil, oidc.ErrInvalidClient().WithDescription("auth_method post not supported")
	}
	if err = AuthorizeClientIDSecret(ctx, tokenReq.ClientID, tokenReq.ClientSecret, exchanger.Storage()); err != nil {
		return nil, nil, err
	}
	request, err = RefreshTokenRequestByRefreshToken(ctx, exchanger.Storage(), tokenReq.RefreshToken)
	return request, client, err
}

// RefreshTokenRequestByRefreshToken returns the RefreshTokenRequest (data representing the original auth request)
// corresponding to the refresh_token from Storage or an error
func RefreshTokenRequestByRefreshToken(ctx context.Context, storage Storage, refreshToken string) (RefreshTokenRequest, error) {
	ctx, span := tracer.Start(ctx, "RefreshTokenRequestByRefreshToken")
	defer span.End()

	request, err := storage.TokenRequestByRefreshToken(ctx, refreshToken)
	if err != nil {
		return nil, oidc.ErrInvalidGrant().WithParent(err)
	}
	return request, nil
}