File: session.go

package info (click to toggle)
golang-github-zitadel-oidc 3.37.0-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 1,484 kB
  • sloc: makefile: 5
file content (130 lines) | stat: -rw-r--r-- 4,116 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package op

import (
	"context"
	"errors"
	"log/slog"
	"net/http"
	"net/url"
	"path"

	httphelper "github.com/zitadel/oidc/v3/pkg/http"
	"github.com/zitadel/oidc/v3/pkg/oidc"
)

type SessionEnder interface {
	Decoder() httphelper.Decoder
	Storage() Storage
	IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
	DefaultLogoutRedirectURI() string
	Logger() *slog.Logger
}

func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
	return func(w http.ResponseWriter, r *http.Request) {
		EndSession(w, r, ender)
	}
}

func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
	ctx, span := tracer.Start(r.Context(), "EndSession")
	defer span.End()
	r = r.WithContext(ctx)

	req, err := ParseEndSessionRequest(r, ender.Decoder())
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
	session, err := ValidateEndSessionRequest(r.Context(), req, ender)
	if err != nil {
		RequestError(w, r, err, ender.Logger())
		return
	}
	redirect := session.RedirectURI
	if fromRequest, ok := ender.Storage().(CanTerminateSessionFromRequest); ok {
		redirect, err = fromRequest.TerminateSessionFromRequest(r.Context(), session)
	} else {
		err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
	}
	if err != nil {
		RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
		return
	}
	http.Redirect(w, r, redirect, http.StatusFound)
}

func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
	err := r.ParseForm()
	if err != nil {
		return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
	}
	req := new(oidc.EndSessionRequest)
	err = decoder.Decode(req, r.Form)
	if err != nil {
		return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
	}
	return req, nil
}

func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
	ctx, span := tracer.Start(ctx, "ValidateEndSessionRequest")
	defer span.End()

	session := &EndSessionRequest{
		RedirectURI: ender.DefaultLogoutRedirectURI(),
	}
	if req.IdTokenHint != "" {
		claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
		if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
			return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
		}
		session.UserID = claims.GetSubject()
		session.IDTokenHintClaims = claims
		if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() {
			return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint")
		}
		req.ClientID = claims.GetAuthorizedParty()
	}
	if req.ClientID != "" {
		client, err := ender.Storage().GetClientByClientID(ctx, req.ClientID)
		if err != nil {
			return nil, oidc.DefaultToServerError(err, "")
		}
		session.ClientID = client.GetID()
		if req.PostLogoutRedirectURI != "" {
			if err := ValidateEndSessionPostLogoutRedirectURI(req.PostLogoutRedirectURI, client); err != nil {
				return nil, err
			}
			session.RedirectURI = req.PostLogoutRedirectURI
		}
	}
	if req.State != "" {
		redirect, err := url.Parse(session.RedirectURI)
		if err != nil {
			return nil, oidc.DefaultToServerError(err, "")
		}
		session.RedirectURI = mergeQueryParams(redirect, url.Values{"state": {req.State}})
	}
	return session, nil
}

func ValidateEndSessionPostLogoutRedirectURI(postLogoutRedirectURI string, client Client) error {
	for _, uri := range client.PostLogoutRedirectURIs() {
		if uri == postLogoutRedirectURI {
			return nil
		}
	}
	if globClient, ok := client.(HasRedirectGlobs); ok {
		for _, uriGlob := range globClient.PostLogoutRedirectURIGlobs() {
			isMatch, err := path.Match(uriGlob, postLogoutRedirectURI)
			if err != nil {
				return oidc.ErrServerError().WithParent(err)
			}
			if isMatch {
				return nil
			}
		}
	}
	return oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
}