File: error.go

package info (click to toggle)
golang-github-zitadel-oidc 3.37.0-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 1,484 kB
  • sloc: makefile: 5
file content (197 lines) | stat: -rw-r--r-- 6,699 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
package op

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"net/http"

	httphelper "github.com/zitadel/oidc/v3/pkg/http"
	"github.com/zitadel/oidc/v3/pkg/oidc"
)

type ErrAuthRequest interface {
	GetRedirectURI() string
	GetResponseType() oidc.ResponseType
	GetState() string
}

// LogAuthRequest is an optional interface,
// that allows logging AuthRequest fields.
// If the AuthRequest does not implement this interface,
// no details shall be printed to the logs.
type LogAuthRequest interface {
	ErrAuthRequest
	slog.LogValuer
}

func AuthRequestError(w http.ResponseWriter, r *http.Request, authReq ErrAuthRequest, err error, authorizer Authorizer) {
	e := oidc.DefaultToServerError(err, err.Error())
	logger := authorizer.Logger().With("oidc_error", e)

	if authReq == nil {
		logger.Log(r.Context(), e.LogLevel(), "auth request")
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}

	if logAuthReq, ok := authReq.(LogAuthRequest); ok {
		logger = logger.With("auth_request", logAuthReq)
	}

	if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
		logger.Log(r.Context(), e.LogLevel(), "auth request: not redirecting")
		http.Error(w, e.Description, http.StatusBadRequest)
		return
	}
	e.State = authReq.GetState()
	var sessionState string
	authRequestSessionState, ok := authReq.(AuthRequestSessionState)
	if ok {
		sessionState = authRequestSessionState.GetSessionState()
	}
	e.SessionState = sessionState
	var responseMode oidc.ResponseMode
	if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
		responseMode = rm.GetResponseMode()
	}
	url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, authorizer.Encoder())
	if err != nil {
		logger.ErrorContext(r.Context(), "auth response URL", "error", err)
		http.Error(w, err.Error(), http.StatusBadRequest)
		return
	}
	logger.Log(r.Context(), e.LogLevel(), "auth request")
	http.Redirect(w, r, url, http.StatusFound)
}

func RequestError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
	e := oidc.DefaultToServerError(err, err.Error())
	status := http.StatusBadRequest
	if e.ErrorType == oidc.InvalidClient {
		status = http.StatusUnauthorized
	}
	logger.Log(r.Context(), e.LogLevel(), "request error", "oidc_error", e)
	httphelper.MarshalJSONWithStatus(w, e, status)
}

// TryErrorRedirect tries to handle an error by redirecting a client.
// If this attempt fails, an error is returned that must be returned
// to the client instead.
func TryErrorRedirect(ctx context.Context, authReq ErrAuthRequest, parent error, encoder httphelper.Encoder, logger *slog.Logger) (*Redirect, error) {
	e := oidc.DefaultToServerError(parent, parent.Error())
	logger = logger.With("oidc_error", e)

	if authReq == nil {
		logger.Log(ctx, e.LogLevel(), "auth request")
		return nil, AsStatusError(e, http.StatusBadRequest)
	}

	if logAuthReq, ok := authReq.(LogAuthRequest); ok {
		logger = logger.With("auth_request", logAuthReq)
	}

	if authReq.GetRedirectURI() == "" || e.IsRedirectDisabled() {
		logger.Log(ctx, e.LogLevel(), "auth request: not redirecting")
		return nil, AsStatusError(e, http.StatusBadRequest)
	}

	e.State = authReq.GetState()
	var sessionState string
	authRequestSessionState, ok := authReq.(AuthRequestSessionState)
	if ok {
		sessionState = authRequestSessionState.GetSessionState()
	}
	e.SessionState = sessionState
	var responseMode oidc.ResponseMode
	if rm, ok := authReq.(interface{ GetResponseMode() oidc.ResponseMode }); ok {
		responseMode = rm.GetResponseMode()
	}
	url, err := AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), responseMode, e, encoder)
	if err != nil {
		logger.ErrorContext(ctx, "auth response URL", "error", err)
		return nil, AsStatusError(err, http.StatusBadRequest)
	}
	logger.Log(ctx, e.LogLevel(), "auth request redirect", "url", url)
	return NewRedirect(url), nil
}

// StatusError wraps an error with a HTTP status code.
// The status code is passed to the handler's writer.
type StatusError struct {
	parent     error
	statusCode int
}

// NewStatusError sets the parent and statusCode to a new StatusError.
// It is recommended for parent to be an [oidc.Error].
//
// Typically implementations should only use this to signal something
// very specific, like an internal server error.
// If a returned error is not a StatusError, the framework
// will set a statusCode based on what the standard specifies,
// which is [http.StatusBadRequest] for most of the time.
// If the error encountered can described clearly with a [oidc.Error],
// do not use this function, as it might break standard rules!
func NewStatusError(parent error, statusCode int) StatusError {
	return StatusError{
		parent:     parent,
		statusCode: statusCode,
	}
}

// AsStatusError unwraps a StatusError from err
// and returns it unmodified if found.
// If no StatuError was found, a new one is returned
// with statusCode set to it as a default.
func AsStatusError(err error, statusCode int) (target StatusError) {
	if errors.As(err, &target) {
		return target
	}
	return NewStatusError(err, statusCode)
}

func (e StatusError) Error() string {
	return fmt.Sprintf("%s: %s", http.StatusText(e.statusCode), e.parent.Error())
}

func (e StatusError) Unwrap() error {
	return e.parent
}

func (e StatusError) Is(err error) bool {
	var target StatusError
	if !errors.As(err, &target) {
		return false
	}
	return errors.Is(e.parent, target.parent) &&
		e.statusCode == target.statusCode
}

// WriteError asserts for a [StatusError] containing an [oidc.Error].
// If no `StatusError` is found, the status code will default to [http.StatusBadRequest].
// If no `oidc.Error` was found in the parent, the error type defaults to [oidc.ServerError].
// When there was no `StatusError` and the `oidc.Error` is of type `oidc.ServerError`,
// the status code will be set to [http.StatusInternalServerError]
func WriteError(w http.ResponseWriter, r *http.Request, err error, logger *slog.Logger) {
	var statusError StatusError
	if errors.As(err, &statusError) {
		writeError(w, r,
			oidc.DefaultToServerError(statusError.parent, statusError.parent.Error()),
			statusError.statusCode, logger,
		)
		return
	}
	statusCode := http.StatusBadRequest
	e := oidc.DefaultToServerError(err, err.Error())
	if e.ErrorType == oidc.ServerError {
		statusCode = http.StatusInternalServerError
	}
	writeError(w, r, e, statusCode, logger)
}

func writeError(w http.ResponseWriter, r *http.Request, err *oidc.Error, statusCode int, logger *slog.Logger) {
	logger.Log(r.Context(), err.LogLevel(), "request error", "oidc_error", err, "status_code", statusCode)
	httphelper.MarshalJSONWithStatus(w, err, statusCode)
}