File: userinfo.go

package info (click to toggle)
golang-github-zitadel-oidc 3.37.0-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 1,484 kB
  • sloc: makefile: 5
file content (104 lines) | stat: -rw-r--r-- 2,962 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package op

import (
	"context"
	"errors"
	"net/http"
	"strings"

	httphelper "github.com/zitadel/oidc/v3/pkg/http"
	"github.com/zitadel/oidc/v3/pkg/oidc"
)

type UserinfoProvider interface {
	Decoder() httphelper.Decoder
	Crypto() Crypto
	Storage() Storage
	AccessTokenVerifier(context.Context) *AccessTokenVerifier
}

func userinfoHandler(userinfoProvider UserinfoProvider) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		Userinfo(w, r, userinfoProvider)
	}
}

func Userinfo(w http.ResponseWriter, r *http.Request, userinfoProvider UserinfoProvider) {
	ctx, span := tracer.Start(r.Context(), "Userinfo")
	r = r.WithContext(ctx)
	defer span.End()

	accessToken, err := ParseUserinfoRequest(r, userinfoProvider.Decoder())
	if err != nil {
		http.Error(w, "access token missing", http.StatusUnauthorized)
		return
	}
	tokenID, subject, ok := getTokenIDAndSubject(r.Context(), userinfoProvider, accessToken)
	if !ok {
		http.Error(w, "access token invalid", http.StatusUnauthorized)
		return
	}
	info := new(oidc.UserInfo)
	err = userinfoProvider.Storage().SetUserinfoFromToken(r.Context(), info, tokenID, subject, r.Header.Get("origin"))
	if err != nil {
		httphelper.MarshalJSONWithStatus(w, err, http.StatusForbidden)
		return
	}
	httphelper.MarshalJSON(w, info)
}

func ParseUserinfoRequest(r *http.Request, decoder httphelper.Decoder) (string, error) {
	ctx, span := tracer.Start(r.Context(), "ParseUserinfoRequest")
	r = r.WithContext(ctx)
	defer span.End()

	accessToken, err := getAccessToken(r)
	if err == nil {
		return accessToken, nil
	}
	err = r.ParseForm()
	if err != nil {
		return "", errors.New("unable to parse request")
	}
	req := new(oidc.UserInfoRequest)
	err = decoder.Decode(req, r.Form)
	if err != nil {
		return "", errors.New("unable to parse request")
	}
	return req.AccessToken, nil
}

func getAccessToken(r *http.Request) (string, error) {
	ctx, span := tracer.Start(r.Context(), "getAccessToken")
	r = r.WithContext(ctx)
	defer span.End()

	authHeader := r.Header.Get("authorization")
	if authHeader == "" {
		return "", errors.New("no auth header")
	}
	parts := strings.Split(authHeader, "Bearer ")
	if len(parts) != 2 {
		return "", errors.New("invalid auth header")
	}
	return parts[1], nil
}

func getTokenIDAndSubject(ctx context.Context, userinfoProvider UserinfoProvider, accessToken string) (string, string, bool) {
	ctx, span := tracer.Start(ctx, "getTokenIDAndSubject")
	defer span.End()

	tokenIDSubject, err := userinfoProvider.Crypto().Decrypt(accessToken)
	if err == nil {
		splitToken := strings.Split(tokenIDSubject, ":")
		if len(splitToken) != 2 {
			return "", "", false
		}
		return splitToken[0], splitToken[1], true
	}
	accessTokenClaims, err := VerifyAccessToken[*oidc.AccessTokenClaims](ctx, accessToken, userinfoProvider.AccessTokenVerifier(ctx))
	if err != nil {
		return "", "", false
	}
	return accessTokenClaims.JWTID, accessTokenClaims.Subject, true
}