1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
package op
import (
"context"
"errors"
"log/slog"
"net/http"
"net/url"
"path"
httphelper "github.com/zitadel/oidc/v3/pkg/http"
"github.com/zitadel/oidc/v3/pkg/oidc"
)
type SessionEnder interface {
Decoder() httphelper.Decoder
Storage() Storage
IDTokenHintVerifier(context.Context) *IDTokenHintVerifier
DefaultLogoutRedirectURI() string
Logger() *slog.Logger
}
func endSessionHandler(ender SessionEnder) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
EndSession(w, r, ender)
}
}
func EndSession(w http.ResponseWriter, r *http.Request, ender SessionEnder) {
ctx, span := tracer.Start(r.Context(), "EndSession")
defer span.End()
r = r.WithContext(ctx)
req, err := ParseEndSessionRequest(r, ender.Decoder())
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
session, err := ValidateEndSessionRequest(r.Context(), req, ender)
if err != nil {
RequestError(w, r, err, ender.Logger())
return
}
redirect := session.RedirectURI
if fromRequest, ok := ender.Storage().(CanTerminateSessionFromRequest); ok {
redirect, err = fromRequest.TerminateSessionFromRequest(r.Context(), session)
} else {
err = ender.Storage().TerminateSession(r.Context(), session.UserID, session.ClientID)
}
if err != nil {
RequestError(w, r, oidc.DefaultToServerError(err, "error terminating session"), ender.Logger())
return
}
http.Redirect(w, r, redirect, http.StatusFound)
}
func ParseEndSessionRequest(r *http.Request, decoder httphelper.Decoder) (*oidc.EndSessionRequest, error) {
err := r.ParseForm()
if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error parsing form").WithParent(err)
}
req := new(oidc.EndSessionRequest)
err = decoder.Decode(req, r.Form)
if err != nil {
return nil, oidc.ErrInvalidRequest().WithDescription("error decoding form").WithParent(err)
}
return req, nil
}
func ValidateEndSessionRequest(ctx context.Context, req *oidc.EndSessionRequest, ender SessionEnder) (*EndSessionRequest, error) {
ctx, span := tracer.Start(ctx, "ValidateEndSessionRequest")
defer span.End()
session := &EndSessionRequest{
RedirectURI: ender.DefaultLogoutRedirectURI(),
}
if req.IdTokenHint != "" {
claims, err := VerifyIDTokenHint[*oidc.IDTokenClaims](ctx, req.IdTokenHint, ender.IDTokenHintVerifier(ctx))
if err != nil && !errors.As(err, &IDTokenHintExpiredError{}) {
return nil, oidc.ErrInvalidRequest().WithDescription("id_token_hint invalid").WithParent(err)
}
session.UserID = claims.GetSubject()
session.IDTokenHintClaims = claims
if req.ClientID != "" && req.ClientID != claims.GetAuthorizedParty() {
return nil, oidc.ErrInvalidRequest().WithDescription("client_id does not match azp of id_token_hint")
}
req.ClientID = claims.GetAuthorizedParty()
}
if req.ClientID != "" {
client, err := ender.Storage().GetClientByClientID(ctx, req.ClientID)
if err != nil {
return nil, oidc.DefaultToServerError(err, "")
}
session.ClientID = client.GetID()
if req.PostLogoutRedirectURI != "" {
if err := ValidateEndSessionPostLogoutRedirectURI(req.PostLogoutRedirectURI, client); err != nil {
return nil, err
}
session.RedirectURI = req.PostLogoutRedirectURI
}
}
if req.State != "" {
redirect, err := url.Parse(session.RedirectURI)
if err != nil {
return nil, oidc.DefaultToServerError(err, "")
}
session.RedirectURI = mergeQueryParams(redirect, url.Values{"state": {req.State}})
}
return session, nil
}
func ValidateEndSessionPostLogoutRedirectURI(postLogoutRedirectURI string, client Client) error {
for _, uri := range client.PostLogoutRedirectURIs() {
if uri == postLogoutRedirectURI {
return nil
}
}
if globClient, ok := client.(HasRedirectGlobs); ok {
for _, uriGlob := range globClient.PostLogoutRedirectURIGlobs() {
isMatch, err := path.Match(uriGlob, postLogoutRedirectURI)
if err != nil {
return oidc.ErrServerError().WithParent(err)
}
if isMatch {
return nil
}
}
}
return oidc.ErrInvalidRequest().WithDescription("post_logout_redirect_uri invalid")
}
|