1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
// Package requestid provides HTTP request ID functionality
package requestid
import (
"context"
"net/http"
"github.com/rs/xid"
"go.step.sm/crypto/randutil"
)
const (
// requestIDHeader is the header name used for propagating request IDs. If
// available in an HTTP request, it'll be used instead of the X-Smallstep-Id
// header. It'll always be used in response and set to the request ID.
requestIDHeader = "X-Request-Id"
// defaultTraceHeader is the default Smallstep tracing header that's currently
// in use. It is used as a fallback to retrieve a request ID from, if the
// "X-Request-Id" request header is not set.
defaultTraceHeader = "X-Smallstep-Id"
)
type Handler struct {
legacyTraceHeader string
}
// New creates a new request ID [handler]. It takes a trace header,
// which is used keep the legacy behavior intact, which relies on the
// X-Smallstep-Id header instead of X-Request-Id.
func New(legacyTraceHeader string) *Handler {
if legacyTraceHeader == "" {
legacyTraceHeader = defaultTraceHeader
}
return &Handler{legacyTraceHeader: legacyTraceHeader}
}
// Middleware wraps an [http.Handler] with request ID extraction
// from the X-Reqeust-Id header by default, or from the X-Smallstep-Id
// header if not set. If both are not set, a new request ID is generated.
// In all cases, the request ID is added to the request context, and
// set to be reflected in the response.
func (h *Handler) Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
requestID := req.Header.Get(requestIDHeader)
if requestID == "" {
requestID = req.Header.Get(h.legacyTraceHeader)
}
if requestID == "" {
requestID = newRequestID()
req.Header.Set(h.legacyTraceHeader, requestID) // legacy behavior
}
// immediately set the request ID to be reflected in the response
w.Header().Set(requestIDHeader, requestID)
// continue down the handler chain
ctx := NewContext(req.Context(), requestID)
next.ServeHTTP(w, req.WithContext(ctx))
}
return http.HandlerFunc(fn)
}
// newRequestID generates a new random UUIDv4 request ID. If UUIDv4
// generation fails, it'll fallback to generating a random ID using
// github.com/rs/xid.
func newRequestID() string {
requestID, err := randutil.UUIDv4()
if err != nil {
requestID = xid.New().String()
}
return requestID
}
type contextKey struct{}
// NewContext returns a new context with the given request ID added to the
// context.
func NewContext(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, contextKey{}, requestID)
}
// FromContext returns the request ID from the context if it exists and
// is not the empty value.
func FromContext(ctx context.Context) (string, bool) {
v, ok := ctx.Value(contextKey{}).(string)
return v, ok && v != ""
}
|