File: logging.go

package info (click to toggle)
relic 7.6.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,108 kB
  • sloc: sh: 230; makefile: 10
file content (169 lines) | stat: -rw-r--r-- 4,571 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
package zhttp

import (
	"context"
	"fmt"
	stdlog "log"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/rs/zerolog"
	"github.com/rs/zerolog/log"
	"github.com/sassoftware/relic/v7/internal/logrotate"
)

type ctxKey int

var (
	ctxAccessCallbacks ctxKey = 1
	ctxDontLog         ctxKey = 2
)

const rfc3339Milli = "2006-01-02T15:04:05.000Z07:00" // RFC3339 with 3 decimal places, padded

// SetupLogging initializes zerolog with reasonable defaults
func SetupLogging(levelName, logFile string) error {
	zerolog.TimeFieldFormat = rfc3339Milli
	zerolog.DurationFieldInteger = true
	switch logFile {
	case "-":
		// write JSON to stderr
	case "":
		// write pretty text to stderr
		log.Logger = log.Logger.Output(zerolog.ConsoleWriter{
			Out:        os.Stderr,
			TimeFormat: "15:04:05",
		})
	default:
		// write JSON to file
		w, err := logrotate.NewWriter(logFile)
		if err != nil {
			return fmt.Errorf("log_file: %w", err)
		}
		log.Logger = log.Logger.Output(w)
	}
	// set default log level
	if levelName == "" {
		levelName = zerolog.InfoLevel.String()
	}
	level, err := zerolog.ParseLevel(levelName)
	if err != nil {
		return fmt.Errorf("log_level: %w", err)
	}
	log.Logger = log.Logger.Level(level)
	// pass stdlib logger through
	stdlog.SetFlags(0)
	stdlog.SetOutput(log.Logger)
	return nil
}

// LoggingMiddleware creates a logging context for each request, and emits an
// access log entry at the completion of the request.
func LoggingMiddleware(opts ...LoggingOption) func(http.Handler) http.Handler {
	cfg := loggingConfig{
		logger: log.Logger,
		now:    time.Now,
	}
	for _, o := range opts {
		o(&cfg)
	}
	return func(next http.Handler) http.Handler {
		return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
			// Make a new log context for the scope of the request with basic
			// request metadata suitable for every log entry
			lc := cfg.logger.With().
				Str("ip", StripPort(req.RemoteAddr)).
				Str("req_id", req.Header.Get("X-Request-Id"))
			// build request context and execute the next handler
			baseLogger := lc.Logger()
			ctx := req.Context()
			ctx = baseLogger.WithContext(ctx)
			var callbacks []AccessLogCallback
			var dontLog bool
			ctx = context.WithValue(ctx, ctxAccessCallbacks, &callbacks)
			ctx = context.WithValue(ctx, ctxDontLog, &dontLog)
			start := cfg.now()
			lw := &Logger{
				ResponseWriter: rw,
				Now:            cfg.now,
			}
			req = req.WithContext(ctx)
			next.ServeHTTP(lw, req)
			// emit the access log entry
			if !dontLog {
				// Logger.WithContext makes a copy of the logger to put into
				// ctx, so make sure we use that as a base to get the effects of
				// any UpdateContext calls from inside the view funcs.
				logger := zerolog.Ctx(ctx)
				ev := logger.Info().
					Str("method", req.Method).
					Stringer("url", req.URL).
					Int("status", lw.Status()).
					Int64("len", lw.Length()).
					Dur("dur", cfg.now().Sub(start)).
					Dur("ttfb", lw.Started().Sub(start)).
					Str("ua", req.UserAgent())
				for _, cb := range callbacks {
					cb(ev)
				}
				ev.Send()
			}
		})
	}
}

type AccessLogCallback func(*zerolog.Event)

// AppendAccessLog adds a callback function which will be invoked to amend the
// access log with additional fields.
func AppendAccessLog(req *http.Request, f AccessLogCallback) {
	AppendAccessLogContext(req.Context(), f)
}

// AppendAccessLog adds a callback function which will be invoked to amend the
// access log with additional fields.
func AppendAccessLogContext(ctx context.Context, f AccessLogCallback) {
	callbacks, _ := ctx.Value(ctxAccessCallbacks).(*[]AccessLogCallback)
	if callbacks != nil {
		*callbacks = append(*callbacks, f)
	}
}

// DontLog marks that the current request should not generate an access log
// entry
func DontLog(req *http.Request) {
	dontLog, _ := req.Context().Value(ctxDontLog).(*bool)
	if dontLog != nil {
		*dontLog = true
	}
}

type loggingConfig struct {
	logger zerolog.Logger
	now    func() time.Time
}

type LoggingOption func(*loggingConfig)

// WithLogger sets the base logger for the middleware
func WithLogger(logger zerolog.Logger) LoggingOption {
	return func(lc *loggingConfig) {
		lc.logger = logger
	}
}

// StripPort returns just the IP part from e.g. Request.RemoteAddr
func StripPort(clientIP string) string {
	i := strings.IndexByte(clientIP, ':')
	j := strings.IndexByte(clientIP, ']')
	if j > 1 && clientIP[0] == '[' {
		// [fe80::]:1234
		return clientIP[1:j]
	} else if i > 0 && strings.Count(clientIP, ":") == 1 {
		// 127.0.0.1:1234
		return clientIP[:i]
	}
	return clientIP
}