File: requestid_test.go

package info (click to toggle)
golang-github-smallstep-certificates 0.29.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 6,720 kB
  • sloc: sh: 385; makefile: 129
file content (105 lines) | stat: -rw-r--r-- 2,909 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package requestid

import (
	"net/http"
	"net/http/httptest"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func newRequest(t *testing.T) *http.Request {
	t.Helper()
	r, err := http.NewRequest(http.MethodGet, "https://example.com", http.NoBody)
	require.NoError(t, err)
	return r
}

func Test_Middleware(t *testing.T) {
	requestWithID := newRequest(t)
	requestWithID.Header.Set("X-Request-Id", "reqID")

	requestWithoutID := newRequest(t)

	requestWithEmptyHeader := newRequest(t)
	requestWithEmptyHeader.Header.Set("X-Request-Id", "")

	requestWithSmallstepID := newRequest(t)
	requestWithSmallstepID.Header.Set("X-Smallstep-Id", "smallstepID")

	tests := []struct {
		name        string
		traceHeader string
		next        http.HandlerFunc
		req         *http.Request
	}{
		{
			name:        "default-request-id",
			traceHeader: defaultTraceHeader,
			next: func(w http.ResponseWriter, r *http.Request) {
				assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
				assert.Equal(t, "reqID", r.Header.Get("X-Request-Id"))
				reqID, ok := FromContext(r.Context())
				if assert.True(t, ok) {
					assert.Equal(t, "reqID", reqID)
				}
				assert.Equal(t, "reqID", w.Header().Get("X-Request-Id"))
			},
			req: requestWithID,
		},
		{
			name:        "no-request-id",
			traceHeader: "X-Request-Id",
			next: func(w http.ResponseWriter, r *http.Request) {
				assert.Empty(t, r.Header.Get("X-Smallstep-Id"))
				value := r.Header.Get("X-Request-Id")
				assert.NotEmpty(t, value)
				reqID, ok := FromContext(r.Context())
				if assert.True(t, ok) {
					assert.Equal(t, value, reqID)
				}
				assert.Equal(t, value, w.Header().Get("X-Request-Id"))
			},
			req: requestWithoutID,
		},
		{
			name:        "empty-header",
			traceHeader: "",
			next: func(w http.ResponseWriter, r *http.Request) {
				assert.Empty(t, r.Header.Get("X-Request-Id"))
				value := r.Header.Get("X-Smallstep-Id")
				assert.NotEmpty(t, value)
				reqID, ok := FromContext(r.Context())
				if assert.True(t, ok) {
					assert.Equal(t, value, reqID)
				}
				assert.Equal(t, value, w.Header().Get("X-Request-Id"))
			},
			req: requestWithEmptyHeader,
		},
		{
			name:        "fallback-header-name",
			traceHeader: defaultTraceHeader,
			next: func(w http.ResponseWriter, r *http.Request) {
				assert.Empty(t, r.Header.Get("X-Request-Id"))
				assert.Equal(t, "smallstepID", r.Header.Get("X-Smallstep-Id"))
				reqID, ok := FromContext(r.Context())
				if assert.True(t, ok) {
					assert.Equal(t, "smallstepID", reqID)
				}
				assert.Equal(t, "smallstepID", w.Header().Get("X-Request-Id"))
			},
			req: requestWithSmallstepID,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			handler := New(tt.traceHeader).Middleware(tt.next)

			w := httptest.NewRecorder()
			handler.ServeHTTP(w, tt.req)
			assert.NotEmpty(t, w.Header().Get("X-Request-Id"))
		})
	}
}