1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
package logging
import (
"net/http"
"time"
"log/slog"
)
type MiddlewareOption func(*middleware)
// WitLogger sets the passed logger with request attributes
// into the Request's context.
func WithLogger(logger *slog.Logger) MiddlewareOption {
return func(m *middleware) {
m.logger = logger
}
}
// WithGroup groups the log attributes
// produced by the middleware.
func WithGroup(name string) MiddlewareOption {
return func(m *middleware) {
m.group = name
}
}
// WithIDFunc enables the creating of request IDs
// in the middleware, which are then attached to
// the logger.
func WithIDFunc(nextID func() slog.Attr) MiddlewareOption {
return func(m *middleware) {
m.nextID = nextID
}
}
// WithDurationFunc allows overriding the request duration for testing.
func WithDurationFunc(df func(time.Time) time.Duration) MiddlewareOption {
return func(m *middleware) {
m.duration = df
}
}
// WithRequestAttr allows customizing the information used
// from a request as request attributes.
func WithRequestAttr(requestToAttr func(*http.Request) slog.Attr) MiddlewareOption {
return func(m *middleware) {
m.reqAttr = requestToAttr
}
}
// WithLoggedWriter allows customizing the writer from
// which post-request attributes are taken.
func WithLoggedWriter(wrap func(w http.ResponseWriter) LoggedWriter) MiddlewareOption {
return func(m *middleware) {
m.wrapWriter = wrap
}
}
// Middleware enables request logging and sets a logger
// to the request context.
// Use [FromContext] to obtain the logger anywhere in the request liftime.
//
// The default logger is [slog.Default], with the request's URL and Method
// as preset attributes.
// When the request terminates, a INFO line with the Status Code and
// amount written to the client is printed.
// This behaviors can be modified with options.
func Middleware(options ...MiddlewareOption) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
mw := &middleware{
logger: slog.Default(),
duration: time.Since,
next: next,
reqAttr: requestToAttr,
wrapWriter: newLoggedWriter,
}
for _, opt := range options {
opt(mw)
}
return mw
}
}
type middleware struct {
logger *slog.Logger
group string
nextID func() slog.Attr
next http.Handler
duration func(time.Time) time.Duration
reqAttr func(*http.Request) slog.Attr
wrapWriter func(http.ResponseWriter) LoggedWriter
}
func (m *middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
logger := m.logger.With(slog.Group(m.group, m.reqAttr(r)))
if m.nextID != nil {
logger = logger.With(slog.Group(m.group, m.nextID()))
}
r = r.WithContext(ToContext(r.Context(), logger))
lw := m.wrapWriter(w)
m.next.ServeHTTP(lw, r)
logger = logger.With(slog.Group(m.group,
slog.Duration("duration", m.duration(start)),
lw.Attr(),
))
if err := lw.Err(); err != nil {
logger.WarnContext(r.Context(), "write response", "error", err)
return
}
logger.InfoContext(r.Context(), "request served")
}
type loggedWriter struct {
http.ResponseWriter
statusCode int
written int
err error
}
func newLoggedWriter(w http.ResponseWriter) LoggedWriter {
return &loggedWriter{
ResponseWriter: w,
}
}
func (w *loggedWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *loggedWriter) Write(b []byte) (int, error) {
if w.statusCode == 0 {
w.WriteHeader(http.StatusOK)
}
n, err := w.ResponseWriter.Write(b)
w.written += n
w.err = err
return n, err
}
func (lw *loggedWriter) Attr() slog.Attr {
return slog.Group("response",
"status", lw.statusCode,
"written", lw.written,
)
}
func (lw *loggedWriter) Err() error {
return lw.err
}
|