File: middleware.go

package info (click to toggle)
golang-github-zitadel-logging 0.6.2-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental, forky, sid, trixie
  • size: 156 kB
  • sloc: javascript: 8; makefile: 2
file content (155 lines) | stat: -rw-r--r-- 3,769 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package logging

import (
	"net/http"
	"time"

	"log/slog"
)

type MiddlewareOption func(*middleware)

// WitLogger sets the passed logger with request attributes
// into the Request's context.
func WithLogger(logger *slog.Logger) MiddlewareOption {
	return func(m *middleware) {
		m.logger = logger
	}
}

// WithGroup groups the log attributes
// produced by the middleware.
func WithGroup(name string) MiddlewareOption {
	return func(m *middleware) {
		m.group = name
	}
}

// WithIDFunc enables the creating of request IDs
// in the middleware, which are then attached to
// the logger.
func WithIDFunc(nextID func() slog.Attr) MiddlewareOption {
	return func(m *middleware) {
		m.nextID = nextID
	}
}

// WithDurationFunc allows overriding the request duration for testing.
func WithDurationFunc(df func(time.Time) time.Duration) MiddlewareOption {
	return func(m *middleware) {
		m.duration = df
	}
}

// WithRequestAttr allows customizing the information used
// from a request as request attributes.
func WithRequestAttr(requestToAttr func(*http.Request) slog.Attr) MiddlewareOption {
	return func(m *middleware) {
		m.reqAttr = requestToAttr
	}
}

// WithLoggedWriter allows customizing the writer from
// which post-request attributes are taken.
func WithLoggedWriter(wrap func(w http.ResponseWriter) LoggedWriter) MiddlewareOption {
	return func(m *middleware) {
		m.wrapWriter = wrap
	}
}

// Middleware enables request logging and sets a logger
// to the request context.
// Use [FromContext] to obtain the logger anywhere in the request liftime.
//
// The default logger is [slog.Default], with the request's URL and Method
// as preset attributes.
// When the request terminates, a INFO line with the Status Code and
// amount written to the client is printed.
// This behaviors can be modified with options.
func Middleware(options ...MiddlewareOption) func(http.Handler) http.Handler {
	return func(next http.Handler) http.Handler {
		mw := &middleware{
			logger:     slog.Default(),
			duration:   time.Since,
			next:       next,
			reqAttr:    requestToAttr,
			wrapWriter: newLoggedWriter,
		}
		for _, opt := range options {
			opt(mw)
		}
		return mw
	}
}

type middleware struct {
	logger     *slog.Logger
	group      string
	nextID     func() slog.Attr
	next       http.Handler
	duration   func(time.Time) time.Duration
	reqAttr    func(*http.Request) slog.Attr
	wrapWriter func(http.ResponseWriter) LoggedWriter
}

func (m *middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	start := time.Now()

	logger := m.logger.With(slog.Group(m.group, m.reqAttr(r)))
	if m.nextID != nil {
		logger = logger.With(slog.Group(m.group, m.nextID()))
	}
	r = r.WithContext(ToContext(r.Context(), logger))

	lw := m.wrapWriter(w)
	m.next.ServeHTTP(lw, r)
	logger = logger.With(slog.Group(m.group,
		slog.Duration("duration", m.duration(start)),
		lw.Attr(),
	))
	if err := lw.Err(); err != nil {
		logger.WarnContext(r.Context(), "write response", "error", err)
		return
	}
	logger.InfoContext(r.Context(), "request served")
}

type loggedWriter struct {
	http.ResponseWriter

	statusCode int
	written    int
	err        error
}

func newLoggedWriter(w http.ResponseWriter) LoggedWriter {
	return &loggedWriter{
		ResponseWriter: w,
	}
}

func (w *loggedWriter) WriteHeader(statusCode int) {
	w.statusCode = statusCode
	w.ResponseWriter.WriteHeader(statusCode)
}

func (w *loggedWriter) Write(b []byte) (int, error) {
	if w.statusCode == 0 {
		w.WriteHeader(http.StatusOK)
	}
	n, err := w.ResponseWriter.Write(b)
	w.written += n
	w.err = err
	return n, err
}

func (lw *loggedWriter) Attr() slog.Attr {
	return slog.Group("response",
		"status", lw.statusCode,
		"written", lw.written,
	)
}

func (lw *loggedWriter) Err() error {
	return lw.err
}